package edu.cmu.minorthird.classify;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.algorithms.trees.AdaBoost;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.BorderLayout;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/StackedLearner.class */
public class StackedLearner extends BatchClassifierLearner {
    private static Logger log = Logger.getLogger(StackedLearner.class);
    private static final boolean DEBUG = false;
    private ExampleSchema schema;
    private BatchClassifierLearner[] innerLearners;
    private BatchClassifierLearner finalLearner;
    private Splitter splitter;

    /* loaded from: input_file:edu/cmu/minorthird/classify/StackedLearner$StackedClassifier.class */
    private static class StackedClassifier implements Classifier, Visible {
        private ExampleSchema schema;
        private Classifier[] innerClassifiers;
        private Classifier finalClassifier;

        public StackedClassifier(ExampleSchema exampleSchema, Classifier[] classifierArr, Classifier classifier) {
            this.schema = exampleSchema;
            this.innerClassifiers = classifierArr;
            this.finalClassifier = classifier;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            return this.finalClassifier.classification(StackedLearner.transformInstance(this.schema, instance, this.innerClassifiers));
        }

        public double score(Instance instance, String str) {
            return classification(instance).getWeight(str);
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            StringBuffer stringBuffer = new StringBuffer("");
            stringBuffer.append(StackedLearner.explainTransformedInstance(this.schema, instance, this.innerClassifiers));
            Instance transformInstance = StackedLearner.transformInstance(this.schema, instance, this.innerClassifiers);
            stringBuffer.append("final classifier:\n");
            stringBuffer.append(this.finalClassifier.explain(transformInstance));
            return stringBuffer.toString();
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public Explanation getExplanation(Instance instance) {
            return new Explanation(explain(instance));
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            return new ComponentViewer() { // from class: edu.cmu.minorthird.classify.StackedLearner.StackedClassifier.1
                @Override // edu.cmu.minorthird.util.gui.ComponentViewer
                public JComponent componentFor(Object obj) {
                    JPanel jPanel = new JPanel();
                    jPanel.setLayout(new BorderLayout());
                    jPanel.setBorder(new TitledBorder("Stacked Classifier"));
                    JPanel jPanel2 = new JPanel();
                    jPanel2.setBorder(new TitledBorder("Final classifier"));
                    SmartVanillaViewer smartVanillaViewer = new SmartVanillaViewer(((StackedClassifier) obj).finalClassifier);
                    jPanel2.add(smartVanillaViewer);
                    smartVanillaViewer.setSuperView(this);
                    jPanel.add(jPanel2, "North");
                    JPanel jPanel3 = new JPanel();
                    jPanel3.setBorder(new TitledBorder("Inner classifier(s)"));
                    for (int i = 0; i < StackedClassifier.this.innerClassifiers.length; i++) {
                        SmartVanillaViewer smartVanillaViewer2 = new SmartVanillaViewer(StackedClassifier.this.innerClassifiers[i]);
                        jPanel3.add(smartVanillaViewer2);
                        smartVanillaViewer2.setSuperView(this);
                    }
                    jPanel.add(jPanel3, "South");
                    return new JScrollPane(jPanel);
                }
            };
        }
    }

    public StackedLearner(BatchClassifierLearner batchClassifierLearner, Splitter splitter) {
        this(new BatchClassifierLearner[]{batchClassifierLearner}, new MaxEntLearner(), splitter);
    }

    public StackedLearner(BatchClassifierLearner batchClassifierLearner) {
        this(new BatchClassifierLearner[]{batchClassifierLearner}, new MaxEntLearner(), new CrossValSplitter(3));
    }

    public StackedLearner() {
        this(new BatchClassifierLearner[]{new AdaBoost()}, new MaxEntLearner(), new CrossValSplitter(3));
    }

    public StackedLearner(BatchClassifierLearner[] batchClassifierLearnerArr, BatchClassifierLearner batchClassifierLearner, Splitter splitter) {
        this.innerLearners = batchClassifierLearnerArr;
        this.finalLearner = batchClassifierLearner;
        this.splitter = splitter;
    }

