924. Minimize Malware Spread

Back to Homepage   |     Back to Code List


class Solution {
    class DSU {
        int[] size;
        int[] root;
        int[] infected;
        
        public DSU(int n) {
            size = new int[n];
            root = new int[n];
            infected = new int[n];
            for (int i = 0; i < n; i++) {
                root[i] = i;
            }
        }
        
        public int find(int x) {
            if (root[x] != x) {
                root[x] = find(root[x]);
            }
            return root[x];
        }
        
        public void union(int x, int y) {
            int rootX = find(x);
            int rootY = find(y);
            if (rootX == rootY) return;
            
            if (size[rootX] < size[rootY]) {
                root[rootX] = rootY;
                size[rootY]++;
            } else {
                root[rootY] = rootX;
                size[rootX]++;
            }
        }
    }
    
    public int minMalwareSpread(int[][] graph, int[] initial) {
        // 1 - 2 - 3
        // 0. find the disjoint set, which only one node is infected
        // 1. choose the disjoint sets, which has most nodes
        // 2. return the smallest index
        int n = graph.length;
        DSU dsu = new DSU(n);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (graph[i][j] == 1) {
                    dsu.union(i, j);
                }
            }
        }
        
        Arrays.sort(initial);
        for (int i : initial) {
            int root = dsu.find(i);
            dsu.infected[root]++;
        }
        
        int num = 0;
        int res = -1;
        for (int i : initial) {
            int root = dsu.find(i);
            if (dsu.infected[root] == 1) {
                if (num < dsu.size[root]) {
                    num = dsu.size[root];
                    res = i;
                }
            }
        }
        
        return res == -1 ? initial[0] : res;
    }
}