package edu.cmu.minorthird.classify.sequential;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Hashtable;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;

/* loaded from: input_file:edu/cmu/minorthird/classify/sequential/MultiClassHMMClassifier.class */
public class MultiClassHMMClassifier implements SequenceClassifier, SequenceConstants, Visible, Serializable {
    private ExampleSchema schema;
    public HMM hmmModel;
    private int numStates;
    private int numEmissions;
    String[] state;
    double[][] aprob;
    double[][] eprob;
    ArrayList training_seq;
    private Hashtable dict_tok;
    private Hashtable dict_tok2idx;
    private Hashtable dict_idx2tok;

    public MultiClassHMMClassifier(SequenceDataset sequenceDataset) {
        this.schema = sequenceDataset.getSchema();
        this.numStates = this.schema.getNumberOfClasses();
        this.state = new String[this.numStates];
        for (int i = 0; i < this.schema.getNumberOfClasses(); i++) {
            this.state[i] = this.schema.getClassName(i);
        }
        this.dict_tok = new Hashtable();
        this.training_seq = new ArrayList();
        Iterator sequenceIterator = sequenceDataset.sequenceIterator();
        while (sequenceIterator.hasNext()) {
            Example[] exampleArr = (Example[]) sequenceIterator.next();
            String[] strArr = new String[exampleArr.length];
            int[] iArr = new int[exampleArr.length];
            for (int i2 = 0; i2 < exampleArr.length; i2++) {
                exampleArr[i2].getLabel();
                String part = exampleArr[i2].numericFeatureIterator().nextFeature().getPart(exampleArr[i2].numericFeatureIterator().nextFeature().size() - 1);
                strArr[i2] = part;
                if (this.dict_tok.containsKey(part)) {
                    this.dict_tok.put(part, String.valueOf(Integer.parseInt((String) this.dict_tok.get(part)) + 1));
                } else {
                    this.dict_tok.put(part, "1");
                }
            }
            this.training_seq.add(strArr);
        }
        this.dict_tok.put("UNSEEN", "1");
        this.numEmissions = this.dict_tok.size();
        this.aprob = new double[this.numStates][this.numStates];
        this.eprob = new double[this.numStates][this.numEmissions];
        this.hmmModel = new HMM(this.state, this.aprob, this.dict_tok, this.eprob);
    }

    public void baumwelch(double d) {
        ArrayList arrayList = new ArrayList(this.training_seq.size());
        for (int i = 0; i < this.training_seq.size(); i++) {
            arrayList.add(this.hmmModel.convert_Ob_seq((String[]) this.training_seq.get(i)));
        }
        HMM hmm = this.hmmModel;
        this.hmmModel = HMM.baumwelch(arrayList, this.state, this.dict_tok, d);
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
    public ClassLabel[] classification(Instance[] instanceArr) {
        ClassLabel[] classLabelArr = new ClassLabel[instanceArr.length];
        String[] strArr = new String[instanceArr.length];
        for (int i = 0; i < instanceArr.length; i++) {
            strArr[i] = instanceArr[i].numericFeatureIterator().nextFeature().getPart(instanceArr[i].numericFeatureIterator().nextFeature().size() - 1);
            System.out.println("ob_seq[" + i + "] is " + strArr[i]);
        }
        String[] path = new Viterbi(this.hmmModel, this.hmmModel.convert_Ob_seq(strArr)).getPath();
        for (int i2 = 0; i2 < path.length; i2++) {
            classLabelArr[i2] = new ClassLabel(path[i2]);
            System.out.println("tag_seq[" + i2 + "] is " + path[i2]);
        }
        return classLabelArr;
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
    public String explain(Instance[] instanceArr) {
        StringBuffer stringBuffer = new StringBuffer("");
        for (int i = 0; i < this.numStates; i++) {
            stringBuffer.append("Hyperplane for class " + this.schema.getClassName(i) + ":\n");
            stringBuffer.append(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        }
        return stringBuffer.toString();
    }

    @Override // edu.cmu.minorthird.classify.sequential.SequenceClassifier
    public Explanation getExplanation(Instance[] instanceArr) {
        Explanation.Node node = new Explanation.Node("MultiClassHMM Explanation");
        for (int i = 0; i < this.numStates; i++) {
            node.add(new Explanation.Node("Hyperplane for class " + this.schema.getClassName(i) + ":\n"));
        }
        return new Explanation(node);
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        ComponentViewer componentViewer = new ComponentViewer() { // from class: edu.cmu.minorthird.classify.sequential.MultiClassHMMClassifier.1
            @Override // edu.cmu.minorthird.util.gui.ComponentViewer
            public JComponent componentFor(Object obj) {
                MultiClassHMMClassifier multiClassHMMClassifier = (MultiClassHMMClassifier) obj;
                JPanel jPanel = new JPanel();
                for (int i = 0; i < MultiClassHMMClassifier.this.numStates; i++) {
                    JPanel jPanel2 = new JPanel();
                    jPanel2.setBorder(new TitledBorder("Class " + multiClassHMMClassifier.schema.getClassName(i)));
                    jPanel.add(jPanel2);
                }
                return new JScrollPane(jPanel);
            }
        };
        componentViewer.setContent(this);
        return componentViewer;
    }

    public String toString() {
        return "[MultiClassHMMClassifier:";
    }
}
