package edu.cmu.minorthird.classify.experiments;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.SGMExample;
import edu.cmu.minorthird.classify.relational.RealRelationalDataset;
import edu.cmu.minorthird.classify.relational.StackedGraphicalLearner;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedClassifier;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedDataset;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.Saveable;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.LineCharter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.VanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.Color;
import java.awt.Component;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.event.ActionEvent;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.LineNumberReader;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Properties;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import javax.swing.AbstractAction;
import javax.swing.JButton;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTable;
import javax.swing.JTextField;
import javax.swing.table.DefaultTableCellRenderer;
import org.apache.log4j.Logger;
import org.jfree.data.xml.DatasetTags;

/* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation.class */
public class Evaluation implements Visible, Serializable, Saveable {
    private static Logger log = Logger.getLogger(Evaluation.class);
    private static final long serialVersionUID = 1;
    public static final int DEFAULT_PARTITION_ID = 0;
    private ExampleSchema schema;
    private boolean isBinary;
    public static final String EVAL_FORMAT_NAME = "Minorthird Evaluation";
    public static final String EVAL_EXT = ".eval";
    private final int CURRENT_VERSION_NUMBER = 1;
    private ArrayList entryList = new ArrayList();
    private transient Matrix cachedPRCMatrix = null;
    private transient Matrix cachedTPFPMatrix = null;
    private transient Matrix cachedConfusionMatrix = null;
    private Properties properties = new Properties();
    private ArrayList propertyKeyList = new ArrayList();

    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$ConfusionMatrixViewer.class */
    public static class ConfusionMatrixViewer extends ComponentViewer {
        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            Evaluation evaluation = (Evaluation) obj;
            JPanel jPanel = new JPanel();
            Matrix confusionMatrix = evaluation.confusionMatrix();
            String[] classes = evaluation.getClasses();
            jPanel.setLayout(new GridBagLayout());
            GridBagConstraints cmGBC = cmGBC(0, 1);
            cmGBC.gridwidth = classes.length;
            jPanel.add(new JLabel("Predicted Class"), cmGBC);
            for (int i = 0; i < classes.length; i++) {
                jPanel.add(new JLabel(classes[i]), cmGBC(1, i + 1));
            }
            for (int i2 = 0; i2 < classes.length; i2++) {
                jPanel.add(new JLabel(classes[i2]), cmGBC(i2 + 2, 0));
                for (int i3 = 0; i3 < classes.length; i3++) {
                    jPanel.add(new JLabel(Double.toString(confusionMatrix.values[i2][i3])), cmGBC(i2 + 2, i3 + 1));
                }
            }
            return jPanel;
        }

