package de.up.ling.irtg.sampling.rule_weighting;

import de.up.ling.irtg.automata.Rule;
import de.up.ling.irtg.automata.TreeAutomaton;
import de.up.ling.irtg.learning_rates.LearningRate;
import de.up.ling.irtg.sampling.RuleWeighting;
import de.up.ling.irtg.sampling.TreeSample;
import de.up.ling.tree.Tree;
import it.unimi.dsi.fastutil.ints.Int2BooleanMap;
import it.unimi.dsi.fastutil.ints.Int2BooleanOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.objects.Object2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.ArrayList;
import java.util.Arrays;

/* loaded from: input_file:de/up/ling/irtg/sampling/rule_weighting/RegularizedKLRuleWeighting.class */
public abstract class RegularizedKLRuleWeighting implements RuleWeighting {
    public static double ALMOST_ZERO = 1.0E-15d;
    private final TreeAutomaton basis;
    private final Int2IntMap lastUpdated;
    private final int[] startStates;
    private final double[] startParameters;
    private final double[] startProbabilities;
    private final int regularizationExponent;
    private final double regularizationDivisor;
    private final LearningRate rate;
    private int updateNumber = 0;
    private final Int2BooleanMap currentProbs = new Int2BooleanOpenHashMap();
    private final Int2ObjectMap<double[]> ruleProbs = new Int2ObjectOpenHashMap();
    private final Int2ObjectMap<Rule[]> listRules = new Int2ObjectOpenHashMap();
    private final Int2ObjectMap<double[]> ruleParameters = new Int2ObjectOpenHashMap();
    private double underFlowPreventer = Double.NEGATIVE_INFINITY;

