package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.Iterator;
import java.util.logging.Logger;
import org.springframework.jdbc.datasource.init.ScriptUtils;

/* loaded from: input_file:cc/mallet/fst/CRFOptimizableByLabelLikelihood.class */
public class CRFOptimizableByLabelLikelihood implements Optimizable.ByGradientValue, Serializable {
    private static Logger logger;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2d;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0d;
    protected InstanceList trainingSet;
    protected double[] cachedGradient;
    protected CRF crf;
    protected CRF.Factors constraints;
    protected CRF.Factors expectations;
    private int cachedValueWeightsStamp;
    private int cachedGradientWeightsStamp;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected double cachedValue = -1.23456789E8d;
    protected BitSet infiniteValues = null;
    boolean usingHyperbolicPrior = false;
    double gaussianPriorVariance = 1.0d;
    double hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
    double hyperbolicPriorSharpness = 10.0d;

    /* loaded from: input_file:cc/mallet/fst/CRFOptimizableByLabelLikelihood$Factory.class */
    public static class Factory {
        public Optimizable.ByGradientValue newCRFOptimizable(CRF crf, InstanceList instanceList) {
            return new CRFOptimizableByLabelLikelihood(crf, instanceList);
        }
    }

    public CRFOptimizableByLabelLikelihood(CRF crf, InstanceList instanceList) {
        this.cachedValueWeightsStamp = -1;
        this.cachedGradientWeightsStamp = -1;
        this.crf = crf;
        this.trainingSet = instanceList;
        this.cachedGradient = new double[crf.parameters.getNumFactors()];
        this.constraints = new CRF.Factors(crf.parameters);
        this.expectations = new CRF.Factors(crf.parameters);
        this.cachedValueWeightsStamp = -1;
        this.cachedGradientWeightsStamp = -1;
        gatherConstraints(instanceList);
    }

    protected void gatherConstraints(InstanceList instanceList) {
        Transducer.Incrementor weightedIncrementor;
        if (!$assertionsDisabled && !this.constraints.structureMatches(this.crf.parameters)) {
            throw new AssertionError();
        }
        this.constraints.zero();
        Iterator<Instance> it2 = instanceList.iterator();
        while (it2.hasNext()) {
            Instance next = it2.next();
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) next.getData();
            FeatureSequence featureSequence = (FeatureSequence) next.getTarget();
            double instanceWeight = instanceList.getInstanceWeight(next);
            if (instanceWeight == 1.0d) {
                CRF.Factors factors = this.constraints;
                factors.getClass();
                weightedIncrementor = new CRF.Factors.Incrementor();
            } else {
                CRF.Factors factors2 = this.constraints;
                factors2.getClass();
                weightedIncrementor = new CRF.Factors.WeightedIncrementor(instanceWeight);
            }
            new SumLatticeDefault(this.crf, featureVectorSequence, featureSequence, weightedIncrementor);
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.crf.parameters.getNumFactors();
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        this.crf.parameters.getParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.crf.parameters.getParameter(i);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        this.crf.parameters.setParameters(dArr);
        this.crf.weightsValueChanged();
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.crf.parameters.setParameter(i, d);
        this.crf.weightsValueChanged();
    }

