class Solution {
class DSU {
int[] size;
int[] root;
int[] infected;
public DSU(int n) {
size = new int[n];
root = new int[n];
infected = new int[n];
for (int i = 0; i < n; i++) {
root[i] = i;
}
}
public int find(int x) {
if (root[x] != x) {
root[x] = find(root[x]);
}
return root[x];
}
public void union(int x, int y) {
int rootX = find(x);
int rootY = find(y);
if (rootX == rootY) return;
if (size[rootX] < size[rootY]) {
root[rootX] = rootY;
size[rootY]++;
} else {
root[rootY] = rootX;
size[rootX]++;
}
}
}
public int minMalwareSpread(int[][] graph, int[] initial) {
// 1 - 2 - 3
// 0. find the disjoint set, which only one node is infected
// 1. choose the disjoint sets, which has most nodes
// 2. return the smallest index
int n = graph.length;
DSU dsu = new DSU(n);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (graph[i][j] == 1) {
dsu.union(i, j);
}
}
}
Arrays.sort(initial);
for (int i : initial) {
int root = dsu.find(i);
dsu.infected[root]++;
}
int num = 0;
int res = -1;
for (int i : initial) {
int root = dsu.find(i);
if (dsu.infected[root] == 1) {
if (num < dsu.size[root]) {
num = dsu.size[root];
res = i;
}
}
}
return res == -1 ? initial[0] : res;
}
}