package edu.cmu.minorthird.classify.algorithms.trees;

import edu.cmu.minorthird.classify.BatchBinaryClassifierLearner;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.algorithms.trees.DecisionTree;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomTreeLearner.class */
public class RandomTreeLearner extends BatchBinaryClassifierLearner {
    private static Logger log = Logger.getLogger(RandomTreeLearner.class);
    private static final boolean DEBUG = log.getEffectiveLevel().isGreaterOrEqual(Level.DEBUG);
    private TreeSplitter splitter;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomTreeLearner$BestOfNRandomTreeSplitter.class */
    public static class BestOfNRandomTreeSplitter implements TreeSplitter {
        int featureCount;

        public BestOfNRandomTreeSplitter(int i) {
            this.featureCount = 1;
            this.featureCount = i;
        }

        @Override // edu.cmu.minorthird.classify.algorithms.trees.RandomTreeLearner.TreeSplitter
        public Object[] getSplit(List<Example> list, int i, Vector<Feature> vector) {
            Feature feature = null;
            double d = 0.0d;
            double d2 = 0.0d;
            LinkedList linkedList = null;
            LinkedList linkedList2 = null;
            for (int i2 = 0; i2 < this.featureCount && i2 < vector.size(); i2++) {
                Feature feature2 = vector.get((int) Math.floor(Math.random() * vector.size()));
                LinkedList linkedList3 = new LinkedList();
                LinkedList linkedList4 = new LinkedList();
                double d3 = Double.MAX_VALUE;
                double d4 = Double.MIN_VALUE;
                Iterator<Example> it = list.iterator();
                while (it.hasNext()) {
                    double weight = it.next().getWeight(feature2);
                    if (weight < d3) {
                        d3 = weight;
                    }
                    if (weight > d4) {
                        d4 = weight;
                    }
                }
                double random = (Math.random() * (d4 - d3)) + d3;
                for (Example example : list) {
                    if (example.getWeight(feature2) >= random) {
                        linkedList3.add(example);
                    } else {
                        linkedList4.add(example);
                    }
                }
                double entropy = RandomTreeLearner.entropy(linkedList3.size(), linkedList4.size(), linkedList3.size(), linkedList4.size());
                if (entropy > d) {
                    d = entropy;
                    linkedList = linkedList3;
                    linkedList2 = linkedList4;
                    feature = feature2;
                    d2 = random;
                }
            }
            return new Object[]{feature, new Double(d2), linkedList, linkedList2};
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomTreeLearner$RandomTreeSplitter.class */
    public static class RandomTreeSplitter implements TreeSplitter {
        @Override // edu.cmu.minorthird.classify.algorithms.trees.RandomTreeLearner.TreeSplitter
        public Object[] getSplit(List<Example> list, int i, Vector<Feature> vector) {
            Feature feature = vector.get((int) Math.floor(Math.random() * vector.size()));
            double d = Double.MAX_VALUE;
            double d2 = Double.MIN_VALUE;
            Iterator<Example> it = list.iterator();
            while (it.hasNext()) {
                double weight = it.next().getWeight(feature);
                if (weight < d) {
                    d = weight;
                }
                if (weight > d2) {
                    d2 = weight;
                }
            }
            return new Object[]{feature, new Double((Math.random() * (d2 - d)) + d)};
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/trees/RandomTreeLearner$TreeSplitter.class */
    public interface TreeSplitter {
        Object[] getSplit(List<Example> list, int i, Vector<Feature> vector);
    }

    public RandomTreeLearner() {
        this.splitter = new RandomTreeSplitter();
    }

    public RandomTreeLearner(TreeSplitter treeSplitter) {
        this.splitter = treeSplitter;
    }

    public Classifier batchTrain(List<Example> list, Vector<Feature> vector) {
        DecisionTree batchTrain = batchTrain(list, 0, vector);
        log.info("built tree: " + batchTrain);
        return batchTrain;
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        LinkedList linkedList = new LinkedList();
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            linkedList.add(it.nextExample());
        }
        return batchTrain(linkedList, RandomForests.getDatasetFeatures(dataset));
    }

    public DecisionTree batchTrain(List<Example> list, int i, Vector<Feature> vector) {
        List<Example> linkedList;
        List<Example> linkedList2;
        double d = 0.0d;
        double d2 = 0.0d;
        for (Example example : list) {
            if (example.getLabel().numericLabel() > 0.0d) {
                d += example.getWeight();
            } else {
                d2 += example.getWeight();
            }
        }
        log.debug("build (sub)tree with posWeight: " + d + " negWeight: " + d2);
        if (d2 == 0.0d || d == 0.0d || vector.size() == 0) {
            int i2 = d > d2 ? 1 : d == d2 ? 0 : -1;
            log.debug("leaf");
            return new DecisionTree.Leaf(i2);
        }
        Object[] split = this.splitter.getSplit(list, i, vector);
        Feature feature = (Feature) split[0];
        double doubleValue = ((Double) split[1]).doubleValue();
        if (split.length == 4) {
            linkedList = (List) split[2];
            linkedList2 = (List) split[3];
        } else {
            linkedList = new LinkedList();
            linkedList2 = new LinkedList();
            for (Example example2 : list) {
                if (example2.getWeight(feature) >= doubleValue) {
                    linkedList.add(example2);
                } else {
                    linkedList2.add(example2);
                }
            }
        }
        log.debug("split on: " + feature + " with threshold " + doubleValue);
        log.debug("trueData size: " + linkedList.size() + " falseData size: " + linkedList2.size());
        Vector<Feature> vector2 = new Vector<>(vector);
        vector2.removeElement(feature);
        if (linkedList2.size() != 0 && linkedList.size() != 0) {
            return new DecisionTree.InternalNode(feature, doubleValue, batchTrain(linkedList, i + 1, vector2), batchTrain(linkedList2, i + 1, vector2));
        }
        log.debug("didn't split data with this feature");
        return batchTrain(list, i, vector2);
    }

    private static double schapireSingerValue(double d, double d2, double d3, double d4) {
        double d5 = d3 + d4;
        double d6 = d / d5;
        double d7 = (d3 - d) / d5;
        double d8 = d2 / d5;
        double d9 = (d4 - d2) / d5;
        log.debug("pos, neg, total = " + d + ", " + d2 + ", " + d5);
        log.debug("wp1,wp0,wn1,wn0 = " + d6 + "," + d7 + "," + d8 + "," + d9);
        return 2.0d * (Math.sqrt(d6 * d8) + Math.sqrt(d7 * d9));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double entropy(double d, double d2, double d3, double d4) {
        double d5 = d3 + d4;
        double d6 = 0.1d / d5;
        double d7 = (d / d5) + d6;
        double d8 = (d2 / d5) + d6;
        double d9 = ((d5 - d) / d5) + d6;
        double d10 = ((d5 - d2) / d5) + d6;
        log.debug("pos, neg, total = " + d + ", " + d2 + ", " + d5);
        log.debug("w11,w10,w01,w00 = " + d7 + "," + d8 + "," + d9 + "," + d10);
        return ((((-d7) * Math.log(d7)) - (d8 * Math.log(d8))) - (d9 * Math.log(d9))) - (d10 * Math.log(d10));
    }
}
