package de.up.ling.irtg.maxent;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import de.up.ling.irtg.InterpretedTreeAutomaton;
import de.up.ling.irtg.automata.Rule;
import de.up.ling.irtg.automata.TreeAutomaton;
import de.up.ling.irtg.corpus.Corpus;
import de.up.ling.irtg.corpus.Instance;
import de.up.ling.irtg.util.ProgressListener;
import de.up.ling.tree.Tree;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import java.io.IOException;
import java.io.Reader;
import java.io.StringWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.springframework.jdbc.datasource.init.ScriptUtils;

/* loaded from: input_file:de/up/ling/irtg/maxent/MaximumEntropyIrtg.class */
public class MaximumEntropyIrtg extends InterpretedTreeAutomaton {
    private static final Logger log = Logger.getLogger(MaximumEntropyIrtg.class.getName());
    private static final double INITIAL_WEIGHT = 0.5d;
    private double[] weights;
    private FeatureFunction[] features;
    private List<String> featureNames;
    private Map<Integer, double[]> f;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/up/ling/irtg/maxent/MaximumEntropyIrtg$MaxEntIrtgOptimizable.class */
    public class MaxEntIrtgOptimizable implements Optimizable.ByGradientValue {
        private boolean cachedStale;
        private double cachedValue;
        private double[] cachedGradient;
        private Corpus trainingData;
        private int iteration = 0;
        private ProgressListener listener;
        static final /* synthetic */ boolean $assertionsDisabled;

        public MaxEntIrtgOptimizable(Corpus corpus, ProgressListener progressListener) {
            this.cachedStale = true;
            this.cachedStale = true;
            this.trainingData = corpus;
            this.cachedGradient = new double[MaximumEntropyIrtg.this.getNumFeatures()];
            this.listener = progressListener;
        }

