package cc.mallet.grmm.inference;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Variable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/grmm/inference/VariableElimination.class */
public class VariableElimination extends AbstractInferencer {
    transient FactorGraph mdlCurrent;
    private static final long serialVersionUID = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    private Factor eliminate(Collection collection, Variable variable) {
        HashSet hashSet = new HashSet();
        Iterator it2 = collection.iterator();
        while (it2.hasNext()) {
            Factor factor = (Factor) it2.next();
            if (factor.varSet().isEmpty() || factor.containsVar(variable)) {
                hashSet.add(factor);
                it2.remove();
            }
        }
        return TableFactor.multiplyAll(hashSet);
    }

    public Factor unnormalizedMarginal(FactorGraph factorGraph, Variable variable) {
        HashSet hashSet = new HashSet();
        Iterator factorsIterator = factorGraph.factorsIterator();
        while (factorsIterator.hasNext()) {
            hashSet.add(((Factor) factorsIterator.next()).duplicate());
        }
        for (Variable variable2 : factorGraph.variablesSet()) {
            if (variable2 != variable) {
                Factor eliminate = eliminate(hashSet, variable2);
                hashSet.add(eliminate.varSet().size() == 1 ? eliminate : eliminate.marginalizeOut(variable2));
            }
        }
        Factor eliminate2 = eliminate(hashSet, variable);
        if (!$assertionsDisabled && !eliminate2.containsVar(variable)) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || eliminate2.varSet().size() == 1) {
            return eliminate2;
        }
        throw new AssertionError();
    }

    public double computeNormalizationFactor(FactorGraph factorGraph) {
        return unnormalizedMarginal(factorGraph, (Variable) factorGraph.variablesSet().iterator().next()).sum();
    }

    @Override // cc.mallet.grmm.inference.AbstractInferencer, cc.mallet.grmm.inference.Inferencer
    public void computeMarginals(FactorGraph factorGraph) {
        this.mdlCurrent = factorGraph;
    }

    @Override // cc.mallet.grmm.inference.AbstractInferencer, cc.mallet.grmm.inference.Inferencer
    public Factor lookupMarginal(Variable variable) {
        Factor unnormalizedMarginal = unnormalizedMarginal(this.mdlCurrent, variable);
        unnormalizedMarginal.normalize();
        return unnormalizedMarginal;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
    }

    static {
        $assertionsDisabled = !VariableElimination.class.desiredAssertionStatus();
    }
}
