package cc.mallet.classify;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/classify/MaxEntOptimizableByLabelDistribution.class */
public class MaxEntOptimizableByLabelDistribution implements Optimizable.ByGradientValue {
    private static Logger logger;
    private static Logger progressLogger;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    static final Class DEFAULT_MAXIMIZER_CLASS;
    double[] parameters;
    double[] constraints;
    double[] cachedGradient;
    MaxEnt theClassifier;
    InstanceList trainingList;
    double cachedValue;
    boolean cachedValueStale;
    boolean cachedGradientStale;
    int numLabels;
    int numFeatures;
    int defaultFeatureIndex;
    FeatureSelection featureSelection;
    FeatureSelection[] perLabelFeatureSelection;
    static final /* synthetic */ boolean $assertionsDisabled;
    double gaussianPriorVariance = 1.0d;
    Class maximizerClass = DEFAULT_MAXIMIZER_CLASS;
    int numGetValueCalls = 0;
    int numGetValueGradientCalls = 0;

    public MaxEntOptimizableByLabelDistribution() {
    }

    public MaxEntOptimizableByLabelDistribution(InstanceList instanceList, MaxEnt maxEnt) {
        this.trainingList = instanceList;
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        LabelAlphabet labelAlphabet = (LabelAlphabet) instanceList.getTargetAlphabet();
        labelAlphabet.stopGrowth();
        this.numLabels = labelAlphabet.size();
        this.numFeatures = dataAlphabet.size() + 1;
        this.defaultFeatureIndex = this.numFeatures - 1;
        this.parameters = new double[this.numLabels * this.numFeatures];
        this.constraints = new double[this.numLabels * this.numFeatures];
        this.cachedGradient = new double[this.numLabels * this.numFeatures];
        Arrays.fill(this.parameters, 0.0d);
        Arrays.fill(this.constraints, 0.0d);
        Arrays.fill(this.cachedGradient, 0.0d);
        this.featureSelection = instanceList.getFeatureSelection();
        this.perLabelFeatureSelection = instanceList.getPerLabelFeatureSelection();
        if (this.featureSelection != null) {
            this.featureSelection.add(this.defaultFeatureIndex);
        }
        if (this.perLabelFeatureSelection != null) {
            for (int i = 0; i < this.perLabelFeatureSelection.length; i++) {
                this.perLabelFeatureSelection[i].add(this.defaultFeatureIndex);
            }
        }
        if (!$assertionsDisabled && this.featureSelection != null && this.perLabelFeatureSelection != null) {
            throw new AssertionError();
        }
        if (maxEnt != null) {
            this.theClassifier = maxEnt;
            this.parameters = this.theClassifier.parameters;
            this.featureSelection = this.theClassifier.featureSelection;
            this.perLabelFeatureSelection = this.theClassifier.perClassFeatureSelection;
            this.defaultFeatureIndex = this.theClassifier.defaultFeatureIndex;
            if (!$assertionsDisabled && maxEnt.getInstancePipe() != instanceList.getPipe()) {
                throw new AssertionError();
            }
        } else if (this.theClassifier == null) {
            this.theClassifier = new MaxEnt(instanceList.getPipe(), this.parameters, this.featureSelection, this.perLabelFeatureSelection);
        }
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        logger.fine("Number of instances in training list = " + this.trainingList.size());
        Iterator<Instance> it2 = this.trainingList.iterator();
        while (it2.hasNext()) {
            Instance next = it2.next();
            double instanceWeight = this.trainingList.getInstanceWeight(next);
            Labeling labeling = next.getLabeling();
            if (labeling != null) {
                FeatureVector featureVector = (FeatureVector) next.getData();
                Alphabet alphabet = featureVector.getAlphabet();
                if (!$assertionsDisabled && featureVector.getAlphabet() != dataAlphabet) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && labeling.numLocations() != instanceList.getTargetAlphabet().size()) {
                    throw new AssertionError();
                }
                for (int i2 = 0; i2 < labeling.numLocations(); i2++) {
                    MatrixOps.rowPlusEquals(this.constraints, this.numFeatures, labeling.indexAtLocation(i2), featureVector, instanceWeight * labeling.valueAtLocation(i2));
                }
                if (!$assertionsDisabled && Double.isNaN(instanceWeight)) {
                    throw new AssertionError("instanceWeight is NaN");
                }
                boolean z = false;
                for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                    if (Double.isNaN(featureVector.valueAtLocation(i3))) {
                        logger.info("NaN for feature " + alphabet.lookupObject(featureVector.indexAtLocation(i3)).toString());
                        z = true;
                    }
                }
                if (z) {
                    logger.info("NaN in instance: " + next.getName());
                }
                for (int i4 = 0; i4 < labeling.numLocations(); i4++) {
                    double[] dArr = this.constraints;
                    int indexAtLocation = (labeling.indexAtLocation(i4) * this.numFeatures) + this.defaultFeatureIndex;
                    dArr[indexAtLocation] = dArr[indexAtLocation] + (1.0d * instanceWeight * labeling.value(labeling.indexAtLocation(i4)));
                }
            }
        }
    }

    public MaxEnt getClassifier() {
        return this.theClassifier;
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        this.parameters[i] = d;
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        if (dArr == null || dArr.length != this.parameters.length) {
            dArr = new double[this.parameters.length];
        }
        System.arraycopy(this.parameters, 0, dArr, 0, this.parameters.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError();
        }
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        if (dArr.length != this.parameters.length) {
            this.parameters = new double[dArr.length];
        }
        System.arraycopy(dArr, 0, this.parameters, 0, dArr.length);
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (this.cachedValueStale) {
            this.numGetValueCalls++;
            this.cachedValue = 0.0d;
            this.cachedGradientStale = true;
            MatrixOps.setAll(this.cachedGradient, 0.0d);
            double[] dArr = new double[this.trainingList.getTargetAlphabet().size()];
            Iterator<Instance> it2 = this.trainingList.iterator();
            int i = 0;
            while (it2.hasNext()) {
                i++;
                Instance next = it2.next();
                double instanceWeight = this.trainingList.getInstanceWeight(next);
                Labeling labeling = next.getLabeling();
                if (labeling != null) {
                    this.theClassifier.getClassificationScores(next, dArr);
                    FeatureVector featureVector = (FeatureVector) next.getData();
                    double d = 0.0d;
                    for (int i2 = 0; i2 < labeling.numLocations(); i2++) {
                        int indexAtLocation = labeling.indexAtLocation(i2);
                        if (dArr[indexAtLocation] == 0.0d && labeling.valueAtLocation(i2) > 0.0d) {
                            logger.warning("Instance " + next.getSource() + " has infinite value; skipping value and gradient");
                            this.cachedValue = Double.NEGATIVE_INFINITY;
                            this.cachedValueStale = false;
                            return this.cachedValue;
                        }
                        if (labeling.valueAtLocation(i2) != 0.0d) {
                            d -= (instanceWeight * labeling.valueAtLocation(i2)) * Math.log(dArr[indexAtLocation]);
                        }
                    }
                    if (Double.isNaN(d)) {
                        logger.fine("MaxEntOptimizableByLabelDistribution: Instance " + next.getName() + "has NaN value.");
                    }
                    if (Double.isInfinite(d)) {
                        logger.warning("Instance " + next.getSource() + " has infinite value; skipping value and gradient");
                        this.cachedValue -= d;
                        this.cachedValueStale = false;
                        return -d;
                    }
                    this.cachedValue += d;
                    for (int i3 = 0; i3 < dArr.length; i3++) {
                        if (dArr[i3] != 0.0d) {
                            if (!$assertionsDisabled && Double.isInfinite(dArr[i3])) {
                                throw new AssertionError();
                            }
                            MatrixOps.rowPlusEquals(this.cachedGradient, this.numFeatures, i3, featureVector, (-instanceWeight) * dArr[i3]);
                            double[] dArr2 = this.cachedGradient;
                            int i4 = (this.numFeatures * i3) + this.defaultFeatureIndex;
                            dArr2[i4] = dArr2[i4] + ((-instanceWeight) * dArr[i3]);
                        }
                    }
                }
            }
            double d2 = 0.0d;
            for (int i5 = 0; i5 < this.numLabels; i5++) {
                for (int i6 = 0; i6 < this.numFeatures; i6++) {
                    double d3 = this.parameters[(i5 * this.numFeatures) + i6];
                    d2 += (d3 * d3) / (2.0d * this.gaussianPriorVariance);
                }
            }
            double d4 = this.cachedValue;
            this.cachedValue += d2;
            this.cachedValue *= -1.0d;
            this.cachedValueStale = false;
            progressLogger.info("Value (labelProb=" + (-d4) + " prior=" + (-d2) + ") loglikelihood = " + this.cachedValue);
        }
        return this.cachedValue;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cachedGradientStale) {
            this.numGetValueGradientCalls++;
            if (this.cachedValueStale) {
                getValue();
            }
            MatrixOps.plusEquals(this.cachedGradient, this.constraints);
            MatrixOps.plusEquals(this.cachedGradient, this.parameters, (-1.0d) / this.gaussianPriorVariance);
            MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0d);
            if (this.perLabelFeatureSelection == null) {
                for (int i = 0; i < this.numLabels; i++) {
                    MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i, 0.0d, this.featureSelection, false);
                }
            } else {
                for (int i2 = 0; i2 < this.numLabels; i2++) {
                    MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i2, 0.0d, this.perLabelFeatureSelection[i2], false);
                }
            }
            this.cachedGradientStale = false;
        }
        if (!$assertionsDisabled && (dArr == null || dArr.length != this.parameters.length)) {
            throw new AssertionError();
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
    }

    public int getValueGradientCalls() {
        return this.numGetValueGradientCalls;
    }

    public int getValueCalls() {
        return this.numGetValueCalls;
    }

    public MaxEntOptimizableByLabelDistribution useGaussianPrior() {
        return this;
    }

    public MaxEntOptimizableByLabelDistribution setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
        return this;
    }

    static {
        $assertionsDisabled = !MaxEntOptimizableByLabelDistribution.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(MaxEntOptimizableByLabelDistribution.class.getName());
        progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByLabelDistribution.class.getName() + "-pl");
        DEFAULT_MAXIMIZER_CLASS = LimitedMemoryBFGS.class;
    }
}