    public RegularizedKLRuleWeighting(TreeAutomaton treeAutomaton, int i, double d, LearningRate learningRate) {
        this.basis = treeAutomaton;
        IntArrayList intArrayList = new IntArrayList();
        IntIterator it2 = treeAutomaton.getFinalStates().iterator();
        while (it2.hasNext()) {
            intArrayList.add(it2.nextInt());
        }
        this.startStates = intArrayList.toIntArray();
        this.startParameters = new double[this.startStates.length];
        Arrays.fill(this.startParameters, 0.0d);
        Arrays.sort(this.startStates);
        this.startProbabilities = new double[this.startStates.length];
        Arrays.fill(this.startProbabilities, 0.0d);
        this.lastUpdated = new Int2IntOpenHashMap();
        this.lastUpdated.defaultReturnValue(0);
        this.regularizationExponent = i - 1;
        this.regularizationDivisor = 1.0d / d;
        this.rate = learningRate;
        this.currentProbs.defaultReturnValue(false);
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public double getLogProbability(int i, int i2) {
        return Math.log(this.ruleProbs.get(i)[i2]);
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public void prepareProbability(int i) {
        int i2 = this.lastUpdated.get(i);
        if (i2 < this.updateNumber || !this.currentProbs.get(i)) {
            Rule[] ensureRules = ensureRules(i);
            double[] dArr = this.ruleProbs.get(i);
            double[] dArr2 = this.ruleParameters.get(i);
            if (i2 < this.updateNumber) {
                while (i2 < this.updateNumber) {
                    for (int i3 = 0; i3 < ensureRules.length; i3++) {
                        adapt(i3, ensureRules[i3], dArr2, null, null, -1.0d);
                    }
                    i2++;
                }
                this.lastUpdated.put(i, this.updateNumber);
            }
            double d = 0.0d;
            for (int i4 = 0; i4 < dArr2.length; i4++) {
                double exp = Math.exp(dArr2[i4]);
                dArr[i4] = exp;
                d += exp;
            }
            for (int i5 = 0; i5 < dArr.length; i5++) {
                int i6 = i5;
                dArr[i6] = dArr[i6] / d;
            }
            this.currentProbs.put(i, true);
        }
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public double getStateStartLogProbability(int i) {
        return Math.log(this.startProbabilities[i]);
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public void prepareStartProbability() {
        double d = 0.0d;
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.startParameters.length; i++) {
            d2 = Math.max(d2, this.startParameters[i]);
        }
        for (int i2 = 0; i2 < this.startParameters.length; i2++) {
            double exp = Math.exp(this.startParameters[i2] - d2);
            this.startProbabilities[i2] = exp;
            d += exp;
        }
        for (int i3 = 0; i3 < this.startProbabilities.length; i3++) {
            double[] dArr = this.startProbabilities;
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public void reset() {
        Arrays.fill(this.startParameters, 0.0d);
        this.lastUpdated.clear();
        this.updateNumber = 0;
        this.currentProbs.clear();
        ObjectIterator<double[]> it2 = this.ruleParameters.values().iterator();
        while (it2.hasNext()) {
            Arrays.fill(it2.next(), 0.0d);
        }
        this.rate.reset();
        this.underFlowPreventer = Double.NEGATIVE_INFINITY;
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public void adapt(TreeSample<Rule> treeSample, boolean z) {
        Int2DoubleOpenHashMap int2DoubleOpenHashMap = new Int2DoubleOpenHashMap();
        int2DoubleOpenHashMap.defaultReturnValue(0.0d);
        Object2DoubleOpenHashMap object2DoubleOpenHashMap = new Object2DoubleOpenHashMap();
        object2DoubleOpenHashMap.defaultReturnValue(0.0d);
        double[] dArr = new double[this.startStates.length];
        double makeMaxBase = treeSample.makeMaxBase(z, this.underFlowPreventer);
        if (makeMaxBase != this.underFlowPreventer) {
            this.rate.reset();
            this.underFlowPreventer = makeMaxBase;
        }
        double makeAmounts = makeAmounts(treeSample, int2DoubleOpenHashMap, object2DoubleOpenHashMap, dArr, z);
        IntIterator it2 = int2DoubleOpenHashMap.keySet().iterator();
        int i = this.updateNumber + 1;
        while (it2.hasNext()) {
            int nextInt = it2.nextInt();
            Rule[] ensureRules = ensureRules(nextInt);
            prepareProbability(nextInt);
            double[] dArr2 = this.ruleProbs.get(nextInt);
            double[] dArr3 = this.ruleParameters.get(nextInt);
            for (int i2 = 0; i2 < ensureRules.length; i2++) {
                adapt(i2, ensureRules[i2], dArr3, object2DoubleOpenHashMap, int2DoubleOpenHashMap, dArr2[i2]);
            }
            this.lastUpdated.put(nextInt, i);
            this.currentProbs.put(nextInt, false);
        }
        prepareStartProbability();
        for (int i3 = 0; i3 < this.startStates.length; i3++) {
            updateStart(this.startParameters, i3, dArr, makeAmounts, this.startProbabilities[i3]);
        }
        this.updateNumber++;
    }

    private Rule[] ensureRules(int i) {
        Rule[] ruleArr = this.listRules.get(i);
        if (ruleArr == null) {
            ArrayList arrayList = new ArrayList();
            this.basis.foreachRuleTopDown(i, obj -> {
                arrayList.add((Rule) obj);
            });
            ruleArr = (Rule[]) arrayList.toArray(new Rule[arrayList.size()]);
            Arrays.sort(ruleArr);
            this.listRules.put(i, (int) ruleArr);
            double[] dArr = new double[ruleArr.length];
            Arrays.fill(dArr, 0.0d);
            this.ruleProbs.put(i, (int) dArr);
            double[] dArr2 = new double[ruleArr.length];
            Arrays.fill(dArr2, 0.0d);
            this.ruleParameters.put(i, (int) dArr2);
        }
        return ruleArr;
    }

    private double makeAmounts(TreeSample<Rule> treeSample, Int2DoubleOpenHashMap int2DoubleOpenHashMap, Object2DoubleOpenHashMap object2DoubleOpenHashMap, double[] dArr, boolean z) {
        double d = 0.0d;
        for (int i = 0; i < treeSample.populationSize(); i++) {
            double selfNormalizedWeight = treeSample.getSelfNormalizedWeight(i);
            Tree<Rule> sample = treeSample.getSample(i);
            d += selfNormalizedWeight;
            int binarySearch = Arrays.binarySearch(this.startStates, sample.getLabel().getParent());
            dArr[binarySearch] = dArr[binarySearch] + selfNormalizedWeight;
            addAmounts(sample, int2DoubleOpenHashMap, object2DoubleOpenHashMap, selfNormalizedWeight);
        }
        return d;
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public TreeAutomaton getAutomaton() {
        return this.basis;
    }

    private void addAmounts(Tree<Rule> tree, Int2DoubleOpenHashMap int2DoubleOpenHashMap, Object2DoubleOpenHashMap object2DoubleOpenHashMap, double d) {
        Rule label = tree.getLabel();
        int2DoubleOpenHashMap.addTo(label.getParent(), d);
        object2DoubleOpenHashMap.addTo(label, d);
        int size = tree.getChildren().size();
        for (int i = 0; i < size; i++) {
            addAmounts(tree.getChildren().get(i), int2DoubleOpenHashMap, object2DoubleOpenHashMap, d);
        }
    }

    private void adapt(int i, Rule rule, double[] dArr, Object2DoubleOpenHashMap object2DoubleOpenHashMap, Int2DoubleOpenHashMap int2DoubleOpenHashMap, double d) {
        double compare = Double.compare(r0, 0.0d) * Math.pow(Math.abs(dArr[i]), this.regularizationExponent) * this.regularizationDivisor;
        if (object2DoubleOpenHashMap != null && int2DoubleOpenHashMap != null) {
            compare = (compare + (d * int2DoubleOpenHashMap.get(rule.getParent()))) - object2DoubleOpenHashMap.getDouble(rule);
        }
        dArr[i] = dArr[i] - (this.rate.getLearningRate(rule.getParent(), i, compare) * compare);
    }

    private void updateStart(double[] dArr, int i, double[] dArr2, double d, double d2) {
        double compare = Double.compare(r0, 0.0d) * Math.pow(Math.abs(dArr[i]), this.regularizationExponent) * this.regularizationDivisor;
        if (dArr2 != null) {
            compare = (compare + (d2 * d)) - dArr2[i];
        }
        dArr[i] = dArr[i] - (this.rate.getLearningRate(-1, i, compare) * compare);
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public int getNumberOfStartStates() {
        return this.startStates.length;
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public int getStartStateByNumber(int i) {
        return this.startStates[i];
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public int getRuleNumber(int i, double d) {
        double[] dArr = this.ruleProbs.get(i);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            d -= dArr[i2];
            if (d <= ALMOST_ZERO) {
                return i2;
            }
        }
        throw new IllegalStateException("Probabilities did not sum to one.");
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public Rule getRuleByNumber(int i, int i2) {
        return this.listRules.get(i)[i2];
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public int getStartStateNumber(double d) {
        for (int i = 0; i < this.startStates.length; i++) {
            d -= this.startProbabilities[i];
            if (d <= ALMOST_ZERO) {
                return i;
            }
        }
        throw new IllegalStateException("Probabilities did not sum to one.");
    }

    @Override // de.up.ling.irtg.sampling.RuleWeighting
    public double getLogProbability(Rule rule) {
        return Math.log(this.ruleProbs.get(rule.getParent())[Arrays.binarySearch(this.listRules.get(rule.getParent()), rule)]);
    }
}