    public Splitter getSplitter() {
        return this.splitter;
    }

    public void setSplitter(Splitter splitter) {
        this.splitter = splitter;
    }

    public void setInnerLearner(BatchClassifierLearner batchClassifierLearner) {
        this.innerLearners = new BatchClassifierLearner[]{batchClassifierLearner};
    }

    public BatchClassifierLearner getInnerLearner() {
        if (this.innerLearners.length != 1) {
            throw new IllegalStateException("multiple inner learners");
        }
        return this.innerLearners[0];
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final void setSchema(ExampleSchema exampleSchema) {
        this.schema = exampleSchema;
        for (int i = 0; i < this.innerLearners.length; i++) {
            this.innerLearners[i].setSchema(exampleSchema);
        }
        this.finalLearner.setSchema(exampleSchema);
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        BasicDataset basicDataset = new BasicDataset();
        Classifier[] classifierArr = new Classifier[this.innerLearners.length];
        Dataset.Split split = dataset.split(this.splitter);
        for (int i = 0; i < split.getNumPartitions(); i++) {
            Dataset train = split.getTrain(i);
            for (int i2 = 0; i2 < this.innerLearners.length; i2++) {
                this.innerLearners[i2].reset();
                log.info("training inner learner " + (i2 + 1) + "/" + this.innerLearners.length + " on fold " + (i + 1) + "/" + split.getNumPartitions());
                classifierArr[i2] = this.innerLearners[i2].batchTrain(train);
            }
            Dataset test = split.getTest(i);
            log.info("transforming test examples of fold " + (i + 1) + "/" + split.getNumPartitions());
            Example.Looper it = test.iterator();
            while (it.hasNext()) {
                Example nextExample = it.nextExample();
                basicDataset.add(new Example(transformInstance(this.schema, nextExample, classifierArr), nextExample.getLabel()));
            }
        }
        log.info("training level-1 learner");
        Classifier batchTrain = this.finalLearner.batchTrain(basicDataset);
        log.info("result is " + batchTrain);
        for (int i3 = 0; i3 < this.innerLearners.length; i3++) {
            log.info("training inner learner " + (i3 + 1) + "/" + this.innerLearners.length + " on full dataset");
            classifierArr[i3] = this.innerLearners[i3].batchTrain(dataset);
        }
        this.classifier = new StackedClassifier(this.schema, classifierArr, batchTrain);
        return this.classifier;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Instance transformInstance(ExampleSchema exampleSchema, Instance instance, Classifier[] classifierArr) {
        MutableInstance mutableInstance = new MutableInstance();
        for (int i = 0; i < classifierArr.length; i++) {
            ClassLabel classification = classifierArr[i].classification(instance);
            String str = "learner_" + i;
            for (int i2 = 0; i2 < exampleSchema.getNumberOfClasses(); i2++) {
                String className = exampleSchema.getClassName(i2);
                mutableInstance.addNumeric(new Feature(new String[]{str, "class_" + className}), classification.getWeight(className));
            }
        }
        return mutableInstance;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String explainTransformedInstance(ExampleSchema exampleSchema, Instance instance, Classifier[] classifierArr) {
        StringBuffer stringBuffer = new StringBuffer("");
        MutableInstance mutableInstance = new MutableInstance();
        for (int i = 0; i < classifierArr.length; i++) {
            ClassLabel classification = classifierArr[i].classification(instance);
            String str = "learner_" + i;
            for (int i2 = 0; i2 < exampleSchema.getNumberOfClasses(); i2++) {
                String className = exampleSchema.getClassName(i2);
                mutableInstance.addNumeric(new Feature(new String[]{str, "class_" + className}), classification.getWeight(className));
                stringBuffer.append("learner#" + (i + 1) + " predicts " + className + ":\n" + classifierArr[i].explain(instance) + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            }
        }
        return stringBuffer.toString();
    }
}
