package edu.cmu.minorthird.classify.algorithms.linear;

import cern.colt.matrix.impl.AbstractFormatter;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.sequential.BeamSearcher;
import edu.cmu.minorthird.classify.sequential.CMM;
import edu.cmu.minorthird.classify.sequential.CRFLearner;
import edu.cmu.minorthird.classify.sequential.SequenceDataset;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.Serializable;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/MaxEntLearner.class */
public class MaxEntLearner extends BatchClassifierLearner {
    private CRFLearner crfLearner;
    private boolean scaleScores;
    public boolean logSpace;

    /* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/MaxEntLearner$MyClassifier.class */
    public static class MyClassifier implements Classifier, Serializable, Visible {
        private static final long serialVersionUID = 1;
        private final int CURRENT_SERIAL_VERSION = 1;
        private Classifier c;
        private ExampleSchema schema;
        private boolean scaleScores;

        public MyClassifier(Classifier classifier, ExampleSchema exampleSchema, boolean z) {
            this.c = classifier;
            this.schema = exampleSchema;
            this.scaleScores = z;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            ClassLabel classification = this.c.classification(BeamSearcher.getBeamInstance(instance, 1));
            return this.scaleScores ? transformScores(classification) : classification;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            Instance beamInstance = BeamSearcher.getBeamInstance(instance, 1);
            return this.scaleScores ? "Augmented instance: " + beamInstance + AbstractFormatter.DEFAULT_ROW_SEPARATOR + this.c.explain(beamInstance) + "\nTransformed score: " + classification(instance) : "Augmented instance: " + beamInstance + AbstractFormatter.DEFAULT_ROW_SEPARATOR + this.c.explain(beamInstance);
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public Explanation getExplanation(Instance instance) {
            Explanation.Node node = new Explanation.Node("MaxEntClassifier Explanation");
            Instance beamInstance = BeamSearcher.getBeamInstance(instance, 1);
            if (this.scaleScores) {
                Explanation.Node node2 = new Explanation.Node("Augmented instance: " + beamInstance);
                String[] split = this.c.explain(beamInstance).split(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
                Explanation.Node node3 = node2;
                for (int i = 0; i < split.length; i++) {
                    Explanation.Node node4 = new Explanation.Node(split[i]);
                    if (split[i].charAt(0) != ' ') {
                        node3 = node4;
                        node2.add(node4);
                    } else {
                        node3.add(node4);
                    }
                }
                node.add(node2);
                node.add(new Explanation.Node("\nTransformed score: " + classification(instance)));
            } else {
                Explanation.Node node5 = new Explanation.Node("Augmented instance: " + beamInstance);
                String[] split2 = this.c.explain(beamInstance).split(AbstractFormatter.DEFAULT_ROW_SEPARATOR);
                Explanation.Node node6 = node5;
                for (int i2 = 0; i2 < split2.length; i2++) {
                    Explanation.Node node7 = new Explanation.Node(split2[i2]);
                    if (split2[i2].charAt(0) != ' ') {
                        node6 = node7;
                        node5.add(node7);
                    } else {
                        node6.add(node7);
                    }
                }
                node.add(node5);
            }
            return new Explanation(node);
        }

        private ClassLabel transformScores(ClassLabel classLabel) {
            double[] dArr = new double[this.schema.getNumberOfClasses()];
            double d = 0.0d;
            for (int i = 0; i < this.schema.getNumberOfClasses(); i++) {
                dArr[i] = Math.exp(classLabel.getWeight(this.schema.getClassName(i)));
                d += dArr[i];
            }
            ClassLabel classLabel2 = new ClassLabel();
            for (int i2 = 0; i2 < this.schema.getNumberOfClasses(); i2++) {
                String className = this.schema.getClassName(i2);
                double d2 = dArr[i2] / d;
                classLabel2.add(className, Math.log(d2 / (1.0d - d2)));
            }
            return classLabel2;
        }

        public Classifier getRawClassifier() {
            return this.c;
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            TransformedViewer transformedViewer = new TransformedViewer(new SmartVanillaViewer()) { // from class: edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner.MyClassifier.1
                @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                public Object transform(Object obj) {
                    return ((MyClassifier) obj).c;
                }
            };
            transformedViewer.setContent(this);
            return transformedViewer;
        }
    }

    public MaxEntLearner() {
        this.scaleScores = false;
        this.logSpace = true;
        this.crfLearner = new CRFLearner("", 1);
    }

    public MaxEntLearner(String str) {
        this.scaleScores = false;
        this.logSpace = true;
        this.crfLearner = new CRFLearner(str, 1);
        if (str.indexOf("scaleScores 1") >= 0) {
            this.scaleScores = true;
            System.out.println("scaleScores => true");
        }
    }

    public void setLogSpace(boolean z) {
        if (z) {
            this.crfLearner.setLogSpaceOption();
        } else {
            this.crfLearner.removeLogSpaceOption();
        }
        this.logSpace = z;
    }

    public boolean getLogSpace() {
        return this.logSpace;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
        this.crfLearner.setSchema(exampleSchema);
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        SequenceDataset sequenceDataset = new SequenceDataset();
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            sequenceDataset.addSequence(new Example[]{it.nextExample()});
        }
        return new MyClassifier(((CMM) this.crfLearner.batchTrain(sequenceDataset)).getClassifier(), sequenceDataset.getSchema(), this.scaleScores);
    }
}
