package cc.mallet.classify;

import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.Labeling;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.logging.Logger;
import org.antlr.v4.runtime.tree.xpath.XPath;

/* loaded from: input_file:cc/mallet/classify/Trial.class */
public class Trial extends ArrayList<Classification> {
    private static Logger logger;
    Classifier classifier;
    static final /* synthetic */ boolean $assertionsDisabled;

    public Trial(Classifier classifier, InstanceList instanceList) {
        super(instanceList.size());
        this.classifier = classifier;
        Iterator<Instance> it2 = instanceList.iterator();
        while (it2.hasNext()) {
            add(classifier.classify(it2.next()));
        }
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.AbstractCollection, java.util.Collection, java.util.List
    public boolean add(Classification classification) {
        if (classification.getClassifier() != this.classifier) {
            throw new IllegalArgumentException("Trying to add Classification from a different Classifier.");
        }
        return super.add((Trial) classification);
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.List
    public void add(int i, Classification classification) {
        if (classification.getClassifier() != this.classifier) {
            throw new IllegalArgumentException("Trying to add Classification from a different Classifier.");
        }
        super.add(i, (int) classification);
    }

    @Override // java.util.ArrayList, java.util.AbstractCollection, java.util.Collection, java.util.List
    public boolean addAll(Collection<? extends Classification> collection) {
        boolean z = true;
        Iterator<? extends Classification> it2 = collection.iterator();
        while (it2.hasNext()) {
            if (!add(it2.next())) {
                z = false;
            }
        }
        return z;
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.List
    public boolean addAll(int i, Collection<? extends Classification> collection) {
        throw new IllegalStateException("Not implemented.");
    }

    public Classifier getClassifier() {
        return this.classifier;
    }

    public double getAccuracy() {
        int i = 0;
        for (int i2 = 0; i2 < size(); i2++) {
            if (get(i2).bestLabelIsCorrect()) {
                i++;
            }
        }
        return i / size();
    }

    public double getPrecision(Object obj) {
        int bestIndex = obj instanceof Labeling ? ((Labeling) obj).getBestIndex() : this.classifier.getLabelAlphabet().lookupIndex(obj, false);
        if (bestIndex == -1) {
            throw new IllegalArgumentException("Label " + obj.toString() + " is not a valid label.");
        }
        return getPrecision(bestIndex);
    }

    public double getPrecision(Labeling labeling) {
        return getPrecision(labeling.getBestIndex());
    }

    public double getPrecision(int i) {
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < size(); i4++) {
            int bestIndex = get(i4).getInstance().getLabeling().getBestIndex();
            if (get(i4).getLabeling().getBestIndex() == i) {
                i3++;
                if (bestIndex == i) {
                    i2++;
                }
            }
        }
        if (i3 != 0) {
            return i2 / i3;
        }
        logger.warning("No examples with predicted label " + this.classifier.getLabelAlphabet().lookupLabel(i) + XPath.NOT);
        if ($assertionsDisabled || i2 == 0) {
            return 1.0d;
        }
        throw new AssertionError();
    }

    public double getRecall(Object obj) {
        int bestIndex = obj instanceof Labeling ? ((Labeling) obj).getBestIndex() : this.classifier.getLabelAlphabet().lookupIndex(obj, false);
        if (bestIndex == -1) {
            throw new IllegalArgumentException("Label " + obj.toString() + " is not a valid label.");
        }
        return getRecall(bestIndex);
    }

    public double getRecall(Labeling labeling) {
        return getRecall(labeling.getBestIndex());
    }

    public double getRecall(int i) {
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < size(); i4++) {
            int bestIndex = get(i4).getInstance().getLabeling().getBestIndex();
            int bestIndex2 = get(i4).getLabeling().getBestIndex();
            if (bestIndex == i) {
                i3++;
                if (bestIndex2 == i) {
                    i2++;
                }
            }
        }
        if (i3 != 0) {
            return i2 / i3;
        }
        logger.warning("No examples with true label " + this.classifier.getLabelAlphabet().lookupLabel(i) + XPath.NOT);
        if ($assertionsDisabled || i2 == 0) {
            return 1.0d;
        }
        throw new AssertionError();
    }

    public double getF1(Object obj) {
        int bestIndex = obj instanceof Labeling ? ((Labeling) obj).getBestIndex() : this.classifier.getLabelAlphabet().lookupIndex(obj, false);
        if (bestIndex == -1) {
            throw new IllegalArgumentException("Label " + obj.toString() + " is not a valid label.");
        }
        return getF1(bestIndex);
    }

    public double getF1(Labeling labeling) {
        return getF1(labeling.getBestIndex());
    }

    public double getF1(int i) {
        double precision = getPrecision(i);
        double recall = getRecall(i);
        if (precision == 0.0d && recall == 0.0d) {
            return 0.0d;
        }
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    public double getAverageRank() {
        double d = 0.0d;
        for (int i = 0; i < size(); i++) {
            Classification classification = get(i);
            Instance classification2 = classification.getInstance();
            Labeling labeling = classification.getLabeling();
            int rank = labeling.getRank((Label) classification2.getTarget());
            labeling.getLabelAtRank(0);
            d += rank;
        }
        return d / size();
    }

    static {
        $assertionsDisabled = !Trial.class.desiredAssertionStatus();
        logger = Logger.getLogger(Trial.class.getName());
    }
}
