package cc.mallet.grmm.inference.gbp;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import gnu.trove.THashSet;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.springframework.jdbc.datasource.init.ScriptUtils;

/* loaded from: input_file:cc/mallet/grmm/inference/gbp/RegionGraph.class */
class RegionGraph {
    private Set regions = new THashSet();
    private List edges = new ArrayList();

    /* JADX INFO: Access modifiers changed from: package-private */
    public void add(Region region, Region region2) {
        if (isConnected(region, region2)) {
            return;
        }
        addRegion(region);
        addRegion(region2);
        region2.isRoot = false;
        if (region.children == null) {
            region.children = new ArrayList();
        }
        region.children.add(region2);
        if (region2.parents == null) {
            region2.parents = new ArrayList();
        }
        region2.parents.add(region);
        this.edges.add(new RegionEdge(region, region2));
    }

    private boolean isConnected(Region region, Region region2) {
        return region.children.contains(region2);
    }

    private void addRegion(Region region) {
        if (this.regions.add(region)) {
            if (region.index != -1) {
                throw new IllegalArgumentException("Region " + region + " has already been added to a different region graph.");
            }
            region.index = this.regions.size() - 1;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int size() {
        return this.regions.size();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Iterator iterator() {
        return this.regions.iterator();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Iterator edgeIterator() {
        return this.edges.iterator();
    }

    public void computeInferenceCaches() {
        computeDescendants();
        includeDescendantFactors();
        computeFactorsToSend();
        computeCountingNumbers();
        computeCousins();
        computeNeighboringParents();
        computeLoopingMessages();
    }

    private void includeDescendantFactors() {
        Iterator it2 = iterator();
        while (it2.hasNext()) {
            Region region = (Region) it2.next();
            Iterator it3 = region.descendants.iterator();
            while (it3.hasNext()) {
                region.factors.addAll(((Region) it3.next()).factors);
            }
        }
    }

    private void computeLoopingMessages() {
        Iterator edgeIterator = edgeIterator();
        while (edgeIterator.hasNext()) {
            RegionEdge regionEdge = (RegionEdge) edgeIterator.next();
            Region region = regionEdge.to;
            ArrayList arrayList = new ArrayList();
            for (Region region2 : regionEdge.cousins) {
                if (region2 != regionEdge.from) {
                    for (Region region3 : region2.children) {
                        if (region3 == region || region.descendants.contains(region3)) {
                            arrayList.add(findEdge(region2, region3));
                        }
                    }
                }
            }
            regionEdge.loopingMessages = arrayList;
        }
    }

    private void computeCountingNumbers() {
        LinkedList linkedList = new LinkedList();
        for (Region region : this.regions) {
            if (region.isRoot) {
                linkedList.add(region);
            }
        }
        while (!linkedList.isEmpty()) {
            Region region2 = (Region) linkedList.removeFirst();
            int i = 0;
            Iterator it2 = region2.parents.iterator();
            while (it2.hasNext()) {
                i += ((Region) it2.next()).countingNumber;
            }
            region2.countingNumber = 1 - i;
            linkedList.addAll(region2.children);
        }
    }

    private void computeFactorsToSend() {
        Iterator it2 = this.edges.iterator();
        while (it2.hasNext()) {
            ((RegionEdge) it2.next()).initializeFactorsToSend();
        }
    }

    private void computeCousins() {
        Iterator edgeIterator = edgeIterator();
        while (edgeIterator.hasNext()) {
            RegionEdge regionEdge = (RegionEdge) edgeIterator.next();
            THashSet tHashSet = new THashSet(regionEdge.from.descendants);
            tHashSet.removeAll(regionEdge.to.descendants);
            tHashSet.remove(regionEdge.to);
            tHashSet.add(regionEdge.from);
            regionEdge.cousins = tHashSet;
        }
    }

    private void computeDescendants() {
        for (Region region : this.regions) {
            if (region.isRoot) {
                computeDescendantsRec(region);
            }
        }
    }

    private void computeDescendantsRec(Region region) {
        THashSet tHashSet = new THashSet(region.children.size());
        for (Region region2 : region.children) {
            computeDescendantsRec(region2);
            tHashSet.add(region2);
            tHashSet.addAll(region2.descendants);
        }
        region.descendants = tHashSet;
    }

    private void computeNeighboringParents() {
        Iterator edgeIterator = edgeIterator();
        while (edgeIterator.hasNext()) {
            RegionEdge regionEdge = (RegionEdge) edgeIterator.next();
            regionEdge.neighboringParents = new ArrayList();
            LinkedList<Region> linkedList = new LinkedList(this.regions);
            linkedList.removeAll(regionEdge.from.descendants);
            linkedList.remove(regionEdge.from);
            for (Region region : linkedList) {
                for (Region region2 : region.children) {
                    if (regionEdge.cousins.contains(region2)) {
                        regionEdge.neighboringParents.add(findEdge(region, region2));
                    }
                }
            }
        }
    }

    private RegionEdge findEdge(Region region, Region region2) {
        return (RegionEdge) this.edges.get(this.edges.indexOf(new RegionEdge(region, region2)));
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("REGION GRAPH\nRegions:\n");
        for (Region region : this.regions) {
            stringBuffer.append("\n    ");
            stringBuffer.append(region);
        }
        stringBuffer.append("\nEdges:");
        for (RegionEdge regionEdge : this.edges) {
            stringBuffer.append("\n   ");
            stringBuffer.append(regionEdge.from);
            stringBuffer.append(" --> ");
            stringBuffer.append(regionEdge.to);
        }
        stringBuffer.append(ScriptUtils.FALLBACK_STATEMENT_SEPARATOR);
        return stringBuffer.toString();
    }

    public boolean contains(Region region) {
        return this.regions.contains(region);
    }

    public Region findRegion(Factor factor, boolean z) {
        VarSet varSet = factor.varSet();
        for (Region region : this.regions) {
            if (region.vars.size() == varSet.size() && region.vars.containsAll(varSet)) {
                return region;
            }
        }
        if (!z) {
            return null;
        }
        Region region2 = new Region(factor);
        addRegion(region2);
        return region2;
    }

    public Region findRegion(Variable variable, boolean z) {
        for (Region region : this.regions) {
            if (region.vars.size() == 1 && region.vars.contains(variable)) {
                return region;
            }
        }
        if (!z) {
            return null;
        }
        Region region2 = new Region(variable);
        addRegion(region2);
        return region2;
    }

    public Region findContainingRegion(Variable variable) {
        Region region = null;
        for (Region region2 : this.regions) {
            if (region2.vars.contains(variable) && (region == null || region2.vars.size() < region.vars.size())) {
                region = region2;
            }
        }
        return region;
    }

    public Region findContainingRegion(VarSet varSet) {
        Region region = null;
        for (Region region2 : this.regions) {
            if (region2.vars.containsAll(varSet) && (region == null || region2.vars.size() < region.vars.size())) {
                region = region2;
            }
        }
        return region;
    }

    public int numEdges() {
        return this.edges.size();
    }
}