        private GridBagConstraints cmGBC(int i, int i2) {
            GridBagConstraints gridBagConstraints = new GridBagConstraints();
            gridBagConstraints.weighty = 0.0d;
            gridBagConstraints.weightx = 0.0d;
            gridBagConstraints.gridy = i;
            gridBagConstraints.gridx = i2;
            gridBagConstraints.ipady = 20;
            gridBagConstraints.ipadx = 20;
            return gridBagConstraints;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$ElevenPointPrecisionViewer.class */
    public static class ElevenPointPrecisionViewer extends ComponentViewer {
        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            double[] elevenPointPrecision = ((Evaluation) obj).elevenPointPrecision();
            LineCharter lineCharter = new LineCharter();
            lineCharter.startCurve("Interpolated Precision");
            for (int i = 0; i < elevenPointPrecision.length; i++) {
                lineCharter.addPoint(i / 10.0d, elevenPointPrecision[i]);
            }
            return lineCharter.getPanel("11-Pt Interpolated Precision vs. Recall", "Recall", "Precision");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$Entry.class */
    public static class Entry implements Serializable {
        private static final long serialVersionUID = -4069980043842319179L;
        public transient Instance instance;
        public int partitionID;
        public int index;
        public ClassLabel predicted;
        public ClassLabel actual;
        public int h;
        public double w = 1.0d;

        public Entry(Instance instance, ClassLabel classLabel, ClassLabel classLabel2, int i, int i2) {
            this.instance = null;
            this.instance = instance;
            this.predicted = classLabel;
            this.actual = classLabel2;
            this.index = i;
            this.partitionID = i2;
            this.h = this.instance.hashCode();
        }

        public String toString() {
            this.predicted.bestWeight();
            return this.predicted + "\t" + this.actual + "\t" + this.instance;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$Matrix.class */
    public static class Matrix {
        public double[][] values;

        public Matrix(double[][] dArr) {
            this.values = dArr;
        }

        public String toString() {
            StringBuffer stringBuffer = new StringBuffer("");
            for (int i = 0; i < this.values.length; i++) {
                stringBuffer.append(StringUtil.toString(this.values[i]) + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            }
            return stringBuffer.toString();
        }

        public double getValue(int i, int i2) {
            return this.values[i][i2];
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$MyTableCellRenderer.class */
    public class MyTableCellRenderer extends DefaultTableCellRenderer {
        public MyTableCellRenderer() {
        }

        public Component getTableCellRendererComponent(JTable jTable, Object obj, boolean z, boolean z2, int i, int i2) {
            JLabel tableCellRendererComponent = super.getTableCellRendererComponent(jTable, obj, z, z2, i, i2);
            if (i % 2 != 0) {
                tableCellRendererComponent.setBackground(Color.lightGray);
                tableCellRendererComponent.setOpaque(true);
            } else {
                tableCellRendererComponent.setBackground(Color.white);
                tableCellRendererComponent.setOpaque(true);
            }
            return tableCellRendererComponent;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$PropertyViewer.class */
    public static class PropertyViewer extends ComponentViewer {
        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            final Evaluation evaluation = (Evaluation) obj;
            final JPanel jPanel = new JPanel();
            final JTextField jTextField = new JTextField(10);
            final JTextField jTextField2 = new JTextField(10);
            final JScrollPane jScrollPane = new JScrollPane(makePropertyTable(evaluation));
            JButton jButton = new JButton(new AbstractAction("Insert Property") { // from class: edu.cmu.minorthird.classify.experiments.Evaluation.PropertyViewer.1
                public void actionPerformed(ActionEvent actionEvent) {
                    evaluation.setProperty(jTextField.getText(), jTextField2.getText());
                    jScrollPane.getViewport().setView(PropertyViewer.this.makePropertyTable(evaluation));
                    jScrollPane.revalidate();
                    jPanel.revalidate();
                }
            });
            jPanel.setLayout(new GridBagLayout());
            GridBagConstraints fillerGBC = fillerGBC();
            fillerGBC.gridwidth = 3;
            jPanel.add(jScrollPane, fillerGBC);
            jPanel.add(jButton, myGBC(0));
            jPanel.add(jTextField, myGBC(1));
            jPanel.add(jTextField2, myGBC(2));
            return jPanel;
        }

        private GridBagConstraints myGBC(int i) {
            GridBagConstraints fillerGBC = fillerGBC();
            fillerGBC.fill = 2;
            fillerGBC.gridx = i;
            fillerGBC.gridy = 1;
            return fillerGBC;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public JTable makePropertyTable(Evaluation evaluation) {
            Object[][] objArr = new Object[evaluation.propertyKeyList.size()][2];
            for (int i = 0; i < evaluation.propertyKeyList.size(); i++) {
                objArr[i][0] = evaluation.propertyKeyList.get(i);
                objArr[i][1] = evaluation.properties.get(evaluation.propertyKeyList.get(i));
            }
            return new JTable(objArr, new String[]{"Property", "Property's Value"});
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$ROCViewer.class */
    public static class ROCViewer extends ComponentViewer {
        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            Matrix thousandPointROC = ((Evaluation) obj).thousandPointROC();
            LineCharter lineCharter = new LineCharter();
            lineCharter.startCurve("Actual ROC");
            for (int i = 0; i < thousandPointROC.values.length; i++) {
                lineCharter.addPoint(thousandPointROC.values[i][1], thousandPointROC.values[i][0]);
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < thousandPointROC.values.length - 1; i2++) {
                d += ((thousandPointROC.values[i2][0] + thousandPointROC.values[i2 + 1][0]) * (thousandPointROC.values[i2 + 1][1] - thousandPointROC.values[i2][1])) / 2.0d;
            }
            return lineCharter.getPanel("Actual ROC Curve", "False Positive / All Negative   (AUC = " + d + ")", "True Positive / All Positive");
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Evaluation$SummaryViewer.class */
    public class SummaryViewer extends ComponentViewer {
        public SummaryViewer() {
        }

        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            Evaluation evaluation = (Evaluation) obj;
            double[] summaryStatistics = evaluation.summaryStatistics();
            String[] summaryStatisticNames = evaluation.summaryStatisticNames();
            Object[][] objArr = new Object[summaryStatistics.length][2];
            for (int i = 0; i < summaryStatistics.length; i++) {
                objArr[i][0] = summaryStatisticNames[i];
                objArr[i][1] = new Double(summaryStatistics[i]);
            }
            JTable jTable = new JTable(objArr, new String[]{"Statistic", DatasetTags.VALUE_TAG});
            jTable.setDefaultRenderer(Object.class, new MyTableCellRenderer());
            jTable.setVisible(true);
            return new JScrollPane(jTable);
        }
    }

    public Evaluation(ExampleSchema exampleSchema) {
        this.isBinary = true;
        this.schema = exampleSchema;
        this.isBinary = exampleSchema.equals(ExampleSchema.BINARY_EXAMPLE_SCHEMA);
    }

    public void extend4SGM(StackedGraphicalLearner.StackedGraphicalClassifier stackedGraphicalClassifier, RealRelationalDataset realRelationalDataset, int i) {
        new ProgressCounter("classifying", "example", realRelationalDataset.size());
        HashMap classification = stackedGraphicalClassifier.classification(realRelationalDataset);
        for (String str : classification.keySet()) {
            ClassLabel classLabel = (ClassLabel) classification.get(str);
            SGMExample examplewithID = realRelationalDataset.getExamplewithID(str);
            if (classLabel.bestClassName() == null) {
                throw new IllegalArgumentException("predicted can't be null! for example: " + examplewithID);
            }
            if (examplewithID.getLabel() == null) {
                throw new IllegalArgumentException("predicted can't be null!");
            }
            if (log.isDebugEnabled()) {
                log.debug("ok: " + (classLabel.isCorrect(examplewithID.getLabel()) ? "Y" : "N") + "\tpredict: " + classLabel + "\ton: " + examplewithID);
            }
            this.entryList.add(new Entry(examplewithID.asInstance(), classLabel, examplewithID.getLabel(), this.entryList.size(), i));
            extendSchema(examplewithID.getLabel());
            extendSchema(classLabel);
            this.cachedPRCMatrix = null;
        }
    }

    public void extend(Classifier classifier, Dataset dataset, int i) {
        ProgressCounter progressCounter = new ProgressCounter("classifying", "example", dataset.size());
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            extend(classifier.classification(nextExample), nextExample, i);
            progressCounter.progress();
        }
        progressCounter.finished();
    }

    public void extend(SequenceClassifier sequenceClassifier, SequenceDataset sequenceDataset) {
        Iterator sequenceIterator = sequenceDataset.sequenceIterator();
        while (sequenceIterator.hasNext()) {
            Example[] exampleArr = (Example[]) sequenceIterator.next();
            ClassLabel[] classification = sequenceClassifier.classification(exampleArr);
            for (int i = 0; i < exampleArr.length; i++) {
                extend(classification[i], exampleArr[i], 0);
            }
        }
    }

    public void extend(SemiSupervisedClassifier semiSupervisedClassifier, SemiSupervisedDataset semiSupervisedDataset, int i) {
        ProgressCounter progressCounter = new ProgressCounter("classifying", "example", semiSupervisedDataset.size());
        Example.Looper it = semiSupervisedDataset.iterator();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            extend(semiSupervisedClassifier.classification(nextExample), nextExample, i);
            progressCounter.progress();
        }
        progressCounter.finished();
    }

    public void extend(ClassLabel classLabel, Example example, int i) {
        if (classLabel.bestClassName() == null) {
            throw new IllegalArgumentException("predicted can't be null! for example: " + example);
        }
        if (example.getLabel() == null) {
            throw new IllegalArgumentException("predicted can't be null!");
        }
        if (log.isDebugEnabled()) {
            log.debug("ok: " + (classLabel.isCorrect(example.getLabel()) ? "Y" : "N") + "\tpredict: " + classLabel + "\ton: " + example);
        }
        this.entryList.add(new Entry(example.asInstance(), classLabel, example.getLabel(), this.entryList.size(), i));
        extendSchema(example.getLabel());
        extendSchema(classLabel);
        this.cachedPRCMatrix = null;
    }

    public void setProperty(String str, String str2) {
        if (this.properties.getProperty(str) == null) {
            this.propertyKeyList.add(str);
        }
        this.properties.setProperty(str, str2);
    }

    public String getProperty(String str) {
        return this.properties.getProperty(str, "=unassigned=");
    }

    public ClassLabel getPrediction(int i) {
        return ((Entry) this.entryList.get(i)).predicted;
    }

    public ClassLabel getActual(int i) {
        return ((Entry) this.entryList.get(i)).actual;
    }

    public boolean isCorrect(int i) {
        return getPrediction(i).isCorrect(getActual(i));
    }

    public double errors() {
        double d = 0.0d;
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            if (entry.actual.bestClassName() == null) {
                throw new IllegalArgumentException("actual label is null?");
            }
            d += entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w;
        }
        return d;
    }

    public double errors(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (entry.partitionID == i) {
                d += entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w;
            }
        }
        return d;
    }

    public double[] errorsByClass() {
        double[] dArr = new double[this.schema.getNumberOfClasses()];
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            int classIndex = this.schema.getClassIndex(entry.actual.bestClassName());
            dArr[classIndex] = dArr[classIndex] + (entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w);
        }
        return dArr;
    }

    public double[] errorsByClass(int i) {
        double[] dArr = new double[this.schema.getNumberOfClasses()];
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (entry.partitionID == i) {
                int classIndex = this.schema.getClassIndex(entry.actual.bestClassName());
                dArr[classIndex] = dArr[classIndex] + (entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w);
            }
        }
        return dArr;
    }

    public double errorsPos() {
        if (!this.isBinary) {
            return -1.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            if (ExampleSchema.POS_CLASS_NAME.equals(entry.actual.bestClassName())) {
                d += entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w;
            }
        }
        return d;
    }

    public double errorsPos(int i) {
        if (!this.isBinary) {
            return -1.0d;
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (ExampleSchema.POS_CLASS_NAME.equals(entry.actual.bestClassName()) & (entry.partitionID == i)) {
                d += entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w;
            }
        }
        return d;
    }

    public double errorsNeg() {
        double d = 0.0d;
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            if (ExampleSchema.NEG_CLASS_NAME.equals(entry.actual.bestClassName())) {
                d += entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w;
            }
        }
        return d;
    }

    public double errorsNeg(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (ExampleSchema.NEG_CLASS_NAME.equals(entry.actual.bestClassName()) & (entry.partitionID == i)) {
                d += entry.predicted.isCorrect(entry.actual) ? 0.0d : entry.w;
            }
        }
        return d;
    }

    public double stDevErrors() {
        int i = 0;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (entry.partitionID > i) {
                i = entry.partitionID + 1;
            }
        }
        double errorRate = errorRate();
        double d = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d += Math.pow((errors(i3) / numberOfInstances(i3)) - errorRate, 2.0d) / i;
        }
        return Math.sqrt(d);
    }

    public double[] stDevErrorsByClass() {
        int numberOfClasses = this.schema.getNumberOfClasses();
        int i = 0;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (entry.partitionID > i) {
                i = entry.partitionID + 1;
            }
        }
        double[] errorRateByClass = errorRateByClass();
        double[] dArr = new double[numberOfClasses];
        for (int i3 = 0; i3 < i; i3++) {
            double[] errorsByClass = errorsByClass(i3);
            double[] numberOfExamplesByClass = numberOfExamplesByClass(i3);
            for (int i4 = 0; i4 < numberOfClasses; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] + (Math.pow((errorsByClass[i4] / numberOfExamplesByClass[i4]) - errorRateByClass[i4], 2.0d) / i);
            }
        }
        for (int i6 = 0; i6 < numberOfClasses; i6++) {
            dArr[i6] = Math.sqrt(dArr[i6]);
        }
        return dArr;
    }

    public double stDevErrorsPos() {
        if (!this.isBinary) {
            return -1.0d;
        }
        int i = 0;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (entry.partitionID > i) {
                i = entry.partitionID + 1;
            }
        }
        double errorsPos = errorsPos() / numberOfPositiveExamples();
        double d = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d += Math.pow((errorsPos(i3) / numberOfPositiveExamples(i3)) - errorsPos, 2.0d) / i;
        }
        return Math.sqrt(d);
    }

    public double stDevErrorsNeg() {
        if (!this.isBinary) {
            return -1.0d;
        }
        int i = 0;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (entry.partitionID > i) {
                i = entry.partitionID + 1;
            }
        }
        double errorsNeg = errorsNeg() / numberOfNegativeExamples();
        double d = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d += Math.pow((errorsNeg(i3) / numberOfNegativeExamples(i3)) - errorsNeg, 2.0d) / i;
        }
        return Math.sqrt(d);
    }

    public double numberOfInstances() {
        double d = 0.0d;
        for (int i = 0; i < this.entryList.size(); i++) {
            d += getEntry(i).w;
        }
        return d;
    }

    public double numberOfInstances(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (entry.partitionID == i) {
                d += entry.w;
            }
        }
        return d;
    }

    public double[] numberOfExamplesByClass() {
        double[] dArr = new double[this.schema.getNumberOfClasses()];
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            int classIndex = this.schema.getClassIndex(entry.actual.bestClassName());
            dArr[classIndex] = dArr[classIndex] + entry.w;
        }
        return dArr;
    }

    public double[] numberOfExamplesByClass(int i) {
        double[] dArr = new double[this.schema.getNumberOfClasses()];
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            int classIndex = this.schema.getClassIndex(entry.actual.bestClassName());
            if (entry.partitionID == i) {
                dArr[classIndex] = dArr[classIndex] + entry.w;
            }
        }
        return dArr;
    }

    public double numberOfPositiveExamples() {
        if (!this.isBinary) {
            return -1.0d;
        }
        double d = 0.0d;
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            if (ExampleSchema.POS_CLASS_NAME.equals(entry.actual.bestClassName())) {
                d += entry.w;
            }
        }
        return d;
    }

    public double numberOfPositiveExamples(int i) {
        if (!this.isBinary) {
            return -1.0d;
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (ExampleSchema.POS_CLASS_NAME.equals(entry.actual.bestClassName()) & (entry.partitionID == i)) {
                d += entry.w;
            }
        }
        return d;
    }

    public double numberOfNegativeExamples() {
        double d = 0.0d;
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            if (ExampleSchema.NEG_CLASS_NAME.equals(entry.actual.bestClassName())) {
                d += entry.w;
            }
        }
        return d;
    }

    public double numberOfNegativeExamples(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.entryList.size(); i2++) {
            Entry entry = getEntry(i2);
            if (ExampleSchema.NEG_CLASS_NAME.equals(entry.actual.bestClassName()) & (entry.partitionID == i)) {
                d += entry.w;
            }
        }
        return d;
    }

    public double errorRate() {
        return errors() / numberOfInstances();
    }

    public double[] errorRateByClass() {
        int numberOfClasses = this.schema.getNumberOfClasses();
        double[] dArr = new double[numberOfClasses];
        double[] errorsByClass = errorsByClass();
        double[] numberOfExamplesByClass = numberOfExamplesByClass();
        for (int i = 0; i < numberOfClasses; i++) {
            dArr[i] = errorsByClass[i] / numberOfExamplesByClass[i];
        }
        return dArr;
    }

    public double errorRatePos() {
        return errorsPos() / numberOfPositiveExamples();
    }

    public double errorRateNeg() {
        return errorsNeg() / numberOfNegativeExamples();
    }

    public double errorRateBalanced() {
        double d = 0.0d;
        int numberOfClasses = this.schema.getNumberOfClasses();
        double[] errorsByClass = errorsByClass();
        double[] numberOfExamplesByClass = numberOfExamplesByClass();
        for (int i = 0; i < numberOfClasses; i++) {
            d += ((1.0d / numberOfClasses) * errorsByClass[i]) / numberOfExamplesByClass[i];
        }
        return d;
    }

    public double recallTopK(int i, double d) {
        if (!this.isBinary) {
            return -1.0d;
        }
        if (numberOfPositiveExamples() == 0.0d) {
            return 1.0d;
        }
        double d2 = 0.0d;
        double d3 = 0.0d;
        Matrix precisionRecallScore = precisionRecallScore();
        for (int i2 = 0; i2 < Math.min(precisionRecallScore.values.length, i); i2++) {
            if (precisionRecallScore.values[i2][1] > d2 && precisionRecallScore.values[i2][2] > d) {
                d3 += 1.0d;
            }
            d2 = precisionRecallScore.values[i2][1];
        }
        return d3 / numberOfPositiveExamples();
    }

    public double averagePrecision() {
        if (!this.isBinary) {
            return -1.0d;
        }
        if (numberOfInstances() == 0.0d) {
            return Double.NaN;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        Matrix precisionRecallScore = precisionRecallScore();
        double d3 = 0.0d;
        for (int i = 0; i < precisionRecallScore.values.length; i++) {
            if (precisionRecallScore.values[i][1] > d3) {
                d2 += 1.0d;
                d += precisionRecallScore.values[i][0];
            }
            d3 = precisionRecallScore.values[i][1];
        }
        return d / d2;
    }

    public double maxF1() {
        return maxF1(Double.MIN_VALUE);
    }

    public double maxF1(double d) {
        if (!this.isBinary) {
            return -1.0d;
        }
        if (numberOfPositiveExamples() == 0.0d) {
            return 1.0d;
        }
        double d2 = 0.0d;
        Matrix precisionRecallScore = precisionRecallScore();
        for (int i = 0; i < precisionRecallScore.values.length; i++) {
            double d3 = precisionRecallScore.values[i][0];
            double d4 = precisionRecallScore.values[i][1];
            if ((d3 > 0.0d || d4 > 0.0d) && precisionRecallScore.values[i][2] >= d) {
                d2 = Math.max(d2, ((2.0d * d3) * d4) / (d3 + d4));
            }
        }
        return d2;
    }

    public double kappa() {
        Matrix confusionMatrix = confusionMatrix();
        double size = this.entryList.size();
        int numberOfClasses = this.schema.getNumberOfClasses();
        double[] dArr = new double[numberOfClasses];
        double[] dArr2 = new double[numberOfClasses];
        double d = 0.0d;
        for (int i = 0; i < numberOfClasses; i++) {
            d += confusionMatrix.values[i][i];
            for (int i2 = 0; i2 < numberOfClasses; i2++) {
                int i3 = i;
                dArr[i3] = dArr[i3] + confusionMatrix.values[i][i2];
                int i4 = i;
                dArr2[i4] = dArr2[i4] + confusionMatrix.values[i2][i];
            }
        }
        double d2 = 0.0d;
        for (int i5 = 0; i5 < numberOfClasses; i5++) {
            d2 += (dArr[i5] / size) * (dArr2[i5] / size);
        }
        return ((d / size) - d2) / (1.0d - d2);
    }

    public int numExamples() {
        return this.entryList.size();
    }

    public double averageLogLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            d += Math.log(1.0d + Math.exp(entry.predicted.getWeight(entry.actual.bestClassName()) * (entry.predicted.isCorrect(entry.actual) ? 1.0d : -1.0d)));
        }
        return d / this.entryList.size();
    }

    public double precision() {
        if (!this.isBinary) {
            return -1.0d;
        }
        Matrix confusionMatrix = confusionMatrix();
        int classIndexOf = classIndexOf(ExampleSchema.POS_CLASS_NAME);
        return confusionMatrix.values[classIndexOf][classIndexOf] / (confusionMatrix.values[classIndexOf][classIndexOf] + confusionMatrix.values[classIndexOf(ExampleSchema.NEG_CLASS_NAME)][classIndexOf]);
    }

    public double recall() {
        if (!this.isBinary) {
            return -1.0d;
        }
        Matrix confusionMatrix = confusionMatrix();
        int classIndexOf = classIndexOf(ExampleSchema.POS_CLASS_NAME);
        return confusionMatrix.values[classIndexOf][classIndexOf] / (confusionMatrix.values[classIndexOf][classIndexOf] + confusionMatrix.values[classIndexOf][classIndexOf(ExampleSchema.NEG_CLASS_NAME)]);
    }

    public double f1() {
        if (!this.isBinary) {
            return -1.0d;
        }
        double precision = precision();
        double recall = recall();
        return ((2.0d * precision) * recall) / (precision + recall);
    }

    public double[] summaryStatistics() {
        int numberOfClasses = this.schema.getNumberOfClasses();
        if (!this.isBinary) {
            double[] dArr = new double[4 + (2 * numberOfClasses)];
            dArr[0] = errorRate();
            dArr[1] = stDevErrors();
            dArr[2] = errorRateBalanced();
            double[] errorRateByClass = errorRateByClass();
            double[] stDevErrorsByClass = stDevErrorsByClass();
            for (int i = 0; i < numberOfClasses; i++) {
                dArr[2 + (2 * i) + 1] = errorRateByClass[i];
                dArr[2 + (2 * i) + 2] = stDevErrorsByClass[i];
            }
            dArr[3 + (2 * numberOfClasses)] = kappa();
            return dArr;
        }
        double[] dArr2 = new double[10 + (2 * numberOfClasses)];
        dArr2[0] = errorRate();
        dArr2[1] = stDevErrors();
        dArr2[2] = errorRateBalanced();
        double[] errorRateByClass2 = errorRateByClass();
        double[] stDevErrorsByClass2 = stDevErrorsByClass();
        for (int i2 = 0; i2 < numberOfClasses; i2++) {
            dArr2[2 + (2 * i2) + 1] = errorRateByClass2[i2];
            dArr2[2 + (2 * i2) + 2] = stDevErrorsByClass2[i2];
        }
        dArr2[3 + (2 * numberOfClasses)] = averagePrecision();
        dArr2[4 + (2 * numberOfClasses)] = maxF1();
        dArr2[5 + (2 * numberOfClasses)] = averageLogLoss();
        dArr2[6 + (2 * numberOfClasses)] = recall();
        dArr2[7 + (2 * numberOfClasses)] = precision();
        dArr2[8 + (2 * numberOfClasses)] = f1();
        dArr2[9 + (2 * numberOfClasses)] = kappa();
        return dArr2;
    }

    public String[] summaryStatisticNames() {
        int numberOfClasses = this.schema.getNumberOfClasses();
        if (!this.isBinary) {
            String[] strArr = new String[4 + (2 * numberOfClasses)];
            strArr[0] = "Error Rate";
            strArr[1] = ". std. deviation error rate";
            strArr[2] = "Balanced Error Rate";
            for (int i = 0; i < numberOfClasses; i++) {
                String className = this.schema.getClassName(i);
                strArr[2 + (2 * i) + 1] = new String(". error Rate on " + className);
                strArr[2 + (2 * i) + 2] = new String(". std. deviation on " + className);
            }
            strArr[3 + (2 * numberOfClasses)] = "Kappa";
            return strArr;
        }
        String[] strArr2 = new String[10 + (2 * numberOfClasses)];
        strArr2[0] = "Error Rate";
        strArr2[1] = ". std. deviation error rate";
        strArr2[2] = "Balanced Error Rate";
        for (int i2 = 0; i2 < numberOfClasses; i2++) {
            String className2 = this.schema.getClassName(i2);
            strArr2[2 + (2 * i2) + 1] = new String(". error Rate on " + className2);
            strArr2[2 + (2 * i2) + 2] = new String(". std. deviation on " + className2);
        }
        strArr2[3 + (2 * numberOfClasses)] = "Average Precision";
        strArr2[4 + (2 * numberOfClasses)] = "Maximium F1";
        strArr2[5 + (2 * numberOfClasses)] = "Average Log Loss";
        strArr2[6 + (2 * numberOfClasses)] = "Recall";
        strArr2[7 + (2 * numberOfClasses)] = "Precision";
        strArr2[8 + (2 * numberOfClasses)] = "F1";
        strArr2[9 + (2 * numberOfClasses)] = "Kappa";
        return strArr2;
    }

    public Matrix confusionMatrix() {
        if (this.cachedConfusionMatrix != null) {
            return this.cachedConfusionMatrix;
        }
        String[] classes = getClasses();
        double[][] dArr = new double[classes.length][classes.length];
        for (int i = 0; i < this.entryList.size(); i++) {
            Entry entry = getEntry(i);
            double[] dArr2 = dArr[classIndexOf(entry.actual)];
            int classIndexOf = classIndexOf(entry.predicted);
            dArr2[classIndexOf] = dArr2[classIndexOf] + 1.0d;
        }
        this.cachedConfusionMatrix = new Matrix(dArr);
        return this.cachedConfusionMatrix;
    }

    public double numErrors() {
        Matrix confusionMatrix = confusionMatrix();
        return confusionMatrix.getValue(0, 1) + confusionMatrix.getValue(1, 0);
    }

    public String[] getClasses() {
        return this.schema.validClassNames();
    }

    public Matrix TPfractionFPfractionScore() {
        if (this.cachedTPFPMatrix != null) {
            return this.cachedTPFPMatrix;
        }
        if (!this.isBinary) {
            throw new IllegalArgumentException("can't compute precisionRecallScore for non-binary data");
        }
        byBinaryScore();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        boolean z = true;
        ProgressCounter progressCounter = new ProgressCounter("counting positive examples", "examples", this.entryList.size());
        for (int i5 = 0; i5 < this.entryList.size(); i5++) {
            if (getEntry(i5).actual.isPositive()) {
                i++;
                i3 = i5;
            } else {
                i2++;
                if (z) {
                    i4 = i5;
                    z = false;
                }
            }
            progressCounter.progress();
        }
        progressCounter.finished();
        int abs = Math.abs(i3 - i4) + 4;
        int min = Math.min(i3, i4);
        int max = Math.max(i3, i4);
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 1.0d;
        double d4 = 1.0d;
        ProgressCounter progressCounter2 = new ProgressCounter("computing statistics", "examples", this.entryList.size());
        double[][] dArr = new double[abs][3];
        int i6 = 0;
        while (i6 < this.entryList.size()) {
            Entry entry = getEntry(i6);
            double posWeight = entry.predicted.posWeight();
            if (entry.actual.isPositive()) {
                d += 1.0d;
            } else {
                d2 += 1.0d;
            }
            if (i > 0) {
                d3 = d / i;
            }
            if (i2 > 0) {
                d4 = d2 / i2;
            }
            if (i6 == 0) {
                dArr[0][0] = 0.0d;
                dArr[0][1] = 0.0d;
                dArr[0][2] = posWeight;
            }
            if ((i6 >= min - 1) & (i6 <= max)) {
                dArr[(i6 - min) + 2][0] = d3;
                dArr[(i6 - min) + 2][1] = d4;
                dArr[(i6 - min) + 2][2] = posWeight;
            }
            dArr[abs - 1][0] = 1.0d;
            dArr[abs - 1][1] = 1.0d;
            dArr[abs - 1][2] = posWeight;
            progressCounter2.progress();
            i6++;
        }
        progressCounter2.finished();
        this.cachedTPFPMatrix = new Matrix(dArr);
        return this.cachedTPFPMatrix;
    }

    public Matrix thousandPointROC() {
        Matrix TPfractionFPfractionScore = TPfractionFPfractionScore();
        int length = TPfractionFPfractionScore.values.length - 2;
        if (length <= 1000) {
            return TPfractionFPfractionScore;
        }
        double[][] dArr = new double[1002][3];
        int i = length / 1000;
        dArr[0][0] = TPfractionFPfractionScore.values[0][0];
        dArr[0][1] = TPfractionFPfractionScore.values[0][1];
        dArr[0][2] = TPfractionFPfractionScore.values[0][2];
        for (int i2 = 1; i2 <= 1000; i2++) {
            int i3 = (i2 - 1) * i;
            dArr[i2][0] = TPfractionFPfractionScore.values[i3 + 1][0];
            dArr[i2][1] = TPfractionFPfractionScore.values[i3 + 1][1];
            dArr[i2][2] = TPfractionFPfractionScore.values[i3 + 1][2];
        }
        dArr[1001][0] = TPfractionFPfractionScore.values[length + 1][0];
        dArr[1001][1] = TPfractionFPfractionScore.values[length + 1][1];
        dArr[1001][2] = TPfractionFPfractionScore.values[length + 1][2];
        return new Matrix(dArr);
    }

    public Matrix precisionRecallScore() {
        if (this.cachedPRCMatrix != null) {
            return this.cachedPRCMatrix;
        }
        if (!this.isBinary) {
            throw new IllegalArgumentException("can't compute precisionRecallScore for non-binary data");
        }
        byBinaryScore();
        int i = 0;
        int i2 = 0;
        ProgressCounter progressCounter = new ProgressCounter("counting positive examples", "examples", this.entryList.size());
        for (int i3 = 0; i3 < this.entryList.size(); i3++) {
            if (getEntry(i3).actual.isPositive()) {
                i++;
                i2 = i3;
            }
            progressCounter.progress();
        }
        progressCounter.finished();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 1.0d;
        double d4 = 1.0d;
        ProgressCounter progressCounter2 = new ProgressCounter("computing statistics", "examples", i2);
        double[][] dArr = new double[i2 + 1][3];
        for (int i4 = 0; i4 <= i2; i4++) {
            Entry entry = getEntry(i4);
            double posWeight = entry.predicted.posWeight();
            if (entry.actual.isPositive()) {
                d += 1.0d;
            } else {
                d2 += 1.0d;
            }
            if (d + d2 > 0.0d) {
                d3 = d / (d + d2);
            }
            if (i > 0) {
                d4 = d / i;
            }
            dArr[i4][0] = d3;
            dArr[i4][1] = d4;
            dArr[i4][2] = posWeight;
            progressCounter2.progress();
        }
        progressCounter2.finished();
        this.cachedPRCMatrix = new Matrix(dArr);
        return this.cachedPRCMatrix;
    }

    public double[] elevenPointPrecision() {
        Matrix precisionRecallScore = precisionRecallScore();
        double[] dArr = new double[11];
        dArr[0] = 1.0d;
        for (int i = 0; i < precisionRecallScore.values.length; i++) {
            double d = precisionRecallScore.values[i][1];
            for (int i2 = 1; i2 <= 10; i2++) {
                if (d >= i2 / 10.0d) {
                    dArr[i2] = Math.max(dArr[i2], precisionRecallScore.values[i][0]);
                }
            }
        }
        return dArr;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer("");
        for (int i = 0; i < this.entryList.size(); i++) {
            stringBuffer.append(getEntry(i) + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        }
        return stringBuffer.toString();
    }

    public void summarize() {
        double[] summaryStatistics = summaryStatistics();
        String[] summaryStatisticNames = summaryStatisticNames();
        int i = 0;
        for (String str : summaryStatisticNames) {
            i = Math.max(str.length(), i);
        }
        for (int i2 = 0; i2 < summaryStatisticNames.length; i2++) {
            System.out.print(summaryStatisticNames[i2] + ": ");
            for (int i3 = 0; i3 < i - summaryStatisticNames[i2].length(); i3++) {
                System.out.print(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR);
            }
            System.out.println(summaryStatistics[i2]);
        }
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        ParallelViewer parallelViewer = new ParallelViewer();
        parallelViewer.addSubView("Summary", new SummaryViewer());
        parallelViewer.addSubView("Properties", new PropertyViewer());
        if (this.isBinary) {
            parallelViewer.addSubView("11Pt Precision/Recall", new ElevenPointPrecisionViewer());
        }
        if (this.isBinary) {
            parallelViewer.addSubView(" ROC & AUC ", new ROCViewer());
        }
        parallelViewer.addSubView("Confusion Matrix", new ConfusionMatrixViewer());
        parallelViewer.addSubView("Debug", new VanillaViewer());
        parallelViewer.setContent(this);
        return parallelViewer;
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public String[] getFormatNames() {
        return new String[]{EVAL_FORMAT_NAME};
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public String getExtensionFor(String str) {
        return EVAL_EXT;
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public void saveAs(File file, String str) throws IOException {
        save(file);
    }

    @Override // edu.cmu.minorthird.util.Saveable
    public Object restore(File file) throws IOException {
        return load(file);
    }

    public void save(File file) throws IOException {
        save(new PrintStream(new GZIPOutputStream(new FileOutputStream(file))));
    }

    public void save(PrintStream printStream) throws IOException {
        printStream.println(StringUtil.toString(this.schema.validClassNames()));
        Iterator it = this.propertyKeyList.iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            printStream.println(str + "=" + this.properties.getProperty(str));
        }
        byOriginalPosition();
        Iterator it2 = this.entryList.iterator();
        while (it2.hasNext()) {
            Entry entry = (Entry) it2.next();
            printStream.println(entry.predicted.bestClassName() + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + entry.predicted.bestWeight() + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + entry.actual.bestClassName());
        }
        printStream.close();
    }

    public static Evaluation load(File file) throws IOException {
        LineNumberReader lineNumberReader = new LineNumberReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))));
        String readLine = lineNumberReader.readLine();
        if (readLine == null) {
            throw new IllegalArgumentException("no class list on line 1 of file " + file.getName());
        }
        Evaluation evaluation = new Evaluation(new ExampleSchema(readLine.substring(1, readLine.length() - 1).split(",")));
        while (true) {
            String readLine2 = lineNumberReader.readLine();
            if (readLine2 == null) {
                lineNumberReader.close();
                return evaluation;
            }
            if (readLine2.indexOf(61) >= 0) {
                String[] split = readLine2.split("=");
                if (split.length == 2) {
                    evaluation.setProperty(split[0], split[1]);
                } else {
                    if (split.length != 1) {
                        throw new IllegalArgumentException(file.getName() + " line " + lineNumberReader.getLineNumber() + ": illegal format");
                    }
                    evaluation.setProperty(split[0], "");
                }
            } else {
                String[] split2 = readLine2.split(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR);
                if (split2.length < 3) {
                    throw new IllegalArgumentException(file.getName() + " line " + lineNumberReader.getLineNumber() + ": illegal format");
                }
                evaluation.extend(new ClassLabel(split2[0], StringUtil.atof(split2[1])), new Example(new MutableInstance("dummy"), new ClassLabel(split2[2])), 0);
            }
        }
    }

    public boolean isBinary() {
        return this.isBinary;
    }

    public ExampleSchema getSchema() {
        return this.schema;
    }

    private Entry getEntry(int i) {
        return (Entry) this.entryList.get(i);
    }

    private int classIndexOf(ClassLabel classLabel) {
        return classIndexOf(classLabel.bestClassName());
    }

    private int classIndexOf(String str) {
        return this.schema.getClassIndex(str);
    }

    private void extendSchema(ClassLabel classLabel) {
        if (!classLabel.isBinary()) {
            this.isBinary = false;
        }
        if (classIndexOf(classLabel.bestClassName()) < 0) {
            this.schema.extend(classLabel.bestClassName());
        }
    }

    private void byBinaryScore() {
        Collections.sort(this.entryList, new Comparator() { // from class: edu.cmu.minorthird.classify.experiments.Evaluation.1
            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                return MathUtil.sign(((Entry) obj2).predicted.posWeight() - ((Entry) obj).predicted.posWeight());
            }
        });
    }

    private void byOriginalPosition() {
        Collections.sort(this.entryList, new Comparator() { // from class: edu.cmu.minorthird.classify.experiments.Evaluation.2
            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                return ((Entry) obj).index - ((Entry) obj2).index;
            }
        });
    }

    public static void main(String[] strArr) {
        try {
            Evaluation load = load(new File(strArr[0]));
            if (strArr.length > 1) {
                load.save(new File(strArr[1]));
            }
            new ViewerFrame("From file " + strArr[0], load.toGUI());
        } catch (Exception e) {
            System.out.println("usage: Evaluation [serializedFile|evaluationFile] [evaluationFile]");
            e.printStackTrace();
        }
    }
}
