class DisjointSet {
int[] root;
int[] size;
public DisjointSet(int n) {
root = new int[n];
size = new int[n];
Arrays.fill(size, 1);
for (int i = 0; i < n; i++) {
root[i] = i;
}
}
public int find(int x) {
if (root[x] != x) {
root[x] = find(root[x]);
}
return root[x];
}
}
class Solution {
public int numberOfGoodPaths(int[] vals, int[][] edges) {
int n = vals.length;
List[] graph = new List[n];
for (int i = 0; i < n; i++) {
graph[i] = new ArrayList<>();
}
for (int[] edge : edges) {
int a = edge[0], b = edge[1];
graph[a].add(b);
graph[b].add(a);
}
DisjointSet ds = new DisjointSet(n);
Integer[] ids = new Integer[n];
for (int i = 0; i < n; i++) {
ids[i] = i;
}
Arrays.sort(ids, (a, b) -> vals[a] - vals[b]);
int res = n;
for (int id : ids) {
int valX = vals[id], rootX = ds.find(id);
for (int y : graph[id]) {
int rootY = ds.find(y);
int valY = vals[rootY];
if (rootX == rootY || valY > valX) continue;
if (valX == valY) {
res += ds.size[rootX] * ds.size[rootY];
ds.size[rootX] += ds.size[rootY];
}
ds.root[rootY] = rootX;
}
}
return res;
}
}