package edu.tau.compbio.graph.algo;

import edu.tau.compbio.graph.FastMaskedGraph;
import edu.tau.compbio.graph.GraphUtilities;
import edu.tau.compbio.math.VecCalc;
import edu.tau.compbio.med.graph.Graph;
import edu.tau.compbio.med.graph.Node;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/tau/compbio/graph/algo/SteinerTree.class */
public class SteinerTree {
    public Set<Node> findSteinerTree(Graph graph, Set<Node> set, Map<Node, Float> map) {
        HashSet hashSet = new HashSet();
        Map<Node, AbstractList<Node>> graphAdjacencies = GraphUtilities.getGraphAdjacencies(graph, false, false, new ArrayList());
        Iterator<Set<Node>> it = GraphUtilities.getConnectedComponents(graph).iterator();
        while (it.hasNext()) {
            Set<Node> next = it.next();
            HashSet hashSet2 = new HashSet(next);
            hashSet2.retainAll(set);
            if (!hashSet2.isEmpty()) {
                hashSet.addAll(findSteinerTreeCC(graph, graphAdjacencies, next, hashSet2, map));
            }
        }
        return hashSet;
    }

    public Set<Integer> findSteinerTree(FastMaskedGraph fastMaskedGraph, Set<Integer> set, Map<Integer, Float> map) {
        HashSet hashSet = new HashSet();
        for (int[] iArr : fastMaskedGraph.getConnectedComponents()) {
            HashSet hashSet2 = new HashSet();
            for (int i : iArr) {
                hashSet2.add(Integer.valueOf(i));
            }
            hashSet2.retainAll(set);
            if (!hashSet2.isEmpty()) {
                Set<Integer> hashSet3 = new HashSet<>();
                for (int i2 : iArr) {
                    hashSet3.add(Integer.valueOf(i2));
                }
                hashSet.addAll(findSteinerTreeCC(fastMaskedGraph, hashSet3, hashSet2, map));
            }
        }
        return hashSet;
    }

