package edu.cmu.minorthird.classify.experiments;

import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.sequential.DatasetSequenceClassifierTeacher;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.classify.sequential.SequenceClassifierLearner;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.ProgressCounter;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/experiments/Tester.class */
public class Tester {
    private static Logger log;
    private static final boolean DEBUG;
    static Class class$edu$cmu$minorthird$classify$experiments$Tester;

    public static Evaluation evaluate(ClassifierLearner classifierLearner, Dataset dataset, Splitter splitter) {
        Evaluation evaluation = new Evaluation(dataset.getSchema());
        Dataset.Split split = dataset.split(splitter);
        ProgressCounter progressCounter = new ProgressCounter("train/test", "fold", split.getNumPartitions());
        for (int i = 0; i < split.getNumPartitions(); i++) {
            Dataset train = split.getTrain(i);
            Dataset test = split.getTest(i);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", preparing to train on ").append(train.size()).append(" and test on ").append(test.size()).toString());
            Classifier train2 = new DatasetClassifierTeacher(train).train(classifierLearner);
            if (DEBUG) {
                log.debug(new StringBuffer().append("classifier for fold ").append(i + 1).append("/").append(split.getNumPartitions()).append(" is:\n").append(train2).toString());
            }
            evaluation.extend(train2, test, i);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", completed train-test round").toString());
            progressCounter.progress();
        }
        progressCounter.finished();
        return evaluation;
    }

    public static Evaluation evaluate(SequenceClassifierLearner sequenceClassifierLearner, SequenceDataset sequenceDataset, Splitter splitter) {
        Evaluation evaluation = new Evaluation(sequenceDataset.getSchema());
        Dataset.Split split = sequenceDataset.split(splitter);
        ProgressCounter progressCounter = new ProgressCounter("train/test", "fold", split.getNumPartitions());
        for (int i = 0; i < split.getNumPartitions(); i++) {
            SequenceDataset sequenceDataset2 = (SequenceDataset) split.getTrain(i);
            SequenceDataset sequenceDataset3 = (SequenceDataset) split.getTest(i);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", preparing to train on ").append(sequenceDataset2.size()).append(" and test on ").append(sequenceDataset3.size()).toString());
            SequenceClassifier train = new DatasetSequenceClassifierTeacher(sequenceDataset2).train(sequenceClassifierLearner);
            if (DEBUG) {
                log.debug(new StringBuffer().append("classifier for fold ").append(i + 1).append("/").append(split.getNumPartitions()).append(" is:\n").append(train).toString());
            }
            evaluation.extend(train, sequenceDataset3);
            log.info(new StringBuffer().append("splitting with ").append(splitter).append(", completed train-test round").toString());
            progressCounter.progress();
        }
        progressCounter.finished();
        return evaluation;
    }

    public static Evaluation evaluate(ClassifierLearner classifierLearner, Dataset dataset, Dataset dataset2) {
        return evaluate(classifierLearner, dataset, new FixedTestSetSplitter(dataset2.iterator()));
    }

    public static Evaluation evaluate(SequenceClassifierLearner sequenceClassifierLearner, SequenceDataset sequenceDataset, SequenceDataset sequenceDataset2) {
        return evaluate(sequenceClassifierLearner, sequenceDataset, new FixedTestSetSplitter(sequenceDataset2.iterator()));
    }

    public static double logLoss(BinaryClassifier binaryClassifier, Example example) {
        return Math.log(1.0d + Math.exp(example.getLabel().numericScore() * binaryClassifier.score(example)));
    }

    public static double logLoss(BinaryClassifier binaryClassifier, Dataset dataset) {
        double d = 0.0d;
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            d += logLoss(binaryClassifier, it.nextExample());
        }
        return d / dataset.size();
    }

    public static double errorRate(Classifier classifier, Dataset dataset) {
        double d = 0.0d;
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            if (!classifier.classification(nextExample).isCorrect(nextExample.getLabel())) {
                d += 1.0d;
            }
        }
        return d / dataset.size();
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$experiments$Tester == null) {
            cls = class$("edu.cmu.minorthird.classify.experiments.Tester");
            class$edu$cmu$minorthird$classify$experiments$Tester = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$experiments$Tester;
        }
        log = Logger.getLogger(cls);
        DEBUG = log.getEffectiveLevel().isGreaterOrEqual(Level.DEBUG);
    }
}
