package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.algorithms.knn.KnnLearner;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.linear.NaiveBayes;
import edu.cmu.minorthird.classify.algorithms.linear.PoissonLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.experiments.SubsamplingCrossValSplitter;
import edu.cmu.minorthird.classify.experiments.Tester;
import edu.cmu.minorthird.classify.semisupervised.SemiSupervisedNaiveBayesLearner;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import junit.textui.TestRunner;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/TestPackage.class */
public class TestPackage extends TestSuite {
    private static Logger log;
    static Class class$edu$cmu$minorthird$classify$TestPackage;

    /* loaded from: input_file:edu/cmu/minorthird/classify/TestPackage$LearnerTest.class */
    public static class LearnerTest extends TestCase {
        private ClassifierLearner learner;
        private double expectedTestError;
        private double allowedVariance;
        private String testName;

        public LearnerTest(String str, ClassifierLearner classifierLearner, double d, double d2) {
            super("doTest");
            this.learner = classifierLearner;
            this.expectedTestError = d;
            this.testName = str;
            this.allowedVariance = d2;
        }

        public void doTest() {
            Dataset sampleData = SampleDatasets.sampleData(this.testName, false);
            sampleData.shuffle(new Random(0L));
            Classifier train = new DatasetClassifierTeacher(sampleData).train(this.learner);
            TestPackage.log.debug(new StringBuffer().append("classifier is ").append(train).toString());
            double errorRate = Tester.errorRate(train, SampleDatasets.sampleData(this.testName, true));
            TestPackage.log.debug(new StringBuffer().append("error of ").append(this.learner).append(" is ").append(errorRate).toString());
            assertEquals(this.expectedTestError, errorRate, this.allowedVariance + 0.001d);
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/TestPackage$LogisticRegressionTest.class */
    public static class LogisticRegressionTest extends TestCase {
        public LogisticRegressionTest() {
            super("doTest");
        }

        public void doTest() {
            MaxEntLearner maxEntLearner = new MaxEntLearner();
            Dataset makeLogisticRegressionData = SampleDatasets.makeLogisticRegressionData(new Random(0L), 1000, 0.2d, 0.3d);
            assertEquals(0.415d, Tester.errorRate(maxEntLearner.batchTrain(makeLogisticRegressionData), makeLogisticRegressionData), 0.05d);
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/TestPackage$XValTest.class */
    public static class XValTest extends TestCase {
        private int numSites;
        private int numPagesPerSite;
        private boolean subsample;

        public XValTest(int i, int i2) {
            this(i, i2, false);
        }

        public XValTest(int i, int i2, boolean z) {
            super("doTest");
            this.numSites = i;
            this.numPagesPerSite = i2;
            this.subsample = z;
        }

        public void doTest() {
            TestPackage.log.debug(new StringBuffer().append("[XValTest sites: ").append(this.numSites).append(" pages/site: ").append(this.numPagesPerSite).append("]").toString());
            ArrayList arrayList = new ArrayList();
            for (int i = 1; i <= this.numSites; i++) {
                String stringBuffer = new StringBuffer().append("www.site").append(i).append(".com").toString();
                for (int i2 = 1; i2 <= this.numPagesPerSite; i2++) {
                    MutableInstance mutableInstance = new MutableInstance(new StringBuffer().append("page").append(i2).append(".html").toString(), stringBuffer);
                    mutableInstance.addBinary(new Feature(new StringBuffer().append("site").append(i).append(".page").append(i2).toString()));
                    arrayList.add(mutableInstance);
                    TestPackage.log.debug(new StringBuffer().append("instance: ").append(mutableInstance).toString());
                }
            }
            int size = arrayList.size();
            Splitter subsamplingCrossValSplitter = this.subsample ? new SubsamplingCrossValSplitter(3, 0.2d) : new CrossValSplitter(3);
            subsamplingCrossValSplitter.split(arrayList.iterator());
            assertEquals(3, subsamplingCrossValSplitter.getNumPartitions());
            Set[] setArr = new Set[3];
            Set[] setArr2 = new Set[3];
            int i3 = 0;
            for (int i4 = 0; i4 < 3; i4++) {
                TestPackage.log.debug(new StringBuffer().append("partition ").append(i4 + 1).append(":").toString());
                setArr[i4] = asSet(subsamplingCrossValSplitter.getTrain(i4));
                setArr2[i4] = asSet(subsamplingCrossValSplitter.getTest(i4));
                for (Object obj : setArr2[i4]) {
                    TestPackage.log.debug(new StringBuffer().append("  test:  ").append(obj).toString());
                    assertTrue(!setArr[i4].contains(obj));
                }
                TestPackage.log.debug(new StringBuffer().append("  -----\n  ").append(setArr2[i4].size()).append(" total").toString());
                for (Object obj2 : setArr[i4]) {
                    TestPackage.log.debug(new StringBuffer().append("  train:  ").append(obj2).toString());
                    assertTrue(!setArr2[i4].contains(obj2));
                }
                TestPackage.log.debug(new StringBuffer().append("  -----\n  ").append(setArr[i4].size()).append(" total").toString());
                if (this.subsample) {
                    assertTrue(size >= setArr[i4].size() + setArr2[i4].size());
                } else {
                    assertEquals(size, setArr[i4].size() + setArr2[i4].size());
                }
                i3 += setArr2[i4].size();
            }
            assertEquals(size, i3);
        }

        private Set asSet(Iterator it) {
            HashSet hashSet = new HashSet();
            while (it.hasNext()) {
                hashSet.add(it.next());
            }
            return hashSet;
        }
    }

    public TestPackage(String str) {
        super(str);
    }

    public static TestSuite suite() {
        TestSuite testSuite = new TestSuite();
        testSuite.addTest(new LearnerTest("bayesUnlabeled", new SemiSupervisedNaiveBayesLearner(), 0.0d, 0.0d));
        testSuite.addTest(new LearnerTest("bayesExtreme", new PoissonLearner(), 0.0d, 0.0d));
        testSuite.addTest(new LearnerTest("bayesExtreme", new NaiveBayes(), 0.5d, 0.5d));
        testSuite.addTest(new LearnerTest("toy", new NaiveBayes(), 0.14285714285714285d, 0.14285714285714285d));
        testSuite.addTest(new LearnerTest("bayes", new PoissonLearner(), 0.14285714285714285d, 0.14285714285714285d));
        testSuite.addTest(new LearnerTest("toy", new BinaryBatchVersion(new VotedPerceptron()), 0.0d, 0.14285714285714285d));
        testSuite.addTest(new LearnerTest("toy", new VotedPerceptron(), 0.14285714285714285d, 0.14285714285714285d));
        testSuite.addTest(new LearnerTest("toy", new DecisionTreeLearner(5, 2), 0.14285714285714285d, 0.14285714285714285d));
        testSuite.addTest(new LearnerTest("toy", new KnnLearner(10), 0.0d, 0.1d));
        testSuite.addTest(new LearnerTest("toy3", new KnnLearner(10), 0.2d, 0.1d));
        testSuite.addTest(new LearnerTest("toy", new AdaBoost(new DecisionTreeLearner(5, 2), 10), 0.14285714285714285d, 0.14285714285714285d));
        testSuite.addTest(new LearnerTest("num", new DecisionTreeLearner(5, 2), 0.05d, 0.1d));
        testSuite.addTest(new LearnerTest("sparseNum", new DecisionTreeLearner(5, 2), 0.0d, 0.1d));
        testSuite.addTest(new LogisticRegressionTest());
        testSuite.addTest(new XValTest(10, 1));
        testSuite.addTest(new XValTest(3, 5));
        testSuite.addTest(new XValTest(50, 1, true));
        testSuite.addTest(new XValTest(3, 25, true));
        return testSuite;
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$classify$TestPackage == null) {
            cls = class$("edu.cmu.minorthird.classify.TestPackage");
            class$edu$cmu$minorthird$classify$TestPackage = cls;
        } else {
            cls = class$edu$cmu$minorthird$classify$TestPackage;
        }
        log = Logger.getLogger(cls);
    }
}
