package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import de.up.ling.irtg.laboratory.Program;
import gnu.trove.TObjectIntHashMap;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Formatter;
import java.util.Iterator;
import java.util.Locale;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Logger;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.springframework.jdbc.datasource.init.ScriptUtils;

/* loaded from: input_file:cc/mallet/topics/ParallelTopicModel.class */
public class ParallelTopicModel implements Serializable {
    public static final int UNASSIGNED_TOPIC = -1;
    public static Logger logger = MalletLogger.getLogger(ParallelTopicModel.class.getName());
    public ArrayList<TopicAssignment> data;
    public Alphabet alphabet;
    public LabelAlphabet topicAlphabet;
    public int numTopics;
    public int topicMask;
    public int topicBits;
    public int numTypes;
    public int totalTokens;
    public double[] alpha;
    public double alphaSum;
    public double beta;
    public double betaSum;
    public boolean usingSymmetricAlpha;
    public static final double DEFAULT_BETA = 0.01d;
    public int[][] typeTopicCounts;
    public int[] tokensPerTopic;
    public int[] docLengthCounts;
    public int[][] topicDocCounts;
    public int numIterations;
    public int burninPeriod;
    public int saveSampleInterval;
    public int optimizeInterval;
    public int temperingInterval;
    public int showTopicsInterval;
    public int wordsPerTopic;
    public int saveStateInterval;
    public String stateFilename;
    public int saveModelInterval;
    public String modelFilename;
    public int randomSeed;
    public NumberFormat formatter;
    public boolean printLogLikelihood;
    int[] typeTotals;
    int maxTypeCount;
    int numThreads;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public ParallelTopicModel(int i) {
        this(i, i, 0.01d);
    }

    public ParallelTopicModel(int i, double d, double d2) {
        this(newLabelAlphabet(i), d, d2);
    }

    private static LabelAlphabet newLabelAlphabet(int i) {
        LabelAlphabet labelAlphabet = new LabelAlphabet();
        for (int i2 = 0; i2 < i; i2++) {
            labelAlphabet.lookupIndex("topic" + i2);
        }
        return labelAlphabet;
    }

    public ParallelTopicModel(LabelAlphabet labelAlphabet, double d, double d2) {
        this.usingSymmetricAlpha = false;
        this.numIterations = 1000;
        this.burninPeriod = 200;
        this.saveSampleInterval = 10;
        this.optimizeInterval = 50;
        this.temperingInterval = 0;
        this.showTopicsInterval = 50;
        this.wordsPerTopic = 7;
        this.saveStateInterval = 0;
        this.stateFilename = null;
        this.saveModelInterval = 0;
        this.modelFilename = null;
        this.randomSeed = -1;
        this.printLogLikelihood = true;
        this.numThreads = 1;
        this.data = new ArrayList<>();
        this.topicAlphabet = labelAlphabet;
        this.numTopics = labelAlphabet.size();
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = (Integer.highestOneBit(this.numTopics) * 2) - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alphaSum = d;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, d / this.numTopics);
        this.beta = d2;
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Coded LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    public void setBurninPeriod(int i) {
        this.burninPeriod = i;
    }

    public void setTopicDisplay(int i, int i2) {
        this.showTopicsInterval = i;
        this.wordsPerTopic = i2;
    }

    public void setRandomSeed(int i) {
        this.randomSeed = i;
    }

    public void setOptimizeInterval(int i) {
        this.optimizeInterval = i;
        if (this.saveSampleInterval > this.optimizeInterval) {
            this.saveSampleInterval = this.optimizeInterval;
        }
    }

    public void setSymmetricAlpha(boolean z) {
        this.usingSymmetricAlpha = z;
    }

    public void setTemperingInterval(int i) {
        this.temperingInterval = i;
    }

    public void setNumThreads(int i) {
        this.numThreads = i;
    }

    public void setSaveState(int i, String str) {
        this.saveStateInterval = i;
        this.stateFilename = str;
    }

    public void setSaveSerializedModel(int i, String str) {
        this.saveModelInterval = i;
        this.modelFilename = str;
    }

