package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.ArrayList;
import org.apache.log4j.Logger;
import org.jfree.chart.ChartPanelConstants;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/KernelVotedPerceptron.class */
public class KernelVotedPerceptron extends OnlineBinaryClassifierLearner implements Serializable {
    private static final long serialVersionUID = 1;
    private final int CURRENT_SERIAL_VERSION = 1;
    private static Logger log = Logger.getLogger(KernelVotedPerceptron.class);
    private Hyperplane v_k;
    private int c_k;
    private ArrayList listVK;
    private ArrayList listCK;
    private String mode;
    private int degree;
    private double gamma;
    private double coef0;
    private boolean speedup;
    private int MAXVEC;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/KernelVotedPerceptron$MyClassifier.class */
    public class MyClassifier implements Classifier, Serializable, Visible {
        private static final long serialVersionUID = 1;
        private final int CURRENT_SERIAL_VERSION = 1;
        ArrayList listVK;
        ArrayList counts;

        public MyClassifier(ArrayList arrayList, ArrayList arrayList2) {
            this.listVK = arrayList;
            this.counts = arrayList2;
            KernelVotedPerceptron.log.info("info: KernelVotedPerceptron: number sup vectors = " + this.listVK.size() + " mode=" + KernelVotedPerceptron.this.mode + " kernel=" + KernelVotedPerceptron.this.degree);
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            double d = 0.0d;
            if (KernelVotedPerceptron.this.mode.equalsIgnoreCase("voted")) {
                d = calculateVoted(instance);
            } else if (KernelVotedPerceptron.this.mode.equalsIgnoreCase("averaged")) {
                d = calculateAveraged(instance);
            } else {
                System.out.println("Mode(" + KernelVotedPerceptron.this.mode + ") is not allowed\n Please use either \"voted\" or \"averaged\"");
                System.exit(0);
            }
            return d >= 0.0d ? ClassLabel.positiveLabel(d) : ClassLabel.negativeLabel(d);
        }

        private double calculateVoted(Instance instance) {
            double d = 0.0d;
            int size = KernelVotedPerceptron.this.speedup ? this.listVK.size() - Math.min(KernelVotedPerceptron.this.MAXVEC, this.listVK.size()) : 0;
            for (int i = size; i < this.listVK.size(); i++) {
                d += ((Integer) this.counts.get(i)).intValue() * (KernelVotedPerceptron.this.Kernel((Hyperplane) this.listVK.get(i), instance) > 0.0d ? 1.0d : -1.0d);
            }
            return d;
        }

        private double calculateAveraged(Instance instance) {
            double d = 0.0d;
            int size = KernelVotedPerceptron.this.speedup ? this.listVK.size() - Math.min(KernelVotedPerceptron.this.MAXVEC, this.listVK.size()) : 0;
            for (int i = size; i < this.listVK.size(); i++) {
                d += ((Integer) this.counts.get(i)).intValue() * KernelVotedPerceptron.this.Kernel((Hyperplane) this.listVK.get(i), instance);
            }
            return d;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            return "KernelVotedPerceptron: Not implemented yet";
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public Explanation getExplanation(Instance instance) {
            return new Explanation(new Explanation.Node("Kernel Perceptron Explanation (not valid!)"));
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            TransformedViewer transformedViewer = new TransformedViewer(new SmartVanillaViewer()) { // from class: edu.cmu.minorthird.classify.algorithms.linear.KernelVotedPerceptron.MyClassifier.1
                @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                public Object transform(Object obj) {
                    return (Hyperplane) ((MyClassifier) obj).listVK.get(MyClassifier.this.listVK.size() - 1);
                }
            };
            transformedViewer.setContent(this);
            return transformedViewer;
        }
    }

    public KernelVotedPerceptron(int i, String str) {
        this.CURRENT_SERIAL_VERSION = 1;
        this.mode = "voted";
        this.degree = 3;
        this.gamma = 10.0d;
        this.coef0 = 1.0d;
        this.speedup = false;
        this.MAXVEC = ChartPanelConstants.DEFAULT_MINIMUM_DRAW_WIDTH;
        reset();
        this.degree = i;
        this.mode = str;
    }

    public KernelVotedPerceptron() {
        this.CURRENT_SERIAL_VERSION = 1;
        this.mode = "voted";
        this.degree = 3;
        this.gamma = 10.0d;
        this.coef0 = 1.0d;
        this.speedup = false;
        this.MAXVEC = ChartPanelConstants.DEFAULT_MINIMUM_DRAW_WIDTH;
        reset();
    }

    public void setKernel(int i) {
        this.degree = i;
    }

    public void setPolyKernelParams(double d, double d2) {
        this.coef0 = d;
        this.gamma = d2;
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void reset() {
        this.v_k = new Hyperplane();
        this.listVK = new ArrayList();
        this.listCK = new ArrayList();
        this.c_k = 0;
    }

    public void setModeVoted() {
        this.mode = "voted";
    }

    public void setModeAveraged() {
        this.mode = "averaged";
    }

    public void setSpeedUp() {
        this.speedup = true;
    }

    private void store(Hyperplane hyperplane, int i) {
        Hyperplane hyperplane2 = new Hyperplane();
        hyperplane2.increment(hyperplane);
        this.listVK.add(hyperplane2);
        this.listCK.add(new Integer(i));
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void addExample(Example example) {
        double numericLabel = example.getLabel().numericLabel();
        if (Kernel(this.v_k, example.asInstance()) * numericLabel > 0.0d) {
            this.c_k++;
            return;
        }
        store(this.v_k, this.c_k);
        this.v_k.increment(example, numericLabel);
        this.c_k = 1;
    }

    double Kernel(Hyperplane hyperplane, Instance instance) {
        double score = hyperplane.score(instance);
        return this.degree == 0 ? score : Math.pow(this.coef0 + (score * this.gamma), this.degree);
    }

    @Override // edu.cmu.minorthird.classify.OnlineClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public Classifier getClassifier() {
        return new MyClassifier(this.listVK, this.listCK);
    }

    public String toString() {
        return "Kernel Voted Perceptron";
    }
}