    protected double getExpectationValue() {
        Transducer.Incrementor weightedIncrementor;
        boolean z = false;
        double d = 0.0d;
        if (this.infiniteValues == null) {
            this.infiniteValues = new BitSet();
            z = true;
        }
        if (!$assertionsDisabled && !this.expectations.structureMatches(this.crf.parameters)) {
            throw new AssertionError();
        }
        this.expectations.zero();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < this.trainingSet.size(); i4++) {
            Instance instance = this.trainingSet.get(i4);
            double instanceWeight = this.trainingSet.getInstanceWeight(instance);
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) instance.getData();
            double totalWeight = new SumLatticeDefault(this.crf, featureVectorSequence, (FeatureSequence) instance.getTarget(), (Transducer.Incrementor) null).getTotalWeight();
            String obj = instance.getName() == null ? "instance#" + i4 : instance.getName().toString();
            if (Double.isInfinite(totalWeight)) {
                i++;
                logger.warning(obj + " has -infinite labeled weight.\n" + (instance.getSource() != null ? instance.getSource() : ""));
            }
            if (instanceWeight == 1.0d) {
                CRF.Factors factors = this.expectations;
                factors.getClass();
                weightedIncrementor = new CRF.Factors.Incrementor();
            } else {
                CRF.Factors factors2 = this.expectations;
                factors2.getClass();
                weightedIncrementor = new CRF.Factors.WeightedIncrementor(instanceWeight);
            }
            double totalWeight2 = new SumLatticeDefault(this.crf, featureVectorSequence, null, weightedIncrementor).getTotalWeight();
            if (Double.isInfinite(totalWeight2)) {
                i2++;
                logger.warning(instance.getName().toString() + " has -infinite unlabeled weight.\n" + (instance.getSource() != null ? instance.getSource() : ""));
            }
            double d2 = totalWeight - totalWeight2;
            if (Double.isInfinite(d2)) {
                i3++;
                logger.warning(obj + " has -infinite weight; skipping.");
                if (z) {
                    this.infiniteValues.set(i4);
                } else if (!this.infiniteValues.get(i4)) {
                    throw new IllegalStateException("Instance i used to have non-infinite value, but now it has infinite value.");
                }
            } else {
                d += d2 * instanceWeight;
            }
        }
        if (i > 0 || i2 > 0 || i3 > 0) {
            logger.warning("Number of instances with:\n\t -infinite labeled weight: " + i + ScriptUtils.FALLBACK_STATEMENT_SEPARATOR + "\t -infinite unlabeled weight: " + i2 + ScriptUtils.FALLBACK_STATEMENT_SEPARATOR + "\t -infinite weight: " + i3);
        }
        return d;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (this.crf.weightsValueChangeStamp != this.cachedValueWeightsStamp) {
            this.cachedValueWeightsStamp = this.crf.weightsValueChangeStamp;
            long currentTimeMillis = System.currentTimeMillis();
            this.cachedValue = getExpectationValue();
            if (this.usingHyperbolicPrior) {
                this.cachedValue += this.crf.parameters.hyberbolicPrior(this.hyperbolicPriorSlope, this.hyperbolicPriorSharpness);
            } else {
                this.cachedValue += this.crf.parameters.gaussianPrior(this.gaussianPriorVariance);
            }
            if (!$assertionsDisabled && (Double.isNaN(this.cachedValue) || Double.isInfinite(this.cachedValue))) {
                throw new AssertionError("Label likelihood is NaN/Infinite");
            }
            logger.info("getValue() (loglikelihood, optimizable by label likelihood) = " + this.cachedValue);
            logger.fine("Inference milliseconds = " + (System.currentTimeMillis() - currentTimeMillis));
        }
        return this.cachedValue;
    }

    private void assertNotNaNOrInfinite() {
        this.crf.parameters.assertNotNaN();
        this.expectations.assertNotNaNOrInfinite();
        this.constraints.assertNotNaNOrInfinite();
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cachedGradientWeightsStamp != this.crf.weightsValueChangeStamp) {
            this.cachedGradientWeightsStamp = this.crf.weightsValueChangeStamp;
            getValue();
            assertNotNaNOrInfinite();
            this.expectations.plusEquals(this.constraints, -1.0d);
            if (this.usingHyperbolicPrior) {
                this.expectations.plusEqualsHyperbolicPriorGradient(this.crf.parameters, -this.hyperbolicPriorSlope, this.hyperbolicPriorSharpness);
            } else {
                this.expectations.plusEqualsGaussianPriorGradient(this.crf.parameters, -this.gaussianPriorVariance);
            }
            this.expectations.assertNotNaNOrInfinite();
            this.expectations.getParameters(this.cachedGradient);
            MatrixOps.timesEquals(this.cachedGradient, -1.0d);
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
    }

    public void setUseHyperbolicPrior(boolean z) {
        this.usingHyperbolicPrior = z;
    }

    public void setHyperbolicPriorSlope(double d) {
        this.hyperbolicPriorSlope = d;
    }

    public void setHyperbolicPriorSharpness(double d) {
        this.hyperbolicPriorSharpness = d;
    }

    public double getUseHyperbolicPriorSlope() {
        return this.hyperbolicPriorSlope;
    }

    public double getUseHyperbolicPriorSharpness() {
        return this.hyperbolicPriorSharpness;
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(0);
        objectOutputStream.writeObject(this.trainingSet);
        objectOutputStream.writeDouble(this.cachedValue);
        objectOutputStream.writeObject(this.cachedGradient);
        objectOutputStream.writeObject(this.infiniteValues);
        objectOutputStream.writeObject(this.crf);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.trainingSet = (InstanceList) objectInputStream.readObject();
        this.cachedValue = objectInputStream.readDouble();
        this.cachedGradient = (double[]) objectInputStream.readObject();
        this.infiniteValues = (BitSet) objectInputStream.readObject();
        this.crf = (CRF) objectInputStream.readObject();
    }

    static {
        $assertionsDisabled = !CRFOptimizableByLabelLikelihood.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(CRFOptimizableByLabelLikelihood.class.getName());
    }
}
