package edu.cmu.minorthird.classify;

import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.algorithms.svm.SVMLearner;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/OnlineVersion.class */
public class OnlineVersion extends OnlineClassifierLearner {
    private static Logger log = Logger.getLogger(OnlineVersion.class);
    private BatchClassifierLearner innerLearner;
    private OnlineClassifierLearner bootstrapLearner;
    private double loadFactor;
    private int minBatchTrainingSize;
    private Classifier storedClassifier;
    private int lastTrainingSetSize;
    private Dataset dataset;

    public OnlineVersion(BatchClassifierLearner batchClassifierLearner, double d, OnlineClassifierLearner onlineClassifierLearner, int i) {
        this.innerLearner = batchClassifierLearner;
        this.loadFactor = d;
        this.bootstrapLearner = onlineClassifierLearner;
        this.minBatchTrainingSize = i;
        reset();
    }

    public OnlineVersion(BatchClassifierLearner batchClassifierLearner, double d) {
        this(batchClassifierLearner, d, new VotedPerceptron(), 10);
    }

    public OnlineVersion(BatchClassifierLearner batchClassifierLearner) {
        this(batchClassifierLearner, 1.5d);
    }

    public OnlineVersion() {
        this(new SVMLearner());
    }

    public BatchClassifierLearner getInnerLearner() {
        return this.innerLearner;
    }

    public void setInnerLearner(BatchClassifierLearner batchClassifierLearner) {
        this.innerLearner = batchClassifierLearner;
    }

    public OnlineClassifierLearner getBootstrapLearner() {
        return this.bootstrapLearner;
    }

    public void setBootstrapLearner(OnlineClassifierLearner onlineClassifierLearner) {
        this.bootstrapLearner = onlineClassifierLearner;
    }

    public double getBatchLoadFactor() {
        return this.loadFactor;
    }

    public void setBatchLoadFactor(double d) {
        this.loadFactor = d;
    }

    public int getMinBatchTrainingSize() {
        return this.minBatchTrainingSize;
    }

    public void setMinBatchTrainingSize(int i) {
        this.minBatchTrainingSize = i;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final void setSchema(ExampleSchema exampleSchema) {
        this.innerLearner.setSchema(exampleSchema);
        this.bootstrapLearner.setSchema(exampleSchema);
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public final void reset() {
        this.storedClassifier = null;
        this.lastTrainingSetSize = 0;
        this.dataset = new BasicDataset();
        this.innerLearner.reset();
        this.bootstrapLearner.reset();
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public final void addExample(Example example) {
        this.dataset.add(example);
        if (this.dataset.size() < this.minBatchTrainingSize) {
            this.bootstrapLearner.addExample(example);
        }
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public final void completeTraining() {
        new ViewerFrame("compete data", this.dataset.toGUI());
        if (this.dataset.size() > this.lastTrainingSetSize || this.storedClassifier == null) {
            log.info("final training for " + this.innerLearner + " on " + this.dataset.size() + " examples");
            this.storedClassifier = this.innerLearner.batchTrain(this.dataset);
            new ViewerFrame("classifier", new SmartVanillaViewer(this.storedClassifier));
            this.lastTrainingSetSize = this.dataset.size();
        }
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public final Classifier getClassifier() {
        if (this.dataset.size() < this.minBatchTrainingSize) {
            return this.bootstrapLearner.getClassifier();
        }
        if (this.dataset.size() <= this.lastTrainingSetSize * this.loadFactor && this.storedClassifier != null) {
            return this.storedClassifier;
        }
        log.info("re-training " + this.innerLearner + " on " + this.dataset.size() + " examples");
        this.storedClassifier = this.innerLearner.batchTrain(this.dataset);
        log.info("batch classifier is " + this.storedClassifier);
        this.lastTrainingSetSize = this.dataset.size();
        return this.storedClassifier;
    }
}