    /* JADX WARN: Type inference failed for: r1v10, types: [int[], int[][]] */
    public void addInstances(InstanceList instanceList) {
        this.alphabet = instanceList.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * this.numTypes;
        this.typeTopicCounts = new int[this.numTypes];
        this.typeTotals = new int[this.numTypes];
        int i = 0;
        Iterator<Instance> it2 = instanceList.iterator();
        while (it2.hasNext()) {
            i++;
            FeatureSequence featureSequence = (FeatureSequence) it2.next().getData();
            for (int i2 = 0; i2 < featureSequence.getLength(); i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                int[] iArr = this.typeTotals;
                iArr[indexAtPosition] = iArr[indexAtPosition] + 1;
            }
        }
        this.maxTypeCount = 0;
        for (int i3 = 0; i3 < this.numTypes; i3++) {
            if (this.typeTotals[i3] > this.maxTypeCount) {
                this.maxTypeCount = this.typeTotals[i3];
            }
            this.typeTopicCounts[i3] = new int[Math.min(this.numTopics, this.typeTotals[i3])];
        }
        int i4 = 0;
        Randoms randoms = this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed);
        Iterator<Instance> it3 = instanceList.iterator();
        while (it3.hasNext()) {
            Instance next = it3.next();
            i4++;
            LabelSequence labelSequence = new LabelSequence(this.topicAlphabet, new int[((FeatureSequence) next.getData()).size()]);
            int[] features = labelSequence.getFeatures();
            for (int i5 = 0; i5 < features.length; i5++) {
                features[i5] = randoms.nextInt(this.numTopics);
            }
            this.data.add(new TopicAssignment(next, labelSequence));
        }
        buildInitialTypeTopicCounts();
        initializeHistograms();
    }

    public void initializeFromState(File file) throws IOException {
        String str;
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))));
        String readLine = bufferedReader.readLine();
        while (true) {
            str = readLine;
            if (!str.startsWith("#")) {
                break;
            } else {
                readLine = bufferedReader.readLine();
            }
        }
        String[] split = str.split(" ");
        Iterator<TopicAssignment> it2 = this.data.iterator();
        while (it2.hasNext()) {
            TopicAssignment next = it2.next();
            FeatureSequence featureSequence = (FeatureSequence) next.instance.getData();
            int[] features = next.topicSequence.getFeatures();
            for (int i = 0; i < featureSequence.size(); i++) {
                if (featureSequence.getIndexAtPosition(i) != Integer.parseInt(split[3])) {
                    System.err.println("instance list and state do not match: " + str);
                    throw new IllegalStateException();
                }
                features[i] = Integer.parseInt(split[5]);
                str = bufferedReader.readLine();
                if (str != null) {
                    split = str.split(" ");
                }
            }
        }
        buildInitialTypeTopicCounts();
        initializeHistograms();
    }

    public void buildInitialTypeTopicCounts() {
        Arrays.fill(this.tokensPerTopic, 0);
        for (int i = 0; i < this.numTypes; i++) {
            int[] iArr = this.typeTopicCounts[i];
            for (int i2 = 0; i2 < iArr.length && iArr[i2] > 0; i2++) {
                iArr[i2] = 0;
            }
        }
        Iterator<TopicAssignment> it2 = this.data.iterator();
        while (it2.hasNext()) {
            TopicAssignment next = it2.next();
            FeatureSequence featureSequence = (FeatureSequence) next.instance.getData();
            int[] features = next.topicSequence.getFeatures();
            for (int i3 = 0; i3 < featureSequence.size(); i3++) {
                int i4 = features[i3];
                if (i4 != -1) {
                    int[] iArr2 = this.tokensPerTopic;
                    iArr2[i4] = iArr2[i4] + 1;
                    int indexAtPosition = featureSequence.getIndexAtPosition(i3);
                    int[] iArr3 = this.typeTopicCounts[indexAtPosition];
                    int i5 = 0;
                    int i6 = iArr3[0];
                    int i7 = this.topicMask;
                    while (true) {
                        int i8 = i6 & i7;
                        if (iArr3[i5] <= 0 || i8 == i4) {
                            break;
                        }
                        i5++;
                        if (i5 == iArr3.length) {
                            logger.info("overflow on type " + indexAtPosition);
                        }
                        i6 = iArr3[i5];
                        i7 = this.topicMask;
                    }
                    int i9 = iArr3[i5] >> this.topicBits;
                    if (i9 == 0) {
                        iArr3[i5] = (1 << this.topicBits) + i4;
                    } else {
                        iArr3[i5] = ((i9 + 1) << this.topicBits) + i4;
                        while (i5 > 0 && iArr3[i5] > iArr3[i5 - 1]) {
                            int i10 = iArr3[i5];
                            iArr3[i5] = iArr3[i5 - 1];
                            iArr3[i5 - 1] = i10;
                            i5--;
                        }
                    }
                }
            }
        }
    }

    public void sumTypeTopicCounts(WorkerRunnable[] workerRunnableArr) {
        Arrays.fill(this.tokensPerTopic, 0);
        for (int i = 0; i < this.numTypes; i++) {
            int[] iArr = this.typeTopicCounts[i];
            for (int i2 = 0; i2 < iArr.length && iArr[i2] > 0; i2++) {
                iArr[i2] = 0;
            }
        }
        for (int i3 = 0; i3 < this.numThreads; i3++) {
            int[] tokensPerTopic = workerRunnableArr[i3].getTokensPerTopic();
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                int[] iArr2 = this.tokensPerTopic;
                int i5 = i4;
                iArr2[i5] = iArr2[i5] + tokensPerTopic[i4];
            }
            int[][] typeTopicCounts = workerRunnableArr[i3].getTypeTopicCounts();
            for (int i6 = 0; i6 < this.numTypes; i6++) {
                int[] iArr3 = typeTopicCounts[i6];
                int[] iArr4 = this.typeTopicCounts[i6];
                for (int i7 = 0; i7 < iArr3.length && iArr3[i7] > 0; i7++) {
                    int i8 = iArr3[i7] & this.topicMask;
                    int i9 = iArr3[i7] >> this.topicBits;
                    int i10 = 0;
                    int i11 = iArr4[0];
                    int i12 = this.topicMask;
                    while (true) {
                        int i13 = i11 & i12;
                        if (iArr4[i10] <= 0 || i13 == i8) {
                            break;
                        }
                        i10++;
                        if (i10 == iArr4.length) {
                            logger.info("overflow in merging on type " + i6);
                        }
                        i11 = iArr4[i10];
                        i12 = this.topicMask;
                    }
                    iArr4[i10] = (((iArr4[i10] >> this.topicBits) + i9) << this.topicBits) + i8;
                    while (i10 > 0 && iArr4[i10] > iArr4[i10 - 1]) {
                        int i14 = iArr4[i10];
                        iArr4[i10] = iArr4[i10 - 1];
                        iArr4[i10 - 1] = i14;
                        i10--;
                    }
                }
            }
        }
    }

    private void initializeHistograms() {
        int i = 0;
        this.totalTokens = 0;
        for (int i2 = 0; i2 < this.data.size(); i2++) {
            int length = ((FeatureSequence) this.data.get(i2).instance.getData()).getLength();
            if (length > i) {
                i = length;
            }
            this.totalTokens += length;
        }
        logger.info("max tokens: " + i);
        logger.info("total tokens: " + this.totalTokens);
        this.docLengthCounts = new int[i + 1];
        this.topicDocCounts = new int[this.numTopics][i + 1];
    }

    public void optimizeAlpha(WorkerRunnable[] workerRunnableArr) {
        Arrays.fill(this.docLengthCounts, 0);
        for (int i = 0; i < this.topicDocCounts.length; i++) {
            Arrays.fill(this.topicDocCounts[i], 0);
        }
        for (int i2 = 0; i2 < this.numThreads; i2++) {
            int[] docLengthCounts = workerRunnableArr[i2].getDocLengthCounts();
            int[][] topicDocCounts = workerRunnableArr[i2].getTopicDocCounts();
            for (int i3 = 0; i3 < docLengthCounts.length; i3++) {
                if (docLengthCounts[i3] > 0) {
                    int[] iArr = this.docLengthCounts;
                    int i4 = i3;
                    iArr[i4] = iArr[i4] + docLengthCounts[i3];
                    docLengthCounts[i3] = 0;
                }
            }
            for (int i5 = 0; i5 < this.numTopics; i5++) {
                if (this.usingSymmetricAlpha) {
                    for (int i6 = 0; i6 < topicDocCounts[i5].length; i6++) {
                        if (topicDocCounts[i5][i6] > 0) {
                            int[] iArr2 = this.topicDocCounts[0];
                            int i7 = i6;
                            iArr2[i7] = iArr2[i7] + topicDocCounts[i5][i6];
                            topicDocCounts[i5][i6] = 0;
                        }
                    }
                } else {
                    for (int i8 = 0; i8 < topicDocCounts[i5].length; i8++) {
                        if (topicDocCounts[i5][i8] > 0) {
                            int[] iArr3 = this.topicDocCounts[i5];
                            int i9 = i8;
                            iArr3[i9] = iArr3[i9] + topicDocCounts[i5][i8];
                            topicDocCounts[i5][i8] = 0;
                        }
                    }
                }
            }
        }
        if (!this.usingSymmetricAlpha) {
            this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts, 1.001d, 1.0d, 1);
            return;
        }
        this.alphaSum = Dirichlet.learnSymmetricConcentration(this.topicDocCounts[0], this.docLengthCounts, this.numTopics, this.alphaSum);
        for (int i10 = 0; i10 < this.numTopics; i10++) {
            this.alpha[i10] = this.alphaSum / this.numTopics;
        }
    }

    public void temperAlpha(WorkerRunnable[] workerRunnableArr) {
        Arrays.fill(this.docLengthCounts, 0);
        for (int i = 0; i < this.topicDocCounts.length; i++) {
            Arrays.fill(this.topicDocCounts[i], 0);
        }
        for (int i2 = 0; i2 < this.numThreads; i2++) {
            int[] docLengthCounts = workerRunnableArr[i2].getDocLengthCounts();
            int[][] topicDocCounts = workerRunnableArr[i2].getTopicDocCounts();
            for (int i3 = 0; i3 < docLengthCounts.length; i3++) {
                if (docLengthCounts[i3] > 0) {
                    docLengthCounts[i3] = 0;
                }
            }
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                for (int i5 = 0; i5 < topicDocCounts[i4].length; i5++) {
                    if (topicDocCounts[i4][i5] > 0) {
                        topicDocCounts[i4][i5] = 0;
                    }
                }
            }
        }
        for (int i6 = 0; i6 < this.numTopics; i6++) {
            this.alpha[i6] = 1.0d;
        }
        this.alphaSum = this.numTopics;
    }

    public void optimizeBeta(WorkerRunnable[] workerRunnableArr) {
        int[] iArr = new int[this.maxTypeCount + 1];
        for (int i = 0; i < this.numTypes; i++) {
            int[] iArr2 = this.typeTopicCounts[i];
            for (int i2 = 0; i2 < iArr2.length && iArr2[i2] > 0; i2++) {
                int i3 = iArr2[i2] >> this.topicBits;
                iArr[i3] = iArr[i3] + 1;
            }
        }
        int i4 = 0;
        for (int i5 = 0; i5 < this.numTopics; i5++) {
            if (this.tokensPerTopic[i5] > i4) {
                i4 = this.tokensPerTopic[i5];
            }
        }
        int[] iArr3 = new int[i4 + 1];
        for (int i6 = 0; i6 < this.numTopics; i6++) {
            int i7 = this.tokensPerTopic[i6];
            iArr3[i7] = iArr3[i7] + 1;
        }
        this.betaSum = Dirichlet.learnSymmetricConcentration(iArr, iArr3, this.numTypes, this.betaSum);
        this.beta = this.betaSum / this.numTypes;
        logger.info("[beta: " + this.formatter.format(this.beta) + "] ");
        for (int i8 = 0; i8 < this.numThreads; i8++) {
            workerRunnableArr[i8].resetBeta(this.beta, this.betaSum);
        }
    }

    /* JADX WARN: Type inference failed for: r0v180, types: [int[], int[][]] */
    public void estimate() throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        WorkerRunnable[] workerRunnableArr = new WorkerRunnable[this.numThreads];
        int size = this.data.size() / this.numThreads;
        int i = 0;
        if (this.numThreads > 1) {
            for (int i2 = 0; i2 < this.numThreads; i2++) {
                int[] iArr = new int[this.numTopics];
                System.arraycopy(this.tokensPerTopic, 0, iArr, 0, this.numTopics);
                ?? r0 = new int[this.numTypes];
                for (int i3 = 0; i3 < this.numTypes; i3++) {
                    int[] iArr2 = new int[this.typeTopicCounts[i3].length];
                    System.arraycopy(this.typeTopicCounts[i3], 0, iArr2, 0, iArr2.length);
                    r0[i3] = iArr2;
                }
                if (i2 == this.numThreads - 1) {
                    size = this.data.size() - i;
                }
                workerRunnableArr[i2] = new WorkerRunnable(this.numTopics, this.alpha, this.alphaSum, this.beta, this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed), this.data, r0, iArr, i, size);
                workerRunnableArr[i2].initializeAlphaStatistics(this.docLengthCounts.length);
                i += size;
            }
        } else {
            workerRunnableArr[0] = new WorkerRunnable(this.numTopics, this.alpha, this.alphaSum, this.beta, this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed), this.data, this.typeTopicCounts, this.tokensPerTopic, 0, size);
            workerRunnableArr[0].initializeAlphaStatistics(this.docLengthCounts.length);
            workerRunnableArr[0].makeOnlyThread();
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numThreads);
        for (int i4 = 1; i4 <= this.numIterations; i4++) {
            long currentTimeMillis2 = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && i4 != 0 && i4 % this.showTopicsInterval == 0) {
                logger.info(ScriptUtils.FALLBACK_STATEMENT_SEPARATOR + displayTopWords(this.wordsPerTopic, false));
            }
            if (this.saveStateInterval != 0 && i4 % this.saveStateInterval == 0) {
                printState(new File(this.stateFilename + '.' + i4));
            }
            if (this.saveModelInterval != 0 && i4 % this.saveModelInterval == 0) {
                write(new File(this.modelFilename + '.' + i4));
            }
            if (this.numThreads > 1) {
                for (int i5 = 0; i5 < this.numThreads; i5++) {
                    if (i4 > this.burninPeriod && this.optimizeInterval != 0 && i4 % this.saveSampleInterval == 0) {
                        workerRunnableArr[i5].collectAlphaStatistics();
                    }
                    logger.fine("submitting thread " + i5);
                    newFixedThreadPool.submit(workerRunnableArr[i5]);
                }
                try {
                    Thread.sleep(20L);
                } catch (InterruptedException e) {
                }
                boolean z = false;
                while (!z) {
                    try {
                        Thread.sleep(10L);
                    } catch (InterruptedException e2) {
                    }
                    z = true;
                    for (int i6 = 0; i6 < this.numThreads; i6++) {
                        z = z && workerRunnableArr[i6].isFinished;
                    }
                }
                sumTypeTopicCounts(workerRunnableArr);
                for (int i7 = 0; i7 < this.numThreads; i7++) {
                    System.arraycopy(this.tokensPerTopic, 0, workerRunnableArr[i7].getTokensPerTopic(), 0, this.numTopics);
                    int[][] typeTopicCounts = workerRunnableArr[i7].getTypeTopicCounts();
                    for (int i8 = 0; i8 < this.numTypes; i8++) {
                        int[] iArr3 = typeTopicCounts[i8];
                        int[] iArr4 = this.typeTopicCounts[i8];
                        for (int i9 = 0; i9 < iArr4.length; i9++) {
                            if (iArr4[i9] != 0) {
                                iArr3[i9] = iArr4[i9];
                            } else if (iArr3[i9] != 0) {
                                iArr3[i9] = 0;
                            }
                        }
                    }
                }
            } else {
                if (i4 > this.burninPeriod && this.optimizeInterval != 0 && i4 % this.saveSampleInterval == 0) {
                    workerRunnableArr[0].collectAlphaStatistics();
                }
                workerRunnableArr[0].run();
            }
            long currentTimeMillis3 = System.currentTimeMillis() - currentTimeMillis2;
            if (currentTimeMillis3 < 1000) {
                logger.fine(currentTimeMillis3 + "ms ");
            } else {
                logger.fine((currentTimeMillis3 / 1000) + "s ");
            }
            if (i4 > this.burninPeriod && this.optimizeInterval != 0 && i4 % this.optimizeInterval == 0) {
                optimizeAlpha(workerRunnableArr);
                optimizeBeta(workerRunnableArr);
                logger.fine("[O " + (System.currentTimeMillis() - currentTimeMillis2) + "] ");
            }
            if (i4 % 10 == 0) {
                if (this.printLogLikelihood) {
                    logger.info(Program.LEFT_INPUT_DELIMITER + i4 + "> LL/token: " + this.formatter.format(modelLogLikelihood() / this.totalTokens));
                } else {
                    logger.info(Program.LEFT_INPUT_DELIMITER + i4 + Program.RIGHT_INPUT_DELIMITER);
                }
            }
        }
        newFixedThreadPool.shutdownNow();
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        StringBuilder sb = new StringBuilder();
        sb.append("\nTotal time: ");
        if (j5 != 0) {
            sb.append(j5);
            sb.append(" days ");
        }
        if (j6 != 0) {
            sb.append(j6);
            sb.append(" hours ");
        }
        if (j4 != 0) {
            sb.append(j4);
            sb.append(" minutes ");
        }
        sb.append(j2);
        sb.append(" seconds");
        logger.info(sb.toString());
    }

    public void printTopWords(File file, int i, boolean z) throws IOException {
        PrintStream printStream = new PrintStream(file);
        printTopWords(printStream, i, z);
        printStream.close();
    }

    public ArrayList<TreeSet<IDSorter>> getSortedWords() {
        ArrayList<TreeSet<IDSorter>> arrayList = new ArrayList<>(this.numTopics);
        for (int i = 0; i < this.numTopics; i++) {
            arrayList.add(new TreeSet<>());
        }
        for (int i2 = 0; i2 < this.numTypes; i2++) {
            int[] iArr = this.typeTopicCounts[i2];
            for (int i3 = 0; i3 < iArr.length && iArr[i3] > 0; i3++) {
                arrayList.get(iArr[i3] & this.topicMask).add(new IDSorter(i2, iArr[i3] >> this.topicBits));
            }
        }
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.Object[], java.lang.Object[][]] */
    public Object[][] getTopWords(int i) {
        ArrayList<TreeSet<IDSorter>> sortedWords = getSortedWords();
        ?? r0 = new Object[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            TreeSet<IDSorter> treeSet = sortedWords.get(i2);
            int i3 = i;
            if (treeSet.size() < i) {
                i3 = treeSet.size();
            }
            r0[i2] = new Object[i3];
            Iterator<IDSorter> it2 = treeSet.iterator();
            for (int i4 = 0; i4 < i3; i4++) {
                r0[i2][i4] = this.alphabet.lookupObject(it2.next().getID());
            }
        }
        return r0;
    }

    public void printTopWords(PrintStream printStream, int i, boolean z) {
        printStream.print(displayTopWords(i, z));
    }

    public String displayTopWords(int i, boolean z) {
        StringBuilder sb = new StringBuilder();
        ArrayList<TreeSet<IDSorter>> sortedWords = getSortedWords();
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            int i3 = 1;
            Iterator<IDSorter> it2 = sortedWords.get(i2).iterator();
            if (z) {
                sb.append(i2 + "\t" + this.formatter.format(this.alpha[i2]) + ScriptUtils.FALLBACK_STATEMENT_SEPARATOR);
                while (it2.hasNext() && i3 < i) {
                    IDSorter next = it2.next();
                    sb.append(this.alphabet.lookupObject(next.getID()) + "\t" + this.formatter.format(next.getWeight()) + ScriptUtils.FALLBACK_STATEMENT_SEPARATOR);
                    i3++;
                }
            } else {
                sb.append(i2 + "\t" + this.formatter.format(this.alpha[i2]) + "\t");
                while (it2.hasNext() && i3 < i) {
                    sb.append(this.alphabet.lookupObject(it2.next().getID()) + " ");
                    i3++;
                }
                sb.append(ScriptUtils.FALLBACK_STATEMENT_SEPARATOR);
            }
        }
        return sb.toString();
    }

    public void topicXMLReport(PrintWriter printWriter, int i) {
        ArrayList<TreeSet<IDSorter>> sortedWords = getSortedWords();
        printWriter.println("<?xml version='1.0' ?>");
        printWriter.println("<topicModel>");
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            printWriter.println("  <topic id='" + i2 + "' alpha='" + this.alpha[i2] + "' totalTokens='" + this.tokensPerTopic[i2] + "'>");
            Iterator<IDSorter> it2 = sortedWords.get(i2).iterator();
            for (int i3 = 1; it2.hasNext() && i3 < i; i3++) {
                printWriter.println("\t<word rank='" + i3 + "'>" + this.alphabet.lookupObject(it2.next().getID()) + "</word>");
            }
            printWriter.println("  </topic>");
        }
        printWriter.println("</topicModel>");
    }

    public void topicPhraseXMLReport(PrintWriter printWriter, int i) {
        int numTopics = getNumTopics();
        TObjectIntHashMap[] tObjectIntHashMapArr = new TObjectIntHashMap[numTopics];
        Alphabet alphabet = getAlphabet();
        for (int i2 = 0; i2 < numTopics; i2++) {
            tObjectIntHashMapArr[i2] = new TObjectIntHashMap();
        }
        for (int i3 = 0; i3 < getData().size(); i3++) {
            FeatureSequence featureSequence = (FeatureSequence) getData().get(i3).instance.getData();
            boolean z = featureSequence instanceof FeatureSequenceWithBigrams;
            int i4 = -1;
            int i5 = -1;
            StringBuffer stringBuffer = null;
            int size = featureSequence.size();
            for (int i6 = 0; i6 < size; i6++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i6);
                int indexAtPosition2 = getData().get(i3).topicSequence.getIndexAtPosition(i6);
                if (indexAtPosition2 != i4 || (z && ((FeatureSequenceWithBigrams) featureSequence).getBiIndexAtPosition(i6) == -1)) {
                    if (stringBuffer != null) {
                        String stringBuffer2 = stringBuffer.toString();
                        if (tObjectIntHashMapArr[i4].get(stringBuffer2) == 0) {
                            tObjectIntHashMapArr[i4].put(stringBuffer2, 0);
                        }
                        tObjectIntHashMapArr[i4].increment(stringBuffer2);
                        i5 = -1;
                        i4 = -1;
                        stringBuffer = null;
                    } else {
                        i4 = indexAtPosition2;
                        i5 = indexAtPosition;
                    }
                } else if (stringBuffer == null) {
                    stringBuffer = new StringBuffer(alphabet.lookupObject(i5).toString() + " " + alphabet.lookupObject(indexAtPosition));
                } else {
                    stringBuffer.append(" ");
                    stringBuffer.append(alphabet.lookupObject(indexAtPosition));
                }
            }
        }
        printWriter.println("<?xml version='1.0' ?>");
        printWriter.println("<topics>");
        ArrayList<TreeSet<IDSorter>> sortedWords = getSortedWords();
        double[] dArr = new double[alphabet.size()];
        for (int i7 = 0; i7 < numTopics; i7++) {
            printWriter.print("  <topic id=\"" + i7 + "\" alpha=\"" + this.alpha[i7] + "\" totalTokens=\"" + this.tokensPerTopic[i7] + "\" ");
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            PrintStream printStream = new PrintStream(byteArrayOutputStream);
            AugmentableFeatureVector augmentableFeatureVector = new AugmentableFeatureVector(new Alphabet());
            int i8 = 1;
            Iterator<IDSorter> it2 = sortedWords.get(i7).iterator();
            while (it2.hasNext() && i8 < i) {
                IDSorter next = it2.next();
                printStream.println("\t<word weight=\"" + (next.getWeight() / this.tokensPerTopic[i7]) + "\" count=\"" + Math.round(next.getWeight()) + "\">" + alphabet.lookupObject(next.getID()) + "</word>");
                i8++;
                if (i8 < 20) {
                    augmentableFeatureVector.add(alphabet.lookupObject(next.getID()), next.getWeight());
                }
            }
            Object[] keys = tObjectIntHashMapArr[i7].keys();
            int[] values = tObjectIntHashMapArr[i7].getValues();
            double[] dArr2 = new double[keys.length];
            for (int i9 = 0; i9 < dArr2.length; i9++) {
                dArr2[i9] = values[i9];
            }
            double sum = MatrixOps.sum(dArr2);
            Alphabet alphabet2 = new Alphabet(keys);
            RankedFeatureVector rankedFeatureVector = new RankedFeatureVector(alphabet2, dArr2);
            int numLocations = rankedFeatureVector.numLocations() < i ? rankedFeatureVector.numLocations() : i;
            for (int i10 = 0; i10 < numLocations; i10++) {
                int indexAtRank = rankedFeatureVector.getIndexAtRank(i10);
                printStream.println("\t<phrase weight=\"" + (dArr2[indexAtRank] / sum) + "\" count=\"" + values[indexAtRank] + "\">" + alphabet2.lookupObject(indexAtRank) + "</phrase>");
                if (i10 < 20 && values[indexAtRank] > 20) {
                    augmentableFeatureVector.add(alphabet2.lookupObject(indexAtRank), 100 * values[indexAtRank]);
                }
            }
            StringBuffer stringBuffer3 = new StringBuffer();
            RankedFeatureVector rankedFeatureVector2 = new RankedFeatureVector(augmentableFeatureVector.getAlphabet(), augmentableFeatureVector);
            int i11 = 10;
            for (int i12 = 0; i12 < i11 && i12 < rankedFeatureVector2.numLocations(); i12++) {
                if (stringBuffer3.indexOf(rankedFeatureVector2.getObjectAtRank(i12).toString()) == -1) {
                    stringBuffer3.append(rankedFeatureVector2.getObjectAtRank(i12));
                    if (i12 < i11 - 1) {
                        stringBuffer3.append(", ");
                    }
                } else {
                    i11++;
                }
            }
            printWriter.println("titles=\"" + stringBuffer3.toString() + "\">");
            printWriter.print(byteArrayOutputStream.toString());
            printWriter.println("  </topic>");
        }
        printWriter.println("</topics>");
    }

    public void printTypeTopicCounts(File file) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(file));
        for (int i = 0; i < this.numTypes; i++) {
            StringBuilder sb = new StringBuilder();
            sb.append(i + " " + this.alphabet.lookupObject(i));
            int[] iArr = this.typeTopicCounts[i];
            for (int i2 = 0; i2 < iArr.length && iArr[i2] > 0; i2++) {
                sb.append(" " + (iArr[i2] & this.topicMask) + ":" + (iArr[i2] >> this.topicBits));
            }
            printWriter.println(sb);
        }
        printWriter.close();
    }

    public void printTopicWordWeights(File file) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(file));
        printTopicWordWeights(printWriter);
        printWriter.close();
    }

    public void printTopicWordWeights(PrintWriter printWriter) throws IOException {
        for (int i = 0; i < this.numTopics; i++) {
            for (int i2 = 0; i2 < this.numTypes; i2++) {
                int[] iArr = this.typeTopicCounts[i2];
                double d = this.beta;
                int i3 = 0;
                while (true) {
                    if (i3 < iArr.length && iArr[i3] > 0) {
                        if ((iArr[i3] & this.topicMask) == i) {
                            d += iArr[i3] >> this.topicBits;
                            break;
                        }
                        i3++;
                    }
                }
                printWriter.println(i + "\t" + this.alphabet.lookupObject(i2) + "\t" + d);
            }
        }
    }

    public double[] getTopicProbabilities(int i) {
        return getTopicProbabilities(this.data.get(i).topicSequence);
    }

    public double[] getTopicProbabilities(LabelSequence labelSequence) {
        double[] dArr = new double[this.numTopics];
        for (int i = 0; i < labelSequence.getLength(); i++) {
            int indexAtPosition = labelSequence.getIndexAtPosition(i);
            dArr[indexAtPosition] = dArr[indexAtPosition] + 1.0d;
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] + this.alpha[i2];
            d += dArr[i2];
        }
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            int i5 = i4;
            dArr[i5] = dArr[i5] / d;
        }
        return dArr;
    }

    public void printDocumentTopics(File file) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(file));
        printDocumentTopics(printWriter);
        printWriter.close();
    }

    public void printDocumentTopics(PrintWriter printWriter) {
        printDocumentTopics(printWriter, 0.0d, -1);
    }

    public void printDocumentTopics(PrintWriter printWriter, double d, int i) {
        printWriter.print("#doc name topic proportion ...\n");
        int[] iArr = new int[this.numTopics];
        IDSorter[] iDSorterArr = new IDSorter[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            iDSorterArr[i2] = new IDSorter(i2, i2);
        }
        if (i < 0 || i > this.numTopics) {
            i = this.numTopics;
        }
        for (int i3 = 0; i3 < this.data.size(); i3++) {
            int[] features = this.data.get(i3).topicSequence.getFeatures();
            StringBuilder sb = new StringBuilder();
            sb.append(i3);
            sb.append("\t");
            if (this.data.get(i3).instance.getName() != null) {
                sb.append(this.data.get(i3).instance.getName());
            } else {
                sb.append("no-name");
            }
            sb.append("\t");
            int length = features.length;
            for (int i4 : features) {
                iArr[i4] = iArr[i4] + 1;
            }
            for (int i5 = 0; i5 < this.numTopics; i5++) {
                iDSorterArr[i5].set(i5, (this.alpha[i5] + iArr[i5]) / (length + this.alphaSum));
            }
            Arrays.sort(iDSorterArr);
            for (int i6 = 0; i6 < i && iDSorterArr[i6].getWeight() >= d; i6++) {
                sb.append(iDSorterArr[i6].getID() + "\t" + iDSorterArr[i6].getWeight() + "\t");
            }
            printWriter.println(sb);
            Arrays.fill(iArr, 0);
        }
    }

    public void printState(File file) throws IOException {
        PrintStream printStream = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file))));
        printState(printStream);
        printStream.close();
    }

    public void printState(PrintStream printStream) {
        printStream.println("#doc source pos typeindex type topic");
        printStream.print("#alpha : ");
        for (int i = 0; i < this.numTopics; i++) {
            printStream.print(this.alpha[i] + " ");
        }
        printStream.println();
        printStream.println("#beta : " + this.beta);
        for (int i2 = 0; i2 < this.data.size(); i2++) {
            FeatureSequence featureSequence = (FeatureSequence) this.data.get(i2).instance.getData();
            LabelSequence labelSequence = this.data.get(i2).topicSequence;
            String obj = this.data.get(i2).instance.getSource() != null ? this.data.get(i2).instance.getSource().toString() : "NA";
            Formatter formatter = new Formatter(new StringBuilder(), Locale.US);
            for (int i3 = 0; i3 < labelSequence.getLength(); i3++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i3);
                formatter.format("%d %s %d %d %s %d\n", Integer.valueOf(i2), obj, Integer.valueOf(i3), Integer.valueOf(indexAtPosition), this.alphabet.lookupObject(indexAtPosition), Integer.valueOf(labelSequence.getIndexAtPosition(i3)));
            }
            printStream.print(formatter);
        }
    }

    public double modelLogLikelihood() {
        double d = 0.0d;
        int[] iArr = new int[this.numTopics];
        double[] dArr = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            dArr[i] = Dirichlet.logGammaStirling(this.alpha[i]);
        }
        for (int i2 = 0; i2 < this.data.size(); i2++) {
            for (int i3 : this.data.get(i2).topicSequence.getFeatures()) {
                iArr[i3] = iArr[i3] + 1;
            }
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                if (iArr[i4] > 0) {
                    d += Dirichlet.logGammaStirling(this.alpha[i4] + iArr[i4]) - dArr[i4];
                }
            }
            d -= Dirichlet.logGammaStirling(this.alphaSum + r0.length);
            Arrays.fill(iArr, 0);
        }
        double size = d + (this.data.size() * Dirichlet.logGammaStirling(this.alphaSum));
        int i5 = 0;
        for (int i6 = 0; i6 < this.numTypes; i6++) {
            int[] iArr2 = this.typeTopicCounts[i6];
            for (int i7 = 0; i7 < iArr2.length && iArr2[i7] > 0; i7++) {
                int i8 = iArr2[i7] & this.topicMask;
                i5++;
                size += Dirichlet.logGammaStirling(this.beta + (iArr2[i7] >> this.topicBits));
                if (Double.isNaN(size)) {
                    logger.warning("NaN in log likelihood calculation");
                    return 0.0d;
                }
                if (Double.isInfinite(size)) {
                    logger.warning("infinite log likelihood");
                    return 0.0d;
                }
            }
        }
        for (int i9 = 0; i9 < this.numTopics; i9++) {
            size -= Dirichlet.logGammaStirling((this.beta * this.numTypes) + this.tokensPerTopic[i9]);
            if (Double.isNaN(size)) {
                logger.info("NaN after topic " + i9 + " " + this.tokensPerTopic[i9]);
                return 0.0d;
            }
            if (Double.isInfinite(size)) {
                logger.info("Infinite value after topic " + i9 + " " + this.tokensPerTopic[i9]);
                return 0.0d;
            }
        }
        double logGammaStirling = (size + (Dirichlet.logGammaStirling(this.beta * this.numTypes) * this.numTopics)) - (Dirichlet.logGammaStirling(this.beta) * i5);
        if (Double.isNaN(logGammaStirling)) {
            logger.info("at the end");
        } else if (Double.isInfinite(logGammaStirling)) {
            logger.info("Infinite value beta " + this.beta + " * " + this.numTypes);
            return 0.0d;
        }
        return logGammaStirling;
    }

    public TopicInferencer getInferencer() {
        return new TopicInferencer(this.typeTopicCounts, this.tokensPerTopic, this.data.get(0).instance.getDataAlphabet(), this.alpha, this.beta, this.betaSum);
    }

    public MarginalProbEstimator getProbEstimator() {
        return new MarginalProbEstimator(this.numTopics, this.alpha, this.alphaSum, this.beta, this.typeTopicCounts, this.tokensPerTopic);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(0);
        objectOutputStream.writeObject(this.data);
        objectOutputStream.writeObject(this.alphabet);
        objectOutputStream.writeObject(this.topicAlphabet);
        objectOutputStream.writeInt(this.numTopics);
        objectOutputStream.writeInt(this.topicMask);
        objectOutputStream.writeInt(this.topicBits);
        objectOutputStream.writeInt(this.numTypes);
        objectOutputStream.writeObject(this.alpha);
        objectOutputStream.writeDouble(this.alphaSum);
        objectOutputStream.writeDouble(this.beta);
        objectOutputStream.writeDouble(this.betaSum);
        objectOutputStream.writeObject(this.typeTopicCounts);
        objectOutputStream.writeObject(this.tokensPerTopic);
        objectOutputStream.writeObject(this.docLengthCounts);
        objectOutputStream.writeObject(this.topicDocCounts);
        objectOutputStream.writeInt(this.numIterations);
        objectOutputStream.writeInt(this.burninPeriod);
        objectOutputStream.writeInt(this.saveSampleInterval);
        objectOutputStream.writeInt(this.optimizeInterval);
        objectOutputStream.writeInt(this.showTopicsInterval);
        objectOutputStream.writeInt(this.wordsPerTopic);
        objectOutputStream.writeInt(this.saveStateInterval);
        objectOutputStream.writeObject(this.stateFilename);
        objectOutputStream.writeInt(this.saveModelInterval);
        objectOutputStream.writeObject(this.modelFilename);
        objectOutputStream.writeInt(this.randomSeed);
        objectOutputStream.writeObject(this.formatter);
        objectOutputStream.writeBoolean(this.printLogLikelihood);
        objectOutputStream.writeInt(this.numThreads);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.data = (ArrayList) objectInputStream.readObject();
        this.alphabet = (Alphabet) objectInputStream.readObject();
        this.topicAlphabet = (LabelAlphabet) objectInputStream.readObject();
        this.numTopics = objectInputStream.readInt();
        this.topicMask = objectInputStream.readInt();
        this.topicBits = objectInputStream.readInt();
        this.numTypes = objectInputStream.readInt();
        this.alpha = (double[]) objectInputStream.readObject();
        this.alphaSum = objectInputStream.readDouble();
        this.beta = objectInputStream.readDouble();
        this.betaSum = objectInputStream.readDouble();
        this.typeTopicCounts = (int[][]) objectInputStream.readObject();
        this.tokensPerTopic = (int[]) objectInputStream.readObject();
        this.docLengthCounts = (int[]) objectInputStream.readObject();
        this.topicDocCounts = (int[][]) objectInputStream.readObject();
        this.numIterations = objectInputStream.readInt();
        this.burninPeriod = objectInputStream.readInt();
        this.saveSampleInterval = objectInputStream.readInt();
        this.optimizeInterval = objectInputStream.readInt();
        this.showTopicsInterval = objectInputStream.readInt();
        this.wordsPerTopic = objectInputStream.readInt();
        this.saveStateInterval = objectInputStream.readInt();
        this.stateFilename = (String) objectInputStream.readObject();
        this.saveModelInterval = objectInputStream.readInt();
        this.modelFilename = (String) objectInputStream.readObject();
        this.randomSeed = objectInputStream.readInt();
        this.formatter = (NumberFormat) objectInputStream.readObject();
        this.printLogLikelihood = objectInputStream.readBoolean();
        this.numThreads = objectInputStream.readInt();
    }

    public void write(File file) {
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file));
            objectOutputStream.writeObject(this);
            objectOutputStream.close();
        } catch (IOException e) {
            System.err.println("Problem serializing ParallelTopicModel to file " + file + ": " + e);
        }
    }

    public static ParallelTopicModel read(File file) throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
        ParallelTopicModel parallelTopicModel = (ParallelTopicModel) objectInputStream.readObject();
        objectInputStream.close();
        parallelTopicModel.initializeHistograms();
        return parallelTopicModel;
    }

    public static void main(String[] strArr) {
        try {
            InstanceList load = InstanceList.load(new File(strArr[0]));
            ParallelTopicModel parallelTopicModel = new ParallelTopicModel(strArr.length > 1 ? Integer.parseInt(strArr[1]) : 200, 50.0d, 0.01d);
            parallelTopicModel.printLogLikelihood = true;
            parallelTopicModel.setTopicDisplay(50, 7);
            parallelTopicModel.addInstances(load);
            parallelTopicModel.setNumThreads(Integer.parseInt(strArr[2]));
            parallelTopicModel.estimate();
            logger.info("printing state");
            parallelTopicModel.printState(new File("state.gz"));
            logger.info("finished printing");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
