package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/RegretWinnow.class */
public class RegretWinnow extends OnlineBinaryClassifierLearner implements Serializable {
    private Hyperplane pos_t;
    private Hyperplane numGivenPos;
    private Hyperplane numGivenNeg;
    private Hyperplane vpos_t;
    private double theta;
    private double alpha;
    private double beta;
    private int excount;
    private double margin;
    private boolean voted;
    private boolean regret;
    private Hyperplane lossH;
    private Hyperplane lossF;
    private double W_MAX;
    private double W_MIN;
    double beta2;
    private int votedCount;
    private int mode;
    private final int LIST_SIZE = 5;
    private Map fmap;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/RegretWinnow$MyClassifier.class */
    public class MyClassifier implements Classifier, Serializable, Visible {
        private static final long serialVersionUID = 1;
        private final int CURRENT_SERIAL_VERSION = 1;
        private Hyperplane lpos_h;
        private Hyperplane lneg_h;
        private ExampleSchema schema;
        private double mytheta;

        public MyClassifier(Hyperplane hyperplane, double d) {
            this.lpos_h = hyperplane;
            this.mytheta = d;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            double score = this.lpos_h.score(Winnow.normalizeWeights(filterFeat(new Example(instance, new ClassLabel(ExampleSchema.POS_CLASS_NAME))), true).asInstance()) - this.mytheta;
            return score >= 0.0d ? ClassLabel.positiveLabel(score) : ClassLabel.negativeLabel(score);
        }

        public Example filterFeat(Example example) {
            MutableInstance mutableInstance = new MutableInstance();
            Feature.Looper featureIterator = example.asInstance().featureIterator();
            while (featureIterator.hasNext()) {
                Feature nextFeature = featureIterator.nextFeature();
                if (this.lpos_h.hasFeature(nextFeature)) {
                    mutableInstance.addNumeric(nextFeature, example.getWeight(nextFeature));
                }
            }
            return new Example(mutableInstance, example.getLabel());
        }

        public String toString() {
            return "POS = " + this.lpos_h.toString();
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            return "BalancedWinnow: Not implemented yet";
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public Explanation getExplanation(Instance instance) {
            return new Explanation(new Explanation.Node("BalancedWinnow Explanation"));
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            TransformedViewer transformedViewer = new TransformedViewer(new SmartVanillaViewer()) { // from class: edu.cmu.minorthird.classify.algorithms.linear.RegretWinnow.MyClassifier.1
                @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                public Object transform(Object obj) {
                    return ((MyClassifier) obj).lpos_h;
                }
            };
            transformedViewer.setContent(this);
            return transformedViewer;
        }
    }

    public RegretWinnow() {
        this(1.5d, 0.5d, false, 1);
    }