    private Set<Integer> findSteinerTreeCC(FastMaskedGraph fastMaskedGraph, Set<Integer> set, Set<Integer> set2, Map<Integer, Float> map) {
        HashSet hashSet = new HashSet();
        HashMap hashMap = new HashMap();
        Iterator<Integer> it = set2.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            hashMap.put(Integer.valueOf(intValue), GraphUtilities.computeDistancesFromNode(fastMaskedGraph, Integer.valueOf(intValue), set));
        }
        boolean[] zArr = new boolean[fastMaskedGraph.sizeNodes()];
        for (int i = 0; i < zArr.length; i++) {
            if (set2.contains(Integer.valueOf(i))) {
                zArr[i] = true;
            }
        }
        AbstractList<Set<Integer>> connectedComponents = GraphUtilities.getConnectedComponents(fastMaskedGraph, zArr);
        while (connectedComponents.size() > 1) {
            float f = Float.MAX_VALUE;
            Integer num = null;
            ArrayList arrayList = null;
            Iterator<Integer> it2 = set.iterator();
            while (it2.hasNext()) {
                int intValue2 = it2.next().intValue();
                if (!set2.contains(Integer.valueOf(intValue2)) && !hashSet.contains(Integer.valueOf(intValue2))) {
                    int[] calcDistances = calcDistances(hashMap, intValue2, connectedComponents);
                    connectedComponents = sortTreesFMG(connectedComponents, calcDistances);
                    float f2 = Float.MAX_VALUE;
                    int i2 = 0;
                    int intValue3 = map != null ? map.get(Integer.valueOf(intValue2)).intValue() : 1;
                    for (int i3 = 0; i3 < connectedComponents.size(); i3++) {
                        intValue3 += calcDistances[i3];
                        if (intValue3 / (i3 + 1) < f2) {
                            f2 = intValue3 / (i3 + 1);
                        }
                        i2 = i3;
                    }
                    if (f2 < f) {
                        f = f2;
                        num = Integer.valueOf(intValue2);
                        arrayList = new ArrayList();
                        for (int i4 = 0; i4 <= i2; i4++) {
                            arrayList.add(connectedComponents.get(i4));
                        }
                    }
                }
            }
            System.out.println(num + ": Merging " + arrayList.size() + " trees");
            connectedComponents.removeAll(arrayList);
            HashSet hashSet2 = new HashSet();
            hashSet2.add(num);
            hashSet.add(num);
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                Set set3 = (Set) it3.next();
                AbstractList<Integer> computeShortestPath = GraphUtilities.computeShortestPath(fastMaskedGraph, num.intValue(), set3);
                hashSet.addAll(computeShortestPath);
                hashSet2.addAll(computeShortestPath);
                hashSet2.addAll(set3);
            }
            connectedComponents.add(hashSet2);
        }
        HashSet hashSet3 = new HashSet(set2);
        hashSet3.addAll(hashSet);
        boolean[] zArr2 = new boolean[fastMaskedGraph.sizeNodes()];
        for (int i5 = 0; i5 < zArr2.length; i5++) {
            zArr2[i5] = hashSet3.contains(Integer.valueOf(i5));
        }
        AbstractList<Set<Integer>> connectedComponents2 = GraphUtilities.getConnectedComponents(fastMaskedGraph, zArr2);
        if (connectedComponents2.size() > 1) {
            System.err.println("The Steinter tree has " + connectedComponents2.size() + " components");
        }
        return hashSet;
    }

    private Set<Node> findSteinerTreeCC(Graph graph, Map<Node, AbstractList<Node>> map, Set<Node> set, Set<Node> set2, Map map2) {
        HashSet hashSet = new HashSet();
        HashMap hashMap = new HashMap();
        for (Node node : set2) {
            hashMap.put(node, GraphUtilities.computeDistancesFromNode(map, node, set));
        }
        AbstractList<Set<Node>> connectedComponents = GraphUtilities.getConnectedComponents(graph, set2, new HashSet());
        while (connectedComponents.size() > 1) {
            float f = Float.MAX_VALUE;
            Node node2 = null;
            ArrayList arrayList = null;
            for (Node node3 : set) {
                if (!set2.contains(node3) && !hashSet.contains(node3)) {
                    int[] calcDistances = calcDistances(hashMap, node3, connectedComponents);
                    connectedComponents = sortTrees(connectedComponents, calcDistances);
                    float f2 = Float.MAX_VALUE;
                    int i = 0;
                    int intValue = map2 != null ? ((Number) map2.get(node3)).intValue() : 1;
                    for (int i2 = 0; i2 < connectedComponents.size(); i2++) {
                        intValue += calcDistances[i2];
                        if (intValue / (i2 + 1) < f2) {
                            f2 = intValue / (i2 + 1);
                        }
                        i = i2;
                    }
                    if (f2 < f) {
                        f = f2;
                        node2 = node3;
                        arrayList = new ArrayList();
                        for (int i3 = 0; i3 <= i; i3++) {
                            arrayList.add(connectedComponents.get(i3));
                        }
                    }
                }
            }
            System.out.println(node2 + ": Merging " + arrayList.size() + " trees");
            connectedComponents.removeAll(arrayList);
            HashSet hashSet2 = new HashSet();
            hashSet2.add(node2);
            hashSet.add(node2);
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                Set set3 = (Set) it.next();
                AbstractList computeShortestPath = GraphUtilities.computeShortestPath(map, node2, set3);
                hashSet.addAll(computeShortestPath);
                hashSet2.addAll(computeShortestPath);
                hashSet2.addAll(set3);
            }
            connectedComponents.add(hashSet2);
        }
        HashSet hashSet3 = new HashSet(set2);
        hashSet3.addAll(hashSet);
        AbstractList<Set<Node>> connectedComponents2 = GraphUtilities.getConnectedComponents(graph, hashSet3, new HashSet());
        if (connectedComponents2.size() > 1) {
            System.err.println("The Steinter tree has " + connectedComponents2.size() + " components");
        }
        return hashSet;
    }

    private int[] calcDistances(Map<Node, Map<Node, Integer>> map, Node node, AbstractList<Set<Node>> abstractList) {
        int[] iArr = new int[abstractList.size()];
        for (int i = 0; i < abstractList.size(); i++) {
            iArr[i] = Integer.MAX_VALUE;
            Iterator<Node> it = abstractList.get(i).iterator();
            while (it.hasNext()) {
                int intValue = map.get(it.next()).get(node).intValue();
                if (intValue < iArr[i]) {
                    iArr[i] = intValue;
                }
            }
        }
        return iArr;
    }

    private int[] calcDistances(Map<Integer, int[]> map, int i, AbstractList<Set<Integer>> abstractList) {
        int[] iArr = new int[abstractList.size()];
        for (int i2 = 0; i2 < abstractList.size(); i2++) {
            iArr[i2] = Integer.MAX_VALUE;
            Iterator<Integer> it = abstractList.get(i2).iterator();
            while (it.hasNext()) {
                int i3 = map.get(Integer.valueOf(it.next().intValue()))[i];
                if (i3 < iArr[i2]) {
                    iArr[i2] = i3;
                }
            }
        }
        return iArr;
    }

    private AbstractList<Set<Node>> sortTrees(AbstractList<Set<Node>> abstractList, int[] iArr) {
        int[] sortWithRanks = VecCalc.sortWithRanks(iArr);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < abstractList.size(); i++) {
            arrayList.add(abstractList.get(sortWithRanks[i]));
        }
        return arrayList;
    }

    private AbstractList<Set<Integer>> sortTreesFMG(AbstractList<Set<Integer>> abstractList, int[] iArr) {
        int[] sortWithRanks = VecCalc.sortWithRanks(iArr);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < abstractList.size(); i++) {
            arrayList.add(abstractList.get(sortWithRanks[i]));
        }
        return arrayList;
    }
}
