package edu.cmu.minorthird.classify.algorithms.trees;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTree;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.StringUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.VanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
import java.util.concurrent.Semaphore;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomForests.class */
public class RandomForests extends BatchBinaryClassifierLearner {
    private static Logger log = Logger.getLogger(RandomForests.class);
    private RandomTreeLearner baseLearner;
    private int numComponents;
    private boolean isThreaded;
    private int threadCount;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomForests$LearnerThread.class */
    private class LearnerThread extends Thread {
        Vector<Example> examples;
        Vector<Feature> features;
        List classifiers;
        Hashtable results;
        Semaphore s;

        public LearnerThread(Vector<Example> vector, Vector<Feature> vector2, List list, Hashtable hashtable, Semaphore semaphore) {
            this.examples = vector;
            this.features = vector2;
            this.classifiers = list;
            this.results = hashtable;
            this.s = semaphore;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            LinkedList linkedList = new LinkedList();
            HashSet hashSet = new HashSet();
            HashSet hashSet2 = new HashSet();
            for (int i = 0; i < this.examples.size(); i++) {
                Example elementAt = this.examples.elementAt((int) Math.floor(Math.random() * this.examples.size()));
                if (hashSet2.add(elementAt)) {
                    linkedList.add(elementAt);
                }
            }
            Iterator<Example> it = this.examples.iterator();
            while (it.hasNext()) {
                Example next = it.next();
                if (!hashSet2.contains(next)) {
                    hashSet.add(next);
                }
            }
            RandomForests.log.info("RandomForest is building tree  with " + linkedList.size() + " elements");
            BinaryClassifier binaryClassifier = (BinaryClassifier) RandomForests.this.baseLearner.batchTrain(linkedList, this.features);
            this.classifiers.add(binaryClassifier);
            this.results.put(binaryClassifier, hashSet);
            if (RandomForests.this.isThreaded) {
                this.s.release();
            }
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomForests$VotingClassifier.class */
    public static class VotingClassifier extends BinaryClassifier implements Serializable, Visible {
        private List classifiers;

        public VotingClassifier(List list) {
            this.classifiers = list;
        }

        public List getClassifiers() {
            return this.classifiers;
        }

        @Override // edu.cmu.minorthird.classify.BinaryClassifier
        public double score(Instance instance) {
            double d = 0.0d;
            Iterator it = this.classifiers.iterator();
            while (it.hasNext()) {
                d += ((BinaryClassifier) it.next()).score(instance);
            }
            return d > 0.0d ? 1.0d : -1.0d;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            StringBuffer stringBuffer = new StringBuffer("");
            double d = 0.0d;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                d += binaryClassifier.score(instance);
                stringBuffer.append("score of " + binaryClassifier + ": " + binaryClassifier.score(instance) + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
                stringBuffer.append(StringUtil.indent(1, binaryClassifier.explain(instance)) + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            }
            stringBuffer.append("total score: " + d);
            return stringBuffer.toString();
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public Explanation getExplanation(Instance instance) {
            Explanation.Node node = new Explanation.Node("Random Forest Explanation");
            double d = 0.0d;
            for (BinaryClassifier binaryClassifier : this.classifiers) {
                d += binaryClassifier.score(instance);
                Explanation.Node node2 = new Explanation.Node("score of " + binaryClassifier);
                Explanation.Node node3 = new Explanation.Node(d + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR);
                Explanation.Node topNode = binaryClassifier.getExplanation(instance).getTopNode();
                node2.add(node3);
                node2.add(topNode);
                node.add(node2);
            }
            node.add(new Explanation.Node("total score: " + d));
            return new Explanation(node);
        }

        public String toString() {
            StringBuffer stringBuffer = new StringBuffer("[voting classifiers:\n");
            Iterator it = this.classifiers.iterator();
            while (it.hasNext()) {
                stringBuffer.append(((BinaryClassifier) it.next()).toString() + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            }
            stringBuffer.append("]");
            return stringBuffer.toString();
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            VotingClassifierViewer votingClassifierViewer = new VotingClassifierViewer();
            votingClassifierViewer.setContent(this);
            return votingClassifierViewer;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomForests$VotingClassifierViewer.class */
    private static class VotingClassifierViewer extends ComponentViewer {
        private VotingClassifierViewer() {
        }

        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            JPanel jPanel = new JPanel();
            jPanel.setLayout(new GridBagLayout());
            int i = 0;
            for (Classifier classifier : ((VotingClassifier) obj).classifiers) {
                GridBagConstraints gridBagConstraints = new GridBagConstraints();
                gridBagConstraints.fill = 2;
                gridBagConstraints.weighty = 0.0d;
                gridBagConstraints.weightx = 0.0d;
                gridBagConstraints.gridx = 0;
                int i2 = i;
                i++;
                gridBagConstraints.gridy = i2;
                Viewer gui = classifier instanceof Visible ? ((Visible) classifier).toGUI() : new VanillaViewer(classifier);
                gui.setSuperView(this);
                jPanel.add(gui, gridBagConstraints);
            }
            JScrollPane jScrollPane = new JScrollPane(jPanel);
            jScrollPane.setHorizontalScrollBarPolicy(30);
            return jScrollPane;
        }
    }

    public RandomForests() {
        this(100);
    }

    public RandomForests(int i) {
        this(new FastRandomTreeLearner(), i);
    }

    public RandomForests(RandomTreeLearner randomTreeLearner, int i) {
        this.isThreaded = true;
        this.threadCount = 4;
        this.baseLearner = randomTreeLearner;
        this.numComponents = i;
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        Vector<Example> vector = new Vector<>(dataset.size());
        Vector<Feature> datasetFeatures = getDatasetFeatures(dataset);
        int size = dataset.size();
        Example.Looper it = dataset.iterator();
        for (int i = 0; i < size; i++) {
            vector.add(i, it.nextExample());
        }
        Hashtable hashtable = new Hashtable();
        ArrayList arrayList = new ArrayList(this.numComponents);
        ProgressCounter progressCounter = new ProgressCounter("RandomForest", "treecounts", this.numComponents);
        int i2 = this.isThreaded ? this.threadCount - 1 : 1;
        Semaphore semaphore = new Semaphore(i2);
        log.info("Random forests starting with " + dataset.size() + " elements");
        log.info("example size: " + vector.size());
        log.info("Learning classifier with " + this.baseLearner);
        for (int i3 = 0; i3 < this.numComponents; i3++) {
            if (this.isThreaded) {
                semaphore.acquireUninterruptibly();
            }
            LearnerThread learnerThread = new LearnerThread(vector, new Vector(datasetFeatures), arrayList, hashtable, semaphore);
            if (this.isThreaded) {
                learnerThread.start();
            } else {
                learnerThread.run();
            }
            progressCounter.progress();
        }
        if (this.isThreaded) {
            semaphore.acquireUninterruptibly(i2);
            semaphore.release(i2);
        }
        progressCounter.finished();
        printSomeStats(vector, hashtable);
        return new VotingClassifier(arrayList);
    }

    private void printSomeStats(Vector<Example> vector, Hashtable hashtable) {
        printTreeShapeInfo(hashtable);
        printOobErrorEstimate(vector, hashtable);
    }

    private void printTreeShapeInfo(Hashtable hashtable) {
        int[] iArr = new int[hashtable.size()];
        int[] iArr2 = new int[hashtable.size()];
        int i = 0;
        Enumeration keys = hashtable.keys();
        while (keys.hasMoreElements()) {
            DecisionTree decisionTree = (DecisionTree) keys.nextElement();
            iArr[i] = maxDepth(decisionTree);
            iArr2[i] = numNodes(decisionTree);
            i++;
        }
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        while (i5 < hashtable.size()) {
            i2 += iArr2[i5];
            i3 += iArr[i5];
            i4 = iArr[i5] > i4 ? iArr[i5] : i4;
            i5++;
        }
        int round = (int) Math.round(i2 / i5);
        int round2 = (int) Math.round(i3 / i5);
        log.info("Average Number of nodes: " + round);
        log.info("Average Max depth of tree: " + round2);
        log.info("Max Max depth of tree: " + i4);
    }

    private int maxDepth(DecisionTree decisionTree) {
        if (decisionTree instanceof DecisionTree.Leaf) {
            return 1;
        }
        DecisionTree.InternalNode internalNode = (DecisionTree.InternalNode) decisionTree;
        int maxDepth = maxDepth(internalNode.getTrueBranch());
        int maxDepth2 = maxDepth(internalNode.getFalseBranch());
        return (maxDepth > maxDepth2 ? maxDepth : maxDepth2) + 1;
    }

    private int numNodes(DecisionTree decisionTree) {
        if (decisionTree instanceof DecisionTree.Leaf) {
            return 1;
        }
        DecisionTree.InternalNode internalNode = (DecisionTree.InternalNode) decisionTree;
        return maxDepth(internalNode.getTrueBranch()) + maxDepth(internalNode.getFalseBranch()) + 1;
    }

    private void printOobErrorEstimate(Vector<Example> vector, Hashtable hashtable) {
        int i = 0;
        int i2 = 0;
        Iterator<Example> it = vector.iterator();
        while (it.hasNext()) {
            Example next = it.next();
            double d = 0.0d;
            Enumeration keys = hashtable.keys();
            while (keys.hasMoreElements()) {
                DecisionTree decisionTree = (DecisionTree) keys.nextElement();
                if (((HashSet) hashtable.get(decisionTree)).contains(next)) {
                    d += decisionTree.score(next.asInstance());
                }
            }
            if ((next.getLabel().numericLabel() <= 0.0d || d <= 0.0d) && (next.getLabel().numericLabel() >= 0.0d || d >= 0.0d)) {
                i2++;
            } else {
                i++;
            }
        }
        log.info("out of bag num correct: " + i);
        log.info("out of bag num inCorrect: " + i2);
        log.info("out of bag estimated error: " + (i2 / (i + i2)));
    }

    public static Vector<Feature> getDatasetFeatures(Dataset dataset) {
        Example.Looper it = dataset.iterator();
        HashSet hashSet = new HashSet();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            Feature.Looper binaryFeatureIterator = nextExample.binaryFeatureIterator();
            while (binaryFeatureIterator.hasNext()) {
                hashSet.add(binaryFeatureIterator.nextFeature());
            }
            Feature.Looper numericFeatureIterator = nextExample.numericFeatureIterator();
            while (numericFeatureIterator.hasNext()) {
                hashSet.add(numericFeatureIterator.nextFeature());
            }
        }
        return new Vector<>(hashSet);
    }
}
