package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/pr/ConstraintsOptimizableByPR.class */
public class ConstraintsOptimizableByPR implements Serializable, Optimizable.ByGradientValue {
    private static Logger logger = MalletLogger.getLogger(ConstraintsOptimizableByPR.class.getName());
    private static final long serialVersionUID = 1;
    protected boolean cacheStale;
    protected int numParameters;
    protected int numThreads;
    protected InstanceList trainingSet;
    protected double cachedValue;
    protected double[] cachedGradient;
    protected CRF crf;
    protected ThreadPoolExecutor executor;
    protected double[][][][] cachedDots;
    PRAuxiliaryModel model;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/mallet/fst/semi_supervised/pr/ConstraintsOptimizableByPR$ExpectationTask.class */
    public class ExpectationTask implements Callable<Double> {
        private int start;
        private int end;
        private PRAuxiliaryModel modelCopy;

        public ExpectationTask(int i, int i2, PRAuxiliaryModel pRAuxiliaryModel) {
            this.start = i;
            this.end = i2;
            this.modelCopy = pRAuxiliaryModel;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() throws Exception {
            double d = 0.0d;
            for (int i = this.start; i < this.end; i++) {
                d -= new SumLatticePR(ConstraintsOptimizableByPR.this.crf, i, (Sequence) ConstraintsOptimizableByPR.this.trainingSet.get(i).getData(), null, this.modelCopy, ConstraintsOptimizableByPR.this.cachedDots[i], true, null, null, false).getTotalWeight();
            }
            return Double.valueOf(d);
        }

        public PRAuxiliaryModel getModelCopy() {
            return this.modelCopy;
        }
    }

    public ConstraintsOptimizableByPR(CRF crf, InstanceList instanceList, PRAuxiliaryModel pRAuxiliaryModel) {
        this(crf, instanceList, pRAuxiliaryModel, 1);
    }

    public ConstraintsOptimizableByPR(CRF crf, InstanceList instanceList, PRAuxiliaryModel pRAuxiliaryModel, int i) {
        this.cachedValue = -1.23456789E8d;
        this.crf = crf;
        this.trainingSet = instanceList;
        this.model = pRAuxiliaryModel;
        this.numParameters = pRAuxiliaryModel.numParameters();
        this.cachedGradient = new double[this.numParameters];
        this.cacheStale = true;
        this.numThreads = i;
        this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(i);
        cacheDotProducts();
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [double[][][], double[][][][]] */
    public void cacheDotProducts() {
        this.cachedDots = new double[this.trainingSet.size()][];
        for (int i = 0; i < this.trainingSet.size(); i++) {
            FeatureVectorSequence featureVectorSequence = (FeatureVectorSequence) this.trainingSet.get(i).getData();
            this.cachedDots[i] = new double[featureVectorSequence.size()][this.crf.numStates()][this.crf.numStates()];
            for (int i2 = 0; i2 < featureVectorSequence.size(); i2++) {
                for (int i3 = 0; i3 < this.crf.numStates(); i3++) {
                    for (int i4 = 0; i4 < this.crf.numStates(); i4++) {
                        this.cachedDots[i][i2][i3][i4] = Double.NEGATIVE_INFINITY;
                    }
                }
            }
            for (int i5 = 0; i5 < featureVectorSequence.size(); i5++) {
                for (int i6 = 0; i6 < this.crf.numStates(); i6++) {
                    Transducer.TransitionIterator transitionIterator = this.crf.getState(i6).transitionIterator(featureVectorSequence, i5);
                    while (transitionIterator.hasNext()) {
                        this.cachedDots[i][i5][i6][transitionIterator.next().getIndex()] = transitionIterator.getWeight();
                    }
                }
            }
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.numParameters;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        this.model.getParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.model.getParameter(i);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        this.cacheStale = true;
        this.model.setParameters(dArr);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.cacheStale = true;
        this.model.setParameter(i, d);
    }

    protected double getExpectationValue() {
        this.model.zeroExpectations();
        ArrayList<Callable<Double>> arrayList = new ArrayList<>();
        int size = this.trainingSet.size() / this.numThreads;
        int i = 0;
        int i2 = size;
        int i3 = 0;
        while (i3 < this.numThreads) {
            arrayList.add(new ExpectationTask(i, i2, this.model.copy()));
            i = i2;
            i2 = i3 == this.numThreads - 2 ? this.trainingSet.size() : i + size;
            i3++;
        }
        double d = 0.0d;
        try {
            Iterator it2 = this.executor.invokeAll(arrayList).iterator();
            while (it2.hasNext()) {
                try {
                    d += ((Double) ((Future) it2.next()).get()).doubleValue();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
            }
        } catch (InterruptedException e2) {
            e2.printStackTrace();
        }
        combine(this.model, arrayList);
        return d + this.model.getValue();
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (this.cacheStale) {
            this.cachedValue = getExpectationValue();
            this.model.getValueGradient(this.cachedGradient);
            this.cacheStale = false;
            logger.info("getValue (auxiliary distribution) = " + this.cachedValue);
        }
        return this.cachedValue;
    }

    public double getCompleteValueContribution() {
        if (this.cacheStale) {
            getValue();
        }
        return this.model.getCompleteValueContribution();
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        if (this.cacheStale) {
            getValue();
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
    }

    private void combine(PRAuxiliaryModel pRAuxiliaryModel, ArrayList<Callable<Double>> arrayList) {
        for (int i = 0; i < arrayList.size(); i++) {
            PRAuxiliaryModel modelCopy = ((ExpectationTask) arrayList.get(i)).getModelCopy();
            for (int i2 = 0; i2 < modelCopy.numConstraints(); i2++) {
                PRConstraint constraint = pRAuxiliaryModel.getConstraint(i2);
                PRConstraint constraint2 = modelCopy.getConstraint(i2);
                double[] dArr = new double[constraint.numDimensions()];
                constraint2.getExpectations(dArr);
                constraint.addExpectations(dArr);
            }
        }
    }

    public void shutdown() {
        this.executor.shutdown();
    }

    public double[][][][] getCachedDots() {
        return this.cachedDots;
    }

    public PRAuxiliaryModel getAuxModel() {
        return this.model;
    }
}
