package cc.mallet.types;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Trial;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/types/ROCData.class */
public class ROCData implements AlphabetCarrying, Serializable {
    private static final long serialVersionUID = -2060194953037720640L;
    public static final int TRUE_POSITIVE = 0;
    public static final int FALSE_POSITIVE = 1;
    public static final int FALSE_NEGATIVE = 2;
    public static final int TRUE_NEGATIVE = 3;
    private final LabelAlphabet labelAlphabet;
    private final int[][][] counts;
    private final double[] thresholds;

    public ROCData(double[] dArr, LabelAlphabet labelAlphabet) {
        Arrays.sort(dArr);
        this.counts = new int[labelAlphabet.size()][dArr.length][4];
        this.labelAlphabet = labelAlphabet;
        this.thresholds = dArr;
    }

    public void add(Classification classification) {
        int bestIndex = classification.getInstance().getLabeling().getBestIndex();
        LabelVector labelVector = classification.getLabelVector();
        double[] values = labelVector.getValues();
        if (!Alphabet.alphabetsMatch(this, labelVector)) {
            throw new IllegalArgumentException("Alphabets do not match");
        }
        int size = this.labelAlphabet.size();
        for (int i = 0; i < size; i++) {
            double d = values[i];
            int[][] iArr = this.counts[i];
            int i2 = 0;
            while (i2 < this.thresholds.length && d >= this.thresholds[i2]) {
                if (bestIndex == i) {
                    int[] iArr2 = iArr[i2];
                    iArr2[0] = iArr2[0] + 1;
                } else {
                    int[] iArr3 = iArr[i2];
                    iArr3[1] = iArr3[1] + 1;
                }
                i2++;
            }
            while (i2 < this.thresholds.length) {
                if (bestIndex == i) {
                    int[] iArr4 = iArr[i2];
                    iArr4[2] = iArr4[2] + 1;
                } else {
                    int[] iArr5 = iArr[i2];
                    iArr5[3] = iArr5[3] + 1;
                }
                i2++;
            }
        }
    }

    public void add(Trial trial) {
        Iterator<Classification> it2 = trial.iterator();
        while (it2.hasNext()) {
            add(it2.next());
        }
    }

    public void add(ROCData rOCData) {
        if (!Alphabet.alphabetsMatch(this, rOCData)) {
            throw new IllegalArgumentException("Alphabets do not match");
        }
        if (!Arrays.equals(this.thresholds, rOCData.thresholds)) {
            throw new IllegalArgumentException("Thresholds do not match");
        }
        int length = this.counts.length;
        for (int i = 0; i < length; i++) {
            int[][] iArr = this.counts[i];
            int[][] iArr2 = rOCData.counts[i];
            int length2 = iArr.length;
            for (int i2 = 0; i2 < length2; i2++) {
                int[] iArr3 = iArr[i2];
                int[] iArr4 = iArr2[i2];
                int length3 = iArr3.length;
                for (int i3 = 0; i3 < length3; i3++) {
                    int i4 = i3;
                    iArr3[i4] = iArr3[i4] + iArr4[i3];
                }
            }
        }
    }

    @Override // cc.mallet.types.AlphabetCarrying
    public Alphabet getAlphabet() {
        return this.labelAlphabet;
    }

    @Override // cc.mallet.types.AlphabetCarrying
    public Alphabet[] getAlphabets() {
        return new Alphabet[]{this.labelAlphabet};
    }

    public int[][] getCounts(Label label) {
        return this.counts[label.getIndex()];
    }

    public int[] getCounts(Label label, double d) {
        int binarySearch = Arrays.binarySearch(this.thresholds, d);
        if (binarySearch < 0) {
            binarySearch = (-binarySearch) - 2;
        }
        return this.counts[label.getIndex()][binarySearch];
    }

    public LabelAlphabet getLabelAlphabet() {
        return this.labelAlphabet;
    }

    public double getPrecision(Label label, double d) {
        int[] counts = getCounts(label, d);
        return counts[0] / (counts[0] + counts[1]);
    }

    public double getPrecisionForScore(Label label, double d) {
        double d2;
        double d3;
        int[][] iArr = this.counts[label.getIndex()];
        int binarySearch = Arrays.binarySearch(this.thresholds, d);
        if (binarySearch < 0) {
            binarySearch = (-binarySearch) - 2;
        }
        if (binarySearch == this.thresholds.length - 1) {
            d2 = iArr[binarySearch][0];
            d3 = iArr[binarySearch][1];
        } else {
            d2 = iArr[binarySearch][0] - iArr[binarySearch + 1][0];
            d3 = iArr[binarySearch][1] - iArr[binarySearch + 1][1];
        }
        return d2 / (d2 + d3);
    }

    public double getPositivePercent(Label label, double d) {
        int[] counts = getCounts(label, d);
        return ((counts[0] + counts[1]) / ((r0 + counts[2]) + counts[3])) * 100.0d;
    }

    public double getRecall(Label label, double d) {
        int[] counts = getCounts(label, d);
        return counts[0] / (counts[0] + counts[2]);
    }

    public double[] getThresholds() {
        return this.thresholds;
    }

    public void setCounts(Label label, double d, int[] iArr) {
        int binarySearch = Arrays.binarySearch(this.thresholds, d);
        if (binarySearch < 0) {
            binarySearch = (-binarySearch) - 2;
        }
        int[] iArr2 = this.counts[label.getIndex()][binarySearch];
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException("Array of counts must contain " + iArr2.length + " elements.");
        }
        for (int i = 0; i < iArr2.length; i++) {
            iArr2[i] = iArr[i];
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        DecimalFormat decimalFormat = new DecimalFormat("0.####");
        for (int i = 0; i < this.labelAlphabet.size(); i++) {
            int[][] iArr = this.counts[i];
            sb.append("ROC data for ");
            sb.append(this.labelAlphabet.lookupObject(i).toString());
            sb.append('\n');
            sb.append("THR\tTP\tFP\tFN\tTN\tPrecis\tRecall\n");
            for (int i2 = 0; i2 < this.thresholds.length; i2++) {
                sb.append(this.thresholds[i2]);
                for (int i3 : iArr[i2]) {
                    sb.append('\t').append(i3);
                }
                double d = iArr[i2][0];
                double d2 = d + iArr[i2][1];
                double d3 = d2 != 0.0d ? d / d2 : 0.0d;
                double d4 = d + iArr[i2][2];
                double d5 = 0.0d;
                if (d4 != 0.0d) {
                    d5 = d / d4;
                }
                sb.append('\t').append(decimalFormat.format(d3));
                sb.append('\t').append(decimalFormat.format(d5));
                sb.append('\n');
            }
            sb.append('\n');
        }
        return sb.toString();
    }
}
