package cc.mallet.grmm.types;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Matrix;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Maths;
import gnu.trove.TDoubleArrayList;
import gnu.trove.TIntArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/grmm/types/Factors.class */
public class Factors {
    public static CPT normalizeAsCpt(AbstractTableFactor abstractTableFactor, Variable variable) {
        double[] dArr = new double[abstractTableFactor.numLocations()];
        Arrays.fill(dArr, Double.NEGATIVE_INFINITY);
        new HashVarSet(abstractTableFactor.varSet()).remove(variable);
        AssignmentIterator assignmentIterator = abstractTableFactor.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            Assignment assignment = assignmentIterator.assignment();
            int singleIndex = ((Assignment) assignment.marginalizeOut(variable)).singleIndex();
            dArr[singleIndex] = Maths.sumLogProb(abstractTableFactor.logValue(assignment), dArr[singleIndex]);
            assignmentIterator.advance();
        }
        AssignmentIterator assignmentIterator2 = abstractTableFactor.assignmentIterator();
        while (assignmentIterator2.hasNext()) {
            Assignment assignment2 = assignmentIterator2.assignment();
            double logValue = abstractTableFactor.logValue(assignment2);
            double d = dArr[((Assignment) assignment2.marginalizeOut(variable)).singleIndex()];
            if (Double.isInfinite(logValue) && Double.isInfinite(d)) {
                abstractTableFactor.setLogValue(assignment2, Double.NEGATIVE_INFINITY);
            } else {
                abstractTableFactor.setLogValue(assignment2, logValue - d);
            }
            assignmentIterator2.advance();
        }
        return new CPT(abstractTableFactor, variable);
    }

    public static Factor average(Factor factor, Factor factor2, double d) {
        return TableFactor.hackyMixture((TableFactor) factor, (TableFactor) factor2, d);
    }

    public static double oneDistance(Factor factor, Factor factor2) {
        if (!factor.varSet().equals(factor2.varSet())) {
            throw new IllegalArgumentException("Attempt to take distancebetween mismatching potentials " + factor + " and " + factor2);
        }
        double d = 0.0d;
        AssignmentIterator assignmentIterator = factor.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            Assignment assignment = assignmentIterator.assignment();
            d += Math.abs(factor.value(assignment) - factor2.value(assignment));
            assignmentIterator.advance();
        }
        return d;
    }

    public static TableFactor retainMass(DiscreteFactor discreteFactor, double d) {
        int[] iArr = new int[discreteFactor.numLocations()];
        double[] dArr = new double[discreteFactor.numLocations()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = discreteFactor.indexAtLocation(i);
            dArr[i] = discreteFactor.logValue(i);
        }
        RankedFeatureVector rankedFeatureVector = new RankedFeatureVector(new Alphabet(), iArr, dArr);
        TIntArrayList tIntArrayList = new TIntArrayList();
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        double d2 = Double.NEGATIVE_INFINITY;
        double log = Math.log(d);
        for (int i2 = 0; i2 < rankedFeatureVector.numLocations(); i2++) {
            int indexAtRank = rankedFeatureVector.getIndexAtRank(i2);
            double value = rankedFeatureVector.value(indexAtRank);
            d2 = Maths.sumLogProb(d2, value);
            tIntArrayList.add(indexAtRank);
            tDoubleArrayList.add(value);
            if (d2 > log) {
                break;
            }
        }
        SparseMatrixn sparseMatrixn = new SparseMatrixn(computeSizes(discreteFactor), tIntArrayList.toNativeArray(), tDoubleArrayList.toNativeArray());
        TableFactor tableFactor = new TableFactor(computeVars(discreteFactor));
        tableFactor.setValues(sparseMatrixn);
        return tableFactor;
    }

    public static int[] computeSizes(Factor factor) {
        int size = factor.varSet().size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = factor.getVariable(i).getNumOutcomes();
        }
        return iArr;
    }

    public static Variable[] computeVars(Factor factor) {
        int size = factor.varSet().size();
        Variable[] variableArr = new Variable[size];
        for (int i = 0; i < size; i++) {
            variableArr[i] = factor.getVariable(i);
        }
        return variableArr;
    }

    public static double mutualInformation(Factor factor) {
        VarSet varSet = factor.varSet();
        if (varSet.size() != 2) {
            throw new IllegalArgumentException("Factor must have size 2");
        }
        Factor marginalize = factor.marginalize(varSet.get(0));
        Factor marginalize2 = factor.marginalize(varSet.get(1));
        double d = 0.0d;
        AssignmentIterator assignmentIterator = factor.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            Assignment assignment = (Assignment) assignmentIterator.next();
            d += factor.value(assignment) * ((factor.logValue(assignment) - marginalize.logValue(assignment)) - marginalize2.logValue(assignment));
        }
        return d;
    }

    public static double KL(AbstractTableFactor abstractTableFactor, AbstractTableFactor abstractTableFactor2) {
        double d = 0.0d;
        for (int i = 0; i < abstractTableFactor.numLocations(); i++) {
            double valueAtLocation = abstractTableFactor.valueAtLocation(i);
            double value = abstractTableFactor2.value(abstractTableFactor.indexAtLocation(i));
            if (valueAtLocation > 1.0E-5d) {
                d += valueAtLocation * Math.log(valueAtLocation / value);
            }
        }
        return d;
    }

    public static Factor mix(AbstractTableFactor abstractTableFactor, AbstractTableFactor abstractTableFactor2, double d) {
        return AbstractTableFactor.hackyMixture(abstractTableFactor, abstractTableFactor2, d);
    }

    public static double euclideanDistance(AbstractTableFactor abstractTableFactor, AbstractTableFactor abstractTableFactor2) {
        double d = 0.0d;
        for (int i = 0; i < abstractTableFactor.numLocations(); i++) {
            double valueAtLocation = abstractTableFactor.valueAtLocation(i);
            double value = abstractTableFactor2.value(abstractTableFactor.indexAtLocation(i));
            d += (valueAtLocation - value) * (valueAtLocation - value);
        }
        return Math.sqrt(d);
    }

    public static double l1Distance(AbstractTableFactor abstractTableFactor, AbstractTableFactor abstractTableFactor2) {
        double d = 0.0d;
        for (int i = 0; i < abstractTableFactor.numLocations(); i++) {
            d += Math.abs(abstractTableFactor.valueAtLocation(i) - abstractTableFactor2.value(abstractTableFactor.indexAtLocation(i)));
        }
        return d;
    }

    public static Factor asFactor(final Inferencer inferencer) {
        return new SkeletonFactor() { // from class: cc.mallet.grmm.types.Factors.1
            @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
            public double value(Assignment assignment) {
                return Inferencer.this.lookupMarginal(assignment.varSet()).value(assignment);
            }

            @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
            public Factor marginalize(Variable[] variableArr) {
                return Inferencer.this.lookupMarginal(new HashVarSet(variableArr));
            }

            @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
            public Factor marginalize(Collection collection) {
                return Inferencer.this.lookupMarginal(new HashVarSet(collection));
            }

            @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
            public Factor marginalize(Variable variable) {
                return Inferencer.this.lookupMarginal(new HashVarSet(new Variable[]{variable}));
            }

            @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
            public Factor marginalizeOut(Variable variable) {
                throw new UnsupportedOperationException();
            }

            @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
            public Factor marginalizeOut(VarSet varSet) {
                throw new UnsupportedOperationException();
            }

            @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
            public VarSet varSet() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public static Variable[] discreteVarsOf(Factor factor) {
        ArrayList arrayList = new ArrayList();
        VarSet varSet = factor.varSet();
        for (int i = 0; i < varSet.size(); i++) {
            Variable variable = varSet.get(i);
            if (!variable.isContinuous()) {
                arrayList.add(variable);
            }
        }
        return (Variable[]) arrayList.toArray(new Variable[arrayList.size()]);
    }

    public static Variable[] continuousVarsOf(Factor factor) {
        ArrayList arrayList = new ArrayList();
        VarSet varSet = factor.varSet();
        for (int i = 0; i < varSet.size(); i++) {
            Variable variable = varSet.get(i);
            if (variable.isContinuous()) {
                arrayList.add(variable);
            }
        }
        return (Variable[]) arrayList.toArray(new Variable[arrayList.size()]);
    }

    public static double corr(Factor factor) {
        if (factor.varSet().size() != 2) {
            throw new IllegalArgumentException("corr() only works on Factors of size 2, tried " + factor);
        }
        Variable variable = factor.varSet().get(0);
        Variable variable2 = factor.varSet().get(1);
        double d = 0.0d;
        AssignmentIterator assignmentIterator = factor.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            d += factor.value((Assignment) assignmentIterator.next()) * r0.get(variable) * r0.get(variable2);
        }
        return d - (mean(factor.marginalize(variable)) * mean(factor.marginalize(variable2)));
    }

    private static double mean(Factor factor) {
        if (factor.varSet().size() != 1) {
            throw new IllegalArgumentException("mean() only works on Factors of size 1, tried " + factor);
        }
        Variable variable = factor.varSet().get(0);
        double d = 0.0d;
        AssignmentIterator assignmentIterator = factor.assignmentIterator();
        while (assignmentIterator.hasNext()) {
            d += factor.value((Assignment) assignmentIterator.next()) * r0.get(variable);
        }
        return d;
    }

    public static Factor multiplyAll(Collection collection) {
        Factor factor = (Factor) collection.iterator().next();
        if (collection.size() == 1) {
            return factor.duplicate();
        }
        HashVarSet hashVarSet = new HashVarSet();
        Iterator it2 = collection.iterator();
        while (it2.hasNext()) {
            hashVarSet.addAll(((Factor) it2.next()).varSet());
        }
        Factor duplicate = factor.duplicate();
        Iterator it3 = collection.iterator();
        while (it3.hasNext()) {
            duplicate.multiplyBy((Factor) it3.next());
        }
        return duplicate;
    }

    public static double distLinf(AbstractTableFactor abstractTableFactor, AbstractTableFactor abstractTableFactor2) {
        return matrixDistLinf(abstractTableFactor.getLogValueMatrix(), abstractTableFactor2.getLogValueMatrix());
    }

    public static double distValueLinf(AbstractTableFactor abstractTableFactor, AbstractTableFactor abstractTableFactor2) {
        return matrixDistLinf(abstractTableFactor.getValueMatrix(), abstractTableFactor2.getValueMatrix());
    }

    private static double matrixDistLinf(Matrix matrix, Matrix matrix2) {
        double d = 0.0d;
        int singleSize = matrix.singleSize();
        if (singleSize != matrix2.singleSize()) {
            return Double.POSITIVE_INFINITY;
        }
        for (int i = 0; i < singleSize; i++) {
            double singleValue = matrix.singleValue(i);
            double singleValue2 = matrix2.singleValue(i);
            double d2 = singleValue > singleValue2 ? singleValue - singleValue2 : singleValue2 - singleValue;
            d = d2 > d ? d2 : d;
        }
        return d;
    }

    public static double logErrorRange(AbstractTableFactor abstractTableFactor, AbstractTableFactor abstractTableFactor2) {
        double d = Double.MAX_VALUE;
        double d2 = 0.0d;
        Matrix logValueMatrix = abstractTableFactor.getLogValueMatrix();
        Matrix logValueMatrix2 = abstractTableFactor2.getLogValueMatrix();
        int singleSize = logValueMatrix.singleSize();
        if (singleSize != logValueMatrix2.singleSize()) {
            return Double.POSITIVE_INFINITY;
        }
        for (int i = 0; i < singleSize; i++) {
            double singleValue = logValueMatrix.singleValue(i);
            double singleValue2 = logValueMatrix2.singleValue(i);
            double d3 = singleValue > singleValue2 ? singleValue - singleValue2 : singleValue2 - singleValue;
            d2 = d3 > d2 ? d3 : d2;
            d = d3 < d ? d3 : d;
        }
        return d2 - d;
    }
}
