package de.up.ling.irtg.learning_rates;

import de.up.ling.irtg.util.NumbersCombine;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;

/* loaded from: input_file:de/up/ling/irtg/learning_rates/AdaGrad.class */
public class AdaGrad implements LearningRate {
    private final Long2DoubleOpenHashMap sums;
    private final double baseRate;

    public AdaGrad(double d) {
        this.sums = new Long2DoubleOpenHashMap();
        this.sums.defaultReturnValue(0.0d);
        this.baseRate = d;
    }

    public AdaGrad() {
        this(0.5d);
    }

    @Override // de.up.ling.irtg.learning_rates.LearningRate
    public double getLearningRate(int i, int i2, double d) {
        double d2 = d * d;
        double addTo = this.sums.addTo(NumbersCombine.combine(i, i2), d2) + d2;
        return (addTo == 0.0d ? 1.0d : 1.0d / Math.sqrt(addTo)) * this.baseRate;
    }

    @Override // de.up.ling.irtg.learning_rates.LearningRate
    public void reset() {
        this.sums.clear();
    }
}
