package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.WeightedSet;
import edu.cmu.minorthird.classify.algorithms.random.Arithmetic;
import edu.cmu.minorthird.classify.algorithms.random.Estimate;
import edu.cmu.minorthird.util.MathUtil;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Controllable;
import edu.cmu.minorthird.util.gui.ControlledViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerControls;
import edu.cmu.minorthird.util.gui.Visible;
import gnu.trove.TObjectDoubleHashMap;
import gnu.trove.TObjectDoubleIterator;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.TreeMap;
import javax.swing.ButtonGroup;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JRadioButton;
import javax.swing.JScrollPane;
import javax.swing.JTable;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/MultinomialClassifier.class */
public class MultinomialClassifier implements Classifier, Visible, Serializable {
    private static Logger log = Logger.getLogger(MultinomialClassifier.class);
    private double SCALE;
    private ArrayList classNames = new ArrayList();
    private ArrayList classParameters = new ArrayList();
    private HashMap featureModels = new HashMap();
    private ArrayList featureGivenClassParameters = new ArrayList();
    private double featurePrior;
    private String unseenModel;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/MultinomialClassifier$MultinomialClassifierControls.class */
    private static class MultinomialClassifierControls extends ViewerControls {
        private JRadioButton absoluteValueButton;
        private JRadioButton valueButton;
        private JRadioButton nameButton;
        private JRadioButton noneButton;

        private MultinomialClassifierControls() {
        }

        @Override // edu.cmu.minorthird.util.gui.ViewerControls
        public void initialize() {
            add(new JLabel("Sort by"));
            ButtonGroup buttonGroup = new ButtonGroup();
            this.nameButton = addButton("name", buttonGroup, true);
            this.valueButton = addButton("weight", buttonGroup, false);
            this.absoluteValueButton = addButton("|weight|", buttonGroup, false);
        }