        private void getFiFor(Tree<Rule> tree, TreeAutomaton treeAutomaton, double[] dArr, Map<String, Object> map) {
            double[] orComputeFeatureValues = MaximumEntropyIrtg.this.getOrComputeFeatureValues(tree.getLabel(), treeAutomaton, map);
            for (int i = 0; i < orComputeFeatureValues.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + orComputeFeatureValues[i];
            }
            Iterator<Tree<Rule>> it2 = tree.getChildren().iterator();
            while (it2.hasNext()) {
                getFiFor(it2.next(), treeAutomaton, dArr, map);
            }
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public double getValue() {
            double d;
            if (this.cachedStale) {
                int numberOfInstances = this.trainingData.getNumberOfInstances();
                double d2 = 0.0d;
                double d3 = 0.0d;
                double[] dArr = new double[this.cachedGradient.length];
                double[] dArr2 = new double[this.cachedGradient.length];
                int i = 0;
                int i2 = 0;
                Iterator<Instance> it2 = this.trainingData.iterator();
                while (it2.hasNext()) {
                    Instance next = it2.next();
                    TreeAutomaton parseInputObjects = MaximumEntropyIrtg.this.parseInputObjects(next.getInputObjects());
                    if (parseInputObjects == null) {
                        i++;
                    } else {
                        Int2ObjectMap<Double> inside = parseInputObjects.inside();
                        Map<Integer, Double> outside = parseInputObjects.outside(inside);
                        double d4 = 0.0d;
                        Iterator<Integer> it3 = parseInputObjects.getFinalStates().iterator();
                        while (it3.hasNext()) {
                            d4 += inside.get(it3.next()).doubleValue();
                        }
                        d2 += Math.log(parseInputObjects.getWeightRaw(next.getDerivationTree()));
                        d3 += Math.log(d4);
                        for (Rule rule : parseInputObjects.getRuleSet()) {
                            Double d5 = outside.get(Integer.valueOf(rule.getParent()));
                            if (d5 != null) {
                                double doubleValue = d5.doubleValue() * rule.getWeight();
                                for (int i3 : rule.getChildren()) {
                                    Double d6 = inside.get(Integer.valueOf(i3));
                                    doubleValue = d6 != null ? doubleValue * d6.doubleValue() : 0.0d;
                                }
                                d = doubleValue / d4;
                            } else {
                                d = 0.0d;
                            }
                            double[] orComputeFeatureValues = MaximumEntropyIrtg.this.getOrComputeFeatureValues(rule, parseInputObjects, next.getInputObjects());
                            for (int i4 = 0; i4 < orComputeFeatureValues.length; i4++) {
                                int i5 = i4;
                                dArr2[i5] = dArr2[i5] + (orComputeFeatureValues[i4] * d);
                            }
                        }
                        try {
                            getFiFor(parseInputObjects.getRuleTree(next.getDerivationTree()), parseInputObjects, dArr, next.getInputObjects());
                            if (this.listener != null) {
                                int i6 = i2;
                                i2++;
                                this.listener.accept(i6, numberOfInstances, null);
                            }
                        } catch (Exception e) {
                            throw new RuntimeException("Could not reconstruct rule tree from derivation tree for instance " + next.toString(MaximumEntropyIrtg.this.getAutomaton()), e);
                        }
                    }
                }
                this.cachedValue = (d2 - d3) / numberOfInstances;
                for (int i7 = 0; i7 < this.cachedGradient.length; i7++) {
                    this.cachedGradient[i7] = (dArr[i7] - dArr2[i7]) / numberOfInstances;
                }
                this.cachedStale = false;
                if (i > 0) {
                    MaximumEntropyIrtg.log.log(Level.WARNING, "Skipped {0} instances. No suitable chart found.", Integer.valueOf(i));
                }
            }
            this.iteration++;
            return this.cachedValue;
        }

        @Override // cc.mallet.optimize.Optimizable.ByGradientValue
        public void getValueGradient(double[] dArr) {
            if (this.cachedStale) {
                getValue();
            }
            if (!$assertionsDisabled && (dArr == null || dArr.length != this.cachedGradient.length)) {
                throw new AssertionError();
            }
            System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
        }

        @Override // cc.mallet.optimize.Optimizable
        public int getNumParameters() {
            return MaximumEntropyIrtg.this.getFeatureWeights().length;
        }

        @Override // cc.mallet.optimize.Optimizable
        public void getParameters(double[] dArr) {
            System.arraycopy(MaximumEntropyIrtg.this.getFeatureWeights(), 0, dArr, 0, getNumParameters());
        }

        @Override // cc.mallet.optimize.Optimizable
        public double getParameter(int i) {
            return MaximumEntropyIrtg.this.getFeatureWeight(i);
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameters(double[] dArr) {
            MaximumEntropyIrtg.this.setFeatureWeights(dArr);
            this.cachedStale = true;
        }

        @Override // cc.mallet.optimize.Optimizable
        public void setParameter(int i, double d) {
            MaximumEntropyIrtg.this.setFeatureWeight(i, d);
            this.cachedStale = true;
        }

        static {
            $assertionsDisabled = !MaximumEntropyIrtg.class.desiredAssertionStatus();
        }
    }

    public MaximumEntropyIrtg(TreeAutomaton<String> treeAutomaton, Map<String, FeatureFunction> map) {
        super(treeAutomaton);
        this.f = new HashMap();
        setFeatures(map);
    }

    public final void setFeatures(Map<String, FeatureFunction> map) {
        if (map == null || map.isEmpty()) {
            this.featureNames = null;
            this.features = null;
            this.weights = null;
            return;
        }
        this.featureNames = new ArrayList();
        this.featureNames.addAll(map.keySet());
        this.features = new FeatureFunction[this.featureNames.size()];
        this.weights = new double[this.featureNames.size()];
        for (int i = 0; i < this.featureNames.size(); i++) {
            this.features[i] = map.get(this.featureNames.get(i));
            this.weights[i] = 0.5d;
        }
    }

    public void setFeatureWeights(double[] dArr) {
        this.weights = dArr;
    }

    public void setFeatureWeight(int i, double d) {
        this.weights[i] = d;
    }

    public double getFeatureWeight(int i) {
        if (this.weights == null || i >= this.weights.length) {
            return Double.NaN;
        }
        return this.weights[i];
    }

    public double[] getFeatureWeights() {
        return this.weights;
    }

    public List<String> getFeatureNames() {
        return this.featureNames;
    }

    public FeatureFunction getFeatureFunction(String str) {
        return getFeatureFunction(this.featureNames.indexOf(str));
    }

    public FeatureFunction getFeatureFunction(int i) {
        if (getFeatures() == null || i >= getFeatures().length) {
            return null;
        }
        return getFeatures()[i];
    }

    public int getNumFeatures() {
        if (this.featureNames == null) {
            return 0;
        }
        return this.featureNames.size();
    }

    @Override // de.up.ling.irtg.InterpretedTreeAutomaton
    public TreeAutomaton parseInputObjects(Map<String, Object> map) {
        TreeAutomaton parseInputObjects = super.parseInputObjects(map);
        setWeightsOnChart(parseInputObjects, map);
        return parseInputObjects;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double[] getOrComputeFeatureValues(Rule rule, TreeAutomaton treeAutomaton, Map<String, Object> map) {
        if (rule.getExtra() == null) {
            double[] dArr = new double[getNumFeatures()];
            for (int i = 0; i < getNumFeatures(); i++) {
                dArr[i] = ((Double) getFeatures()[i].evaluate(rule, treeAutomaton, this, map)).doubleValue();
            }
            rule.setExtra(dArr);
        }
        return (double[]) rule.getExtra();
    }

    private double getRuleScore(Rule rule, TreeAutomaton treeAutomaton, Map<String, Object> map) {
        double d = 0.0d;
        double[] orComputeFeatureValues = getOrComputeFeatureValues(rule, treeAutomaton, map);
        for (int i = 0; i < getNumFeatures(); i++) {
            d += orComputeFeatureValues[i] * this.weights[i];
        }
        return d;
    }

    private void setWeightsOnChart(TreeAutomaton treeAutomaton, Map<String, Object> map) {
        Iterable<Rule> ruleSet = treeAutomaton.getRuleSet();
        if (getFeatures() != null) {
            for (Rule rule : ruleSet) {
                rule.setWeight(Math.exp(getRuleScore(rule, treeAutomaton, map)));
            }
        }
    }

    public boolean trainMaxent(Corpus corpus) {
        return trainMaxent(corpus, null);
    }

    public boolean trainMaxent(Corpus corpus, ProgressListener progressListener) {
        LimitedMemoryBFGS limitedMemoryBFGS = new LimitedMemoryBFGS(new MaxEntIrtgOptimizable(corpus, progressListener));
        try {
            limitedMemoryBFGS.optimize();
        } catch (OptimizationException e) {
            log.log(Level.WARNING, e.toString());
        }
        if (limitedMemoryBFGS.isConverged()) {
            log.info("Optimization was successful.");
        } else {
            log.info("Optimization was unsuccessful.");
        }
        return limitedMemoryBFGS.isConverged();
    }

    public void readWeights(Reader reader) throws IOException {
        Properties properties = new Properties();
        properties.load(reader);
        for (Map.Entry entry : properties.entrySet()) {
            String valueOf = String.valueOf(entry.getKey());
            double doubleValue = Double.valueOf(String.valueOf(entry.getValue())).doubleValue();
            int indexOf = this.featureNames.indexOf(valueOf);
            if (indexOf >= 0) {
                this.weights[indexOf] = doubleValue;
            }
        }
    }

    public void writeWeights(Writer writer) throws IOException {
        Properties properties = new Properties();
        if (this.weights != null) {
            for (int i = 0; i < this.weights.length; i++) {
                properties.put(this.featureNames.get(i), String.valueOf(this.weights[i]));
            }
        }
        properties.store(writer, (String) null);
    }

    @Override // de.up.ling.irtg.InterpretedTreeAutomaton
    public String toString() {
        StringWriter stringWriter = new StringWriter();
        stringWriter.append((CharSequence) super.toString());
        if (this.featureNames != null) {
            for (int i = 0; i < this.featureNames.size(); i++) {
                stringWriter.append((CharSequence) "feature ");
                stringWriter.append((CharSequence) this.featureNames.get(i));
                stringWriter.append((CharSequence) ": ");
                stringWriter.append((CharSequence) getFeatures()[i].toString());
                stringWriter.append((CharSequence) ScriptUtils.FALLBACK_STATEMENT_SEPARATOR);
            }
        }
        return stringWriter.toString();
    }

    public FeatureFunction[] getFeatures() {
        return this.features;
    }

    public static void setLoggingLevel(Level level) {
        log.setLevel(level);
    }
}
