package cc.mallet.topics;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/topics/MarginalProbEstimator.class */
public class MarginalProbEstimator implements Serializable {
    protected int numTopics;
    protected int topicMask;
    protected int topicBits;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    protected double smoothingOnlyMass;
    protected double[] cachedCoefficients;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected Randoms random;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public MarginalProbEstimator(int i, double[] dArr, double d, double d2, int[][] iArr, int[] iArr2) {
        this.smoothingOnlyMass = 0.0d;
        this.numTopics = i;
        if (Integer.bitCount(i) == 1) {
            this.topicMask = i - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = (Integer.highestOneBit(i) * 2) - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.typeTopicCounts = iArr;
        this.tokensPerTopic = iArr2;
        this.alphaSum = d;
        this.alpha = dArr;
        this.beta = d2;
        this.betaSum = d2 * iArr.length;
        this.random = new Randoms();
        this.cachedCoefficients = new double[i];
        this.smoothingOnlyMass = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            this.smoothingOnlyMass += (dArr[i2] * d2) / (iArr2[i2] + this.betaSum);
            this.cachedCoefficients[i2] = dArr[i2] / (iArr2[i2] + this.betaSum);
        }
        System.err.println("Topic Evaluator: " + i + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public int[] getTokensPerTopic() {
        return this.tokensPerTopic;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double evaluateLeftToRight(InstanceList instanceList, int i, boolean z, PrintStream printStream) {
        this.random = new Randoms();
        double log = Math.log(i);
        double d = 0.0d;
        Iterator<Instance> it2 = instanceList.iterator();
        while (it2.hasNext()) {
            FeatureSequence featureSequence = (FeatureSequence) it2.next().getData();
            double d2 = 0.0d;
            double[] dArr = new double[i];
            for (int i2 = 0; i2 < i; i2++) {
                dArr[i2] = leftToRight(featureSequence, z);
            }
            for (int i3 = 0; i3 < dArr[0].length; i3++) {
                double d3 = 0.0d;
                for (int i4 = 0; i4 < i; i4++) {
                    d3 += dArr[i4][i3];
                }
                if (d3 > 0.0d) {
                    d2 += Math.log(d3) - log;
                }
            }
            if (printStream != null) {
                printStream.println(d2);
            }
            d += d2;
        }
        return d;
    }

    protected double[] leftToRight(FeatureSequence featureSequence, boolean z) {
        int[] iArr = new int[featureSequence.getLength()];
        double[] dArr = new double[featureSequence.getLength()];
        int length = featureSequence.getLength();
        int i = 0;
        int[] iArr2 = new int[this.numTopics];
        int[] iArr3 = new int[this.numTopics];
        int i2 = 0;
        double d = 0.0d;
        double[] dArr2 = new double[this.numTopics];
        for (int i3 = 0; i3 < length; i3++) {
            if (z) {
                for (int i4 = 0; i4 < i3; i4++) {
                    int indexAtPosition = featureSequence.getIndexAtPosition(i4);
                    int i5 = iArr[i4];
                    if (indexAtPosition < this.typeTopicCounts.length && this.typeTopicCounts[indexAtPosition] != null) {
                        int[] iArr4 = this.typeTopicCounts[indexAtPosition];
                        double d2 = d - ((this.beta * iArr2[i5]) / (this.tokensPerTopic[i5] + this.betaSum));
                        iArr2[i5] = iArr2[i5] - 1;
                        if (iArr2[i5] == 0) {
                            int i6 = 0;
                            while (iArr3[i6] != i5) {
                                i6++;
                            }
                            while (i6 < i2) {
                                if (i6 < iArr3.length - 1) {
                                    iArr3[i6] = iArr3[i6 + 1];
                                }
                                i6++;
                            }
                            i2--;
                        }
                        double d3 = d2 + ((this.beta * iArr2[i5]) / (this.tokensPerTopic[i5] + this.betaSum));
                        this.cachedCoefficients[i5] = (this.alpha[i5] + iArr2[i5]) / (this.tokensPerTopic[i5] + this.betaSum);
                        double d4 = 0.0d;
                        for (int i7 = 0; i7 < iArr4.length && iArr4[i7] > 0; i7++) {
                            double d5 = this.cachedCoefficients[iArr4[i7] & this.topicMask] * (iArr4[i7] >> this.topicBits);
                            d4 += d5;
                            dArr2[i7] = d5;
                        }
                        double nextUniform = this.random.nextUniform() * (this.smoothingOnlyMass + d3 + d4);
                        int i8 = -1;
                        if (nextUniform >= d4) {
                            double d6 = nextUniform - d4;
                            if (d6 >= d3) {
                                i8 = 0;
                                double d7 = (d6 - d3) / this.beta;
                                double d8 = this.alpha[0];
                                double d9 = this.tokensPerTopic[0];
                                double d10 = this.betaSum;
                                while (true) {
                                    nextUniform = d7 - (d8 / (d9 + d10));
                                    if (nextUniform <= 0.0d) {
                                        break;
                                    }
                                    i8++;
                                    d7 = nextUniform;
                                    d8 = this.alpha[i8];
                                    d9 = this.tokensPerTopic[i8];
                                    d10 = this.betaSum;
                                }
                            } else {
                                nextUniform = d6 / this.beta;
                                int i9 = 0;
                                while (true) {
                                    if (i9 >= i2) {
                                        break;
                                    }
                                    int i10 = iArr3[i9];
                                    nextUniform -= iArr2[i10] / (this.tokensPerTopic[i10] + this.betaSum);
                                    if (nextUniform <= 0.0d) {
                                        i8 = i10;
                                        break;
                                    }
                                    i9++;
                                }
                            }
                        } else {
                            int i11 = -1;
                            while (nextUniform > 0.0d) {
                                i11++;
                                nextUniform -= dArr2[i11];
                            }
                            i8 = iArr4[i11] & this.topicMask;
                        }
                        if (i8 == -1) {
                            System.err.println("sampling error: " + nextUniform + " " + nextUniform + " " + this.smoothingOnlyMass + " " + d3 + " " + d4);
                            i8 = this.numTopics - 1;
                        }
                        iArr[i4] = i8;
                        double d11 = d3 - ((this.beta * iArr2[i8]) / (this.tokensPerTopic[i8] + this.betaSum));
                        int i12 = i8;
                        iArr2[i12] = iArr2[i12] + 1;
                        if (iArr2[i8] == 1) {
                            int i13 = i2;
                            while (i13 > 0 && iArr3[i13 - 1] > i8) {
                                iArr3[i13] = iArr3[i13 - 1];
                                i13--;
                            }
                            iArr3[i13] = i8;
                            i2++;
                        }
                        this.cachedCoefficients[i8] = (this.alpha[i8] + iArr2[i8]) / (this.tokensPerTopic[i8] + this.betaSum);
                        d = d11 + ((this.beta * iArr2[i8]) / (this.tokensPerTopic[i8] + this.betaSum));
                    }
                }
            }
            int indexAtPosition2 = featureSequence.getIndexAtPosition(i3);
            if (indexAtPosition2 < this.typeTopicCounts.length && this.typeTopicCounts[indexAtPosition2] != null) {
                int[] iArr5 = this.typeTopicCounts[indexAtPosition2];
                double d12 = 0.0d;
                for (int i14 = 0; i14 < iArr5.length && iArr5[i14] > 0; i14++) {
                    double d13 = this.cachedCoefficients[iArr5[i14] & this.topicMask] * (iArr5[i14] >> this.topicBits);
                    d12 += d13;
                    dArr2[i14] = d13;
                }
                double nextUniform2 = this.random.nextUniform() * (this.smoothingOnlyMass + d + d12);
                int i15 = i3;
                dArr[i15] = dArr[i15] + (((this.smoothingOnlyMass + d) + d12) / (this.alphaSum + i));
                i++;
                int i16 = -1;
                if (nextUniform2 >= d12) {
                    double d14 = nextUniform2 - d12;
                    if (d14 >= d) {
                        i16 = 0;
                        double d15 = (d14 - d) / this.beta;
                        double d16 = this.alpha[0];
                        double d17 = this.tokensPerTopic[0];
                        double d18 = this.betaSum;
                        while (true) {
                            nextUniform2 = d15 - (d16 / (d17 + d18));
                            if (nextUniform2 <= 0.0d) {
                                break;
                            }
                            i16++;
                            d15 = nextUniform2;
                            d16 = this.alpha[i16];
                            d17 = this.tokensPerTopic[i16];
                            d18 = this.betaSum;
                        }
                    } else {
                        nextUniform2 = d14 / this.beta;
                        int i17 = 0;
                        while (true) {
                            if (i17 >= i2) {
                                break;
                            }
                            int i18 = iArr3[i17];
                            nextUniform2 -= iArr2[i18] / (this.tokensPerTopic[i18] + this.betaSum);
                            if (nextUniform2 <= 0.0d) {
                                i16 = i18;
                                break;
                            }
                            i17++;
                        }
                    }
                } else {
                    int i19 = -1;
                    while (nextUniform2 > 0.0d) {
                        i19++;
                        nextUniform2 -= dArr2[i19];
                    }
                    i16 = iArr5[i19] & this.topicMask;
                }
                if (i16 == -1) {
                    System.err.println("sampling error: " + nextUniform2 + " " + nextUniform2 + " " + this.smoothingOnlyMass + " " + d + " " + d12);
                    i16 = this.numTopics - 1;
                }
                iArr[i3] = i16;
                double d19 = d - ((this.beta * iArr2[i16]) / (this.tokensPerTopic[i16] + this.betaSum));
                int i20 = i16;
                iArr2[i20] = iArr2[i20] + 1;
                if (iArr2[i16] == 1) {
                    int i21 = i2;
                    while (i21 > 0 && iArr3[i21 - 1] > i16) {
                        iArr3[i21] = iArr3[i21 - 1];
                        i21--;
                    }
                    iArr3[i21] = i16;
                    i2++;
                }
                this.cachedCoefficients[i16] = (this.alpha[i16] + iArr2[i16]) / (this.tokensPerTopic[i16] + this.betaSum);
                d = d19 + ((this.beta * iArr2[i16]) / (this.tokensPerTopic[i16] + this.betaSum));
            }
        }
        for (int i22 = 0; i22 < i2; i22++) {
            int i23 = iArr3[i22];
            this.cachedCoefficients[i23] = this.alpha[i23] / (this.tokensPerTopic[i23] + this.betaSum);
        }
        return dArr;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(0);
        objectOutputStream.writeInt(this.numTopics);
        objectOutputStream.writeInt(this.topicMask);
        objectOutputStream.writeInt(this.topicBits);
        objectOutputStream.writeObject(this.alpha);
        objectOutputStream.writeDouble(this.alphaSum);
        objectOutputStream.writeDouble(this.beta);
        objectOutputStream.writeDouble(this.betaSum);
        objectOutputStream.writeObject(this.typeTopicCounts);
        objectOutputStream.writeObject(this.tokensPerTopic);
        objectOutputStream.writeObject(this.random);
        objectOutputStream.writeDouble(this.smoothingOnlyMass);
        objectOutputStream.writeObject(this.cachedCoefficients);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.numTopics = objectInputStream.readInt();
        this.topicMask = objectInputStream.readInt();
        this.topicBits = objectInputStream.readInt();
        this.alpha = (double[]) objectInputStream.readObject();
        this.alphaSum = objectInputStream.readDouble();
        this.beta = objectInputStream.readDouble();
        this.betaSum = objectInputStream.readDouble();
        this.typeTopicCounts = (int[][]) objectInputStream.readObject();
        this.tokensPerTopic = (int[]) objectInputStream.readObject();
        this.random = (Randoms) objectInputStream.readObject();
        this.smoothingOnlyMass = objectInputStream.readDouble();
        this.cachedCoefficients = (double[]) objectInputStream.readObject();
    }

    public static MarginalProbEstimator read(File file) throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
        MarginalProbEstimator marginalProbEstimator = (MarginalProbEstimator) objectInputStream.readObject();
        objectInputStream.close();
        return marginalProbEstimator;
    }
}