        private JRadioButton addButton(String str, ButtonGroup buttonGroup, boolean z) {
            JRadioButton jRadioButton = new JRadioButton(str, z);
            buttonGroup.add(jRadioButton);
            add(jRadioButton);
            jRadioButton.addActionListener(this);
            return jRadioButton;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/MultinomialClassifier$MyViewer.class */
    private static class MyViewer extends ComponentViewer implements Controllable {
        private MultinomialClassifierControls controls;
        private MultinomialClassifier h;

        private MyViewer() {
            this.controls = null;
            this.h = null;
        }

        @Override // edu.cmu.minorthird.util.gui.Controllable
        public void applyControls(ViewerControls viewerControls) {
            this.controls = (MultinomialClassifierControls) viewerControls;
            setContent(this.h, true);
            revalidate();
        }

        @Override // edu.cmu.minorthird.util.gui.ComponentViewer, edu.cmu.minorthird.util.gui.Viewer
        public boolean canReceive(Object obj) {
            return obj instanceof MultinomialClassifier;
        }

        @Override // edu.cmu.minorthird.util.gui.ComponentViewer
        public JComponent componentFor(Object obj) {
            this.h = (MultinomialClassifier) obj;
            Object[][] objArr = new Object[this.h.keys().length][this.h.classNames.size() + 1];
            int i = 0;
            Feature.Looper featureIterator = this.h.featureIterator();
            while (featureIterator.hasNext()) {
                Feature nextFeature = featureIterator.nextFeature();
                objArr[i][0] = nextFeature;
                for (int i2 = 0; i2 < this.h.classNames.size(); i2++) {
                    objArr[i][i2 + 1] = ((Estimate) ((HashMap) this.h.featureGivenClassParameters.get(i2)).get(nextFeature)).toTableInViewer();
                }
                i++;
            }
            if (this.controls != null) {
                Arrays.sort(objArr, new Comparator() { // from class: edu.cmu.minorthird.classify.algorithms.linear.MultinomialClassifier.MyViewer.1
                    @Override // java.util.Comparator
                    public int compare(Object obj2, Object obj3) {
                        Object[] objArr2 = (Object[]) obj2;
                        Object[] objArr3 = (Object[]) obj3;
                        if (MyViewer.this.controls.nameButton.isSelected()) {
                            return objArr2[0].toString().compareTo(objArr3[0].toString());
                        }
                        Double d = (Double) objArr2[1];
                        Double d2 = (Double) objArr3[1];
                        return MyViewer.this.controls.valueButton.isSelected() ? MathUtil.sign(d2.doubleValue() - d.doubleValue()) : MathUtil.sign(Math.abs(d2.doubleValue()) - Math.abs(d.doubleValue()));
                    }
                });
            }
            String[] strArr = new String[this.h.classNames.size() + 1];
            strArr[0] = "Feature Name";
            for (int i3 = 0; i3 < this.h.classNames.size(); i3++) {
                strArr[i3 + 1] = "Class " + this.h.classNames.get(i3);
            }
            JTable jTable = new JTable(objArr, strArr);
            monitorSelections(jTable, 0);
            return new JScrollPane(jTable);
        }
    }

    public MultinomialClassifier() {
        this.featureGivenClassParameters.add(new WeightedSet());
        this.featurePrior = 0.0d;
        this.unseenModel = null;
    }

    @Override // edu.cmu.minorthird.classify.Classifier
    public ClassLabel classification(Instance instance) {
        double[] score = score(instance);
        int i = 0;
        for (int i2 = 0; i2 < score.length; i2++) {
            if (score[i2] >= score[i]) {
                i = i2;
            }
        }
        return new ClassLabel((String) this.classNames.get(i));
    }

    public double[] score(Instance instance) {
        String str;
        double d = 0.0d;
        Feature.Looper featureIterator = instance.featureIterator();
        while (featureIterator.hasNext()) {
            d += instance.getWeight(featureIterator.nextFeature());
        }
        double[] dArr = new double[this.classNames.size()];
        for (int i = 0; i < this.classNames.size(); i++) {
            dArr[i] = Math.log(((Double) this.classParameters.get(i)).doubleValue());
        }
        Feature.Looper featureIterator2 = instance.featureIterator();
        while (featureIterator2.hasNext()) {
            Feature nextFeature = featureIterator2.nextFeature();
            double weight = instance.getWeight(nextFeature);
            for (int i2 = 0; i2 < this.classNames.size(); i2++) {
                Estimate estimate = (Estimate) ((HashMap) this.featureGivenClassParameters.get(i2)).get(nextFeature);
                try {
                    str = estimate.getModel();
                } catch (NullPointerException e) {
                    str = "unseen";
                }
                if (str.equals("Poisson")) {
                    String parameterization = estimate.getParameterization();
                    if (parameterization.equals("weighted-lambda")) {
                        double doubleValue = ((Double) estimate.getPms().get("lambda")).doubleValue();
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + (((-doubleValue) * d) / this.SCALE) + (weight * Math.log(doubleValue));
                    } else if (parameterization.equals("lambda")) {
                        double doubleValue2 = ((Double) estimate.getPms().get("lambda")).doubleValue();
                        int i4 = i2;
                        dArr[i4] = dArr[i4] + (((-doubleValue2) * d) / this.SCALE) + (weight * Math.log(doubleValue2));
                    }
                } else if (str.equals("Naive-Bayes")) {
                    String parameterization2 = estimate.getParameterization();
                    if (parameterization2.equals("weighted-mean")) {
                        int i5 = i2;
                        dArr[i5] = dArr[i5] + (weight * Math.log(((Double) estimate.getPms().get("mean")).doubleValue()));
                    } else if (parameterization2.equals("mean")) {
                        int i6 = i2;
                        dArr[i6] = dArr[i6] + (weight * Math.log(((Double) estimate.getPms().get("mean")).doubleValue()));
                    }
                } else if (str.equals("Negative-Binomial")) {
                    if (estimate.getParameterization().equals("mu/delta")) {
                        int i7 = i2;
                        dArr[i7] = dArr[i7] + logProbNegativeBinomialMuDelta(weight, d / this.SCALE, estimate.getPms());
                    }
                } else if (str.equals("Binomial")) {
                    String parameterization3 = estimate.getParameterization();
                    if (parameterization3.equals("p/N")) {
                        int i8 = i2;
                        dArr[i8] = dArr[i8] + logProbBinomialPN(weight, d / this.SCALE, estimate.getPms());
                    } else if (parameterization3.equals("mu/delta")) {
                        int i9 = i2;
                        dArr[i9] = dArr[i9] + logProbBinomialMuDelta(weight, d / this.SCALE, estimate.getPms());
                    }
                } else if (str.equals("Dirichlet-Poisson MCMC")) {
                    String parameterization4 = estimate.getParameterization();
                    if (parameterization4.equals("weighted-lambda")) {
                        double doubleValue3 = ((Double) estimate.getPms().get("lambda")).doubleValue();
                        int i10 = i2;
                        dArr[i10] = dArr[i10] + (((-doubleValue3) * d) / this.SCALE) + (weight * Math.log(doubleValue3));
                    } else if (parameterization4.equals("lambda")) {
                        double doubleValue4 = ((Double) estimate.getPms().get("lambda")).doubleValue();
                        int i11 = i2;
                        dArr[i11] = dArr[i11] + (((-doubleValue4) * d) / this.SCALE) + (weight * Math.log(doubleValue4));
                    }
                } else if (str.equals("unseen")) {
                    int i12 = i2;
                    dArr[i12] = dArr[i12] + 0.0d;
                } else {
                    System.out.println("error: model " + str + " not found!");
                    System.exit(1);
                }
            }
        }
        return dArr;
    }

    private double logProbNegativeBinomialMuDelta(double d, double d2, TreeMap treeMap) {
        double d3;
        try {
            double doubleValue = ((Double) treeMap.get("mu")).doubleValue();
            double doubleValue2 = ((Double) treeMap.get("delta")).doubleValue();
            d3 = doubleValue2 == 0.0d ? (d * Math.log(doubleValue)) - (d2 * doubleValue) : ((Arithmetic.logGamma(d + (doubleValue / doubleValue2)) - Arithmetic.logGamma(doubleValue / doubleValue2)) + (d * Math.log(doubleValue2))) - (d * Math.log(1.0d + (d2 * doubleValue2)));
        } catch (Exception e) {
            d3 = 0.0d;
        }
        return d3;
    }

    private double logProbBinomialPN(double d, double d2, TreeMap treeMap) {
        double d3;
        try {
            double doubleValue = ((Double) treeMap.get("p")).doubleValue();
            double doubleValue2 = ((Double) treeMap.get("N")).doubleValue();
            d3 = doubleValue2 == 0.0d ? (d * Math.log(doubleValue)) - (d2 * doubleValue) : (Arithmetic.logFactorial((int) doubleValue2) - Arithmetic.logFactorial(((int) doubleValue2) - ((int) d))) + (d * Math.log(doubleValue)) + ((doubleValue2 - d) * Math.log(1.0d - doubleValue));
        } catch (Exception e) {
            d3 = 0.0d;
        }
        return d3;
    }

    private double logProbBinomialMuDelta(double d, double d2, TreeMap treeMap) {
        double d3;
        try {
            double doubleValue = ((Double) treeMap.get("mu")).doubleValue();
            double doubleValue2 = ((Double) treeMap.get("delta")).doubleValue();
            if (doubleValue2 == 0.0d) {
                d3 = (d * Math.log(doubleValue)) - (d2 * doubleValue);
            } else {
                double round = Math.round(Math.max(doubleValue / doubleValue2, d));
                double min = Math.min(Math.max(1.0E-7d, d2 * doubleValue2), 0.9999999d);
                d3 = (((Arithmetic.logGamma(round + 1.0d) - Arithmetic.logGamma((round - d) + 1.0d)) + (d * Math.log(doubleValue2))) - (d * Math.log(1.0d - min))) + (round * Math.log(1.0d - min));
            }
        } catch (Exception e) {
            d3 = 0.0d;
        }
        return d3;
    }

    @Override // edu.cmu.minorthird.classify.Classifier
    public String explain(Instance instance) {
        StringBuffer stringBuffer = new StringBuffer("");
        Feature.Looper featureIterator = instance.featureIterator();
        while (featureIterator.hasNext()) {
            featureIterator.nextFeature();
            if (stringBuffer.length() > 0) {
                stringBuffer.append("\n + ");
            } else {
                stringBuffer.append("   ");
            }
        }
        stringBuffer.append("\n = " + score(instance));
        return stringBuffer.toString();
    }

    @Override // edu.cmu.minorthird.classify.Classifier
    public Explanation getExplanation(Instance instance) {
        Explanation.Node node = new Explanation.Node("MultinomialClassifier Explanation");
        Explanation.Node node2 = new Explanation.Node("Features");
        Feature.Looper featureIterator = instance.featureIterator();
        while (featureIterator.hasNext()) {
            Feature nextFeature = featureIterator.nextFeature();
            node2.add(new Explanation.Node(nextFeature + "<" + instance.getWeight(nextFeature)));
        }
        node2.add(new Explanation.Node("bias"));
        node.add(node2);
        node.add(new Explanation.Node("\n = " + score(instance)));
        return new Explanation(node);
    }

    public void setScale(double d) {
        this.SCALE = d;
    }

    public void setPrior(double d) {
        this.featurePrior = d;
    }

    public void setUnseenModel(String str) {
        this.unseenModel = str;
    }

    public double getLogLikelihood(Example example) {
        int i = -1;
        int i2 = 0;
        while (true) {
            if (i2 >= this.classNames.size()) {
                break;
            }
            if (this.classNames.get(i2).equals(example.getLabel().bestClassName())) {
                i = i2;
                break;
            }
            i2++;
        }
        Instance asInstance = example.asInstance();
        double d = 0.0d;
        Feature.Looper featureIterator = asInstance.featureIterator();
        while (featureIterator.hasNext()) {
            Feature nextFeature = featureIterator.nextFeature();
            double weight = asInstance.getWeight(nextFeature);
            double weight2 = ((WeightedSet) this.featureGivenClassParameters.get(i)).getWeight(nextFeature);
            ((Double) this.classParameters.get(i)).doubleValue();
            String featureModel = getFeatureModel(nextFeature);
            if (featureModel.equals("Poisson")) {
                d += (-weight2) + (weight * Math.log(weight2));
            } else if (featureModel.equals("Naive-Bayes")) {
                d += weight * Math.log(weight2);
            } else if (featureModel.equals("unseen")) {
                System.out.println("unseen: " + nextFeature);
            } else {
                System.out.println("error: model " + featureModel + " not found!");
                System.exit(1);
            }
        }
        return d;
    }

    public void reset() {
        this.classParameters = new ArrayList();
        this.featureGivenClassParameters = new ArrayList();
    }

    public boolean isPresent(ClassLabel classLabel) {
        boolean z = false;
        for (int i = 0; i < this.classNames.size(); i++) {
            if (this.classNames.get(i).equals(classLabel.bestClassName())) {
                z = true;
            }
        }
        return z;
    }

    public void addValidLabel(ClassLabel classLabel) {
        this.classNames.add(classLabel.bestClassName());
    }

    public ClassLabel getLabel(int i) {
        return new ClassLabel((String) this.classNames.get(i));
    }

    public int indexOf(ClassLabel classLabel) {
        return this.classNames.indexOf(classLabel.bestClassName());
    }

    public void setFeatureGivenClassParameter(Feature feature, int i, Estimate estimate) {
        try {
            HashMap hashMap = (HashMap) this.featureGivenClassParameters.get(i);
            hashMap.put(feature, estimate);
            this.featureGivenClassParameters.set(i, hashMap);
        } catch (Exception e) {
            HashMap hashMap2 = new HashMap();
            hashMap2.put(feature, estimate);
            this.featureGivenClassParameters.add(i, hashMap2);
        }
    }

    public void setFeatureGivenClassParameter(Feature feature, int i, double d) {
        System.out.println("Should not happen!");
    }

    public void setClassParameter(int i, double d) {
        try {
            this.classParameters.get(i);
        } catch (Exception e) {
            this.classParameters.add(i, new Double(d));
        }
    }

    public void setFeatureModel(Feature feature, String str) {
        this.featureModels.put(feature, str);
    }

    public String getFeatureModel(Feature feature) {
        try {
            return this.featureModels.get(feature).toString();
        } catch (NullPointerException e) {
            return "unseen";
        }
    }

    public Feature.Looper featureIterator() {
        TObjectDoubleHashMap tObjectDoubleHashMap = new TObjectDoubleHashMap();
        for (int i = 0; i < this.classNames.size(); i++) {
            Iterator it = ((HashMap) this.featureGivenClassParameters.get(i)).keySet().iterator();
            while (it.hasNext()) {
                tObjectDoubleHashMap.put((Feature) it.next(), 0.0d);
            }
        }
        final TObjectDoubleIterator it2 = tObjectDoubleHashMap.iterator();
        return new Feature.Looper(new Iterator() { // from class: edu.cmu.minorthird.classify.algorithms.linear.MultinomialClassifier.1
            @Override // java.util.Iterator
            public boolean hasNext() {
                return it2.hasNext();
            }

            @Override // java.util.Iterator
            public Object next() {
                it2.advance();
                return it2.key();
            }

            @Override // java.util.Iterator
            public void remove() {
                it2.remove();
            }
        });
    }

    public Object[] keys() {
        TObjectDoubleHashMap tObjectDoubleHashMap = new TObjectDoubleHashMap();
        for (int i = 0; i < this.classNames.size(); i++) {
            Iterator it = ((HashMap) this.featureGivenClassParameters.get(i)).keySet().iterator();
            while (it.hasNext()) {
                tObjectDoubleHashMap.put((Feature) it.next(), 0.0d);
            }
        }
        return tObjectDoubleHashMap.keys();
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        ControlledViewer controlledViewer = new ControlledViewer(new MyViewer(), new MultinomialClassifierControls());
        controlledViewer.setContent(this);
        return controlledViewer;
    }

    public String toString() {
        return null;
    }
}
