package edu.cmu.minorthird.classify.algorithms.active;

import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.RandomAccessDataset;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner;
import gnu.trove.TObjectDoubleHashMap;
import java.util.Random;
import java.util.TreeMap;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/active/QueryByCommittee.class */
public class QueryByCommittee implements ClassifierLearner {
    private static Logger log = Logger.getLogger(QueryByCommittee.class);
    private ClassifierLearner innerLearner;
    private int minLabelsBeforeQuerying;
    private CommitteeLearner committeeLearner;
    private ExampleSchema schema;
    private TreeMap unlabeled;
    private RandomAccessDataset labeled;

    public QueryByCommittee() {
        this(new DecisionTreeLearner(), 5);
    }

    public QueryByCommittee(BatchClassifierLearner batchClassifierLearner, int i) {
        this.minLabelsBeforeQuerying = 5;
        this.committeeLearner = new CommitteeLearner(batchClassifierLearner, i);
        this.innerLearner = batchClassifierLearner;
        reset();
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public ClassifierLearner copy() {
        ClassifierLearner classifierLearner = null;
        try {
            classifierLearner = (ClassifierLearner) clone();
            classifierLearner.reset();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return classifierLearner;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final void reset() {
        this.unlabeled = new TreeMap();
        this.labeled = new RandomAccessDataset();
        this.innerLearner.reset();
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final void setSchema(ExampleSchema exampleSchema) {
        this.schema = exampleSchema;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final void setInstancePool(Instance.Looper looper) {
        this.unlabeled.clear();
        Random random = new Random(0L);
        while (looper.hasNext()) {
            this.unlabeled.put(new Double(random.nextDouble()), looper.nextInstance());
        }
        log.info(this.unlabeled.size() + " unlabeled examples available");
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final boolean hasNextQuery() {
        return this.unlabeled.size() > 0;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final Instance nextQuery() {
        Object keyOfBestUnlabeledInstance;
        if (this.labeled.size() < this.minLabelsBeforeQuerying) {
            log.info("will pick next unlabeled example");
            keyOfBestUnlabeledInstance = this.unlabeled.firstKey();
        } else {
            log.info("will use committee to pick an unlabeled example");
            keyOfBestUnlabeledInstance = keyOfBestUnlabeledInstance(this.committeeLearner.batchTrainCommittee(this.labeled));
        }
        Instance instance = (Instance) this.unlabeled.get(keyOfBestUnlabeledInstance);
        this.unlabeled.remove(keyOfBestUnlabeledInstance);
        return instance;
    }

    private Object keyOfBestUnlabeledInstance(Classifier[] classifierArr) {
        double d = 2.0d;
        Object obj = null;
        for (Object obj2 : this.unlabeled.keySet()) {
            Instance instance = (Instance) this.unlabeled.get(obj2);
            TObjectDoubleHashMap tObjectDoubleHashMap = new TObjectDoubleHashMap();
            double d2 = 0.0d;
            for (Classifier classifier : classifierArr) {
                String bestClassName = classifier.classification(instance).bestClassName();
                double d3 = tObjectDoubleHashMap.get(bestClassName) + 1.0d;
                tObjectDoubleHashMap.put(bestClassName, d3);
                if (d3 > d2) {
                    d2 = d3;
                }
            }
            double length = d2 / classifierArr.length;
            log.info("instance: " + instance + " committee: " + tObjectDoubleHashMap + " agreement: " + length);
            if (length < d) {
                d = length;
                obj = obj2;
                log.debug(" ==> best");
            }
        }
        log.info("queryInstance is: " + this.unlabeled.get(obj));
        return obj;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void addExample(Example example) {
        log.info("adding example: " + example);
        this.labeled.add(example);
        this.innerLearner.addExample(example);
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final void completeTraining() {
        this.innerLearner.completeTraining();
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public Classifier getClassifier() {
        return this.innerLearner.getClassifier();
    }
}