    public RegretWinnow(double d, double d2, boolean z, int i) {
        this.theta = 1.0d;
        this.margin = 0.0d;
        this.W_MAX = Math.pow(2.0d, 200.0d);
        this.W_MIN = 1.0d / Math.pow(2.0d, 200.0d);
        this.beta2 = 0.95d;
        this.votedCount = 0;
        this.LIST_SIZE = 5;
        if (d < 1.0d || d2 < 0.0d || d2 > 1.0d) {
            System.out.println("Error in BalancedWinnow initial parameters");
            System.out.println("Possible problem: (theta<0)||(alpha < 1)||(beta<0)||(beta>1)");
            System.exit(0);
        }
        this.alpha = d;
        this.beta = d2;
        this.voted = z;
        if (i == 0) {
            this.regret = false;
        } else {
            this.regret = true;
            this.mode = i;
        }
        reset();
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void reset() {
        this.pos_t = new Hyperplane();
        this.excount = 0;
        this.votedCount = 0;
        if (this.voted) {
            this.vpos_t = new Hyperplane();
        }
        if (this.regret) {
            this.lossH = new Hyperplane();
            this.lossF = new Hyperplane();
            if (this.mode == 4) {
                this.fmap = new HashMap();
            }
            if (this.mode == 2) {
                this.numGivenPos = new Hyperplane();
                this.numGivenNeg = new Hyperplane();
            }
        }
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void addExample(Example example) {
        this.excount++;
        Example normalizeWeights = Winnow.normalizeWeights(example, true);
        Feature.Looper featureIterator = normalizeWeights.asInstance().featureIterator();
        while (featureIterator.hasNext()) {
            Feature nextFeature = featureIterator.nextFeature();
            if (this.mode == 2) {
                if (normalizeWeights.getLabel().isPositive()) {
                    this.numGivenPos.increment(nextFeature, 1.0d);
                } else {
                    this.numGivenNeg.increment(nextFeature, 1.0d);
                }
            }
            if (!this.pos_t.hasFeature(nextFeature)) {
                this.pos_t.increment(nextFeature, 1.0d);
                if (this.mode == 4 && !this.fmap.containsKey(nextFeature)) {
                    this.fmap.put(nextFeature, new ArrayList(6));
                }
            }
            if (this.mode == 4) {
                ArrayList arrayList = (ArrayList) this.fmap.get(nextFeature);
                arrayList.add(0, normalizeWeights.getLabel());
                if (arrayList.size() > 6) {
                    arrayList.remove(6);
                }
            }
        }
        double numericLabel = normalizeWeights.getLabel().numericLabel();
        double score = this.pos_t.score(normalizeWeights.asInstance()) - this.theta;
        if (numericLabel * score <= this.margin) {
            if (this.voted) {
                if (this.votedCount == 0) {
                    updateVotedHyperplane(1);
                } else {
                    updateVotedHyperplane(this.votedCount);
                }
                this.votedCount = 1;
            }
            if (normalizeWeights.getLabel().isPositive()) {
                Feature.Looper featureIterator2 = normalizeWeights.featureIterator();
                while (featureIterator2.hasNext()) {
                    Feature nextFeature2 = featureIterator2.nextFeature();
                    if (this.pos_t.featureScore(nextFeature2) < this.W_MAX) {
                        this.pos_t.multiply(nextFeature2, this.alpha);
                    }
                }
            } else {
                Feature.Looper featureIterator3 = normalizeWeights.featureIterator();
                while (featureIterator3.hasNext()) {
                    Feature nextFeature3 = featureIterator3.nextFeature();
                    if (this.pos_t.featureScore(nextFeature3) > this.W_MIN) {
                        this.pos_t.multiply(nextFeature3, this.beta);
                    }
                }
            }
        } else if (this.voted) {
            this.votedCount++;
        }
        if (this.regret) {
            Feature.Looper featureIterator4 = normalizeWeights.featureIterator();
            while (featureIterator4.hasNext()) {
                double d = 0.0d;
                Feature nextFeature4 = featureIterator4.nextFeature();
                if (numericLabel * score <= this.margin) {
                    this.lossH.increment(nextFeature4, 1.0d);
                    double d2 = 0.0d + 1.0d;
                }
                if (this.mode == 1 && numericLabel * normalizeWeights.getWeight(nextFeature4) <= 0.0d) {
                    this.lossF.increment(nextFeature4, 1.0d);
                    d = 0.0d + 1.0d;
                }
                if (this.mode == 2) {
                    double featureScore = this.numGivenPos.featureScore(nextFeature4);
                    double featureScore2 = this.numGivenNeg.featureScore(nextFeature4);
                    double d3 = featureScore + featureScore2;
                    if (normalizeWeights.getLabel().isPositive()) {
                        this.lossF.increment(nextFeature4, 1.0d - (featureScore / d3));
                    } else {
                        this.lossF.increment(nextFeature4, 1.0d - (featureScore2 / d3));
                    }
                }
                if (this.mode == 3) {
                    int estimatedSize = example.featureIterator().estimatedSize();
                    if (normalizeWeights.getLabel().isNegative() && normalizeWeights.getWeight(nextFeature4) * this.pos_t.featureScore(nextFeature4) * estimatedSize > 1.0d) {
                        this.lossF.increment(nextFeature4, 1.0d);
                        d += 1.0d;
                    } else if (normalizeWeights.getLabel().isPositive() && normalizeWeights.getWeight(nextFeature4) * this.pos_t.featureScore(nextFeature4) * estimatedSize < 1.0d) {
                        this.lossF.increment(nextFeature4, 1.0d);
                        d += 1.0d;
                    }
                }
                if (this.mode == 4 && numericLabel * getHistory((ArrayList) this.fmap.get(nextFeature4)) < 0.0d) {
                    this.lossF.increment(nextFeature4, 1.0d);
                    double d4 = d + 1.0d;
                }
                if (this.mode == 5) {
                    this.lossF.increment(nextFeature4, Math.random());
                }
                double pow = Math.pow(this.beta2, this.lossF.featureScore(nextFeature4) - (this.beta2 * this.lossH.featureScore(nextFeature4)));
                if (pow <= 1.0d || pow >= this.W_MAX) {
                    if (pow < 1.0d && pow > this.W_MIN && this.pos_t.featureScore(nextFeature4) > this.W_MIN) {
                        this.pos_t.multiply(nextFeature4, pow);
                    }
                } else if (this.pos_t.featureScore(nextFeature4) < this.W_MAX) {
                    this.pos_t.multiply(nextFeature4, pow);
                }
            }
        }
    }

    public void updateVotedHyperplane(int i) {
        this.vpos_t.increment(this.pos_t, i);
        this.votedCount = 0;
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public Classifier getClassifier() {
        if (!this.voted) {
            return new MyClassifier(this.pos_t, this.theta);
        }
        updateVotedHyperplane(this.votedCount);
        Hyperplane hyperplane = new Hyperplane();
        hyperplane.increment(this.vpos_t, 1.0d / this.excount);
        return new MyClassifier(hyperplane, this.theta);
    }

    public int getHistory(ArrayList arrayList) {
        int i = 0;
        for (int i2 = 1; i2 < arrayList.size(); i2++) {
            i = ((ClassLabel) arrayList.get(i2)).isPositive() ? i + 1 : i - 1;
        }
        if (i == 0) {
            return 1;
        }
        return i;
    }

    public String toString() {
        return "RegretWinnow: voted=" + this.voted + ", regret=" + this.mode;
    }
}
