package edu.cmu.minorthird.text.learn;

import com.wcohen.ss.BasicStringWrapper;
import com.wcohen.ss.DistanceLearnerFactory;
import com.wcohen.ss.api.StringDistance;
import com.wcohen.ss.api.StringDistanceLearner;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.lookup.SoftDictionary;
import edu.cmu.minorthird.classify.BinaryClassifier;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.OnlineBinaryClassifierLearner;
import edu.cmu.minorthird.classify.algorithms.linear.VotedPerceptron;
import edu.cmu.minorthird.classify.sequential.InstanceFromSequence;
import edu.cmu.minorthird.text.AbstractAnnotator;
import edu.cmu.minorthird.text.Annotator;
import edu.cmu.minorthird.text.BasicSpanLooper;
import edu.cmu.minorthird.text.EmptyLabels;
import edu.cmu.minorthird.text.MonotonicTextLabels;
import edu.cmu.minorthird.text.Span;
import edu.cmu.minorthird.text.TextLabels;
import edu.cmu.minorthird.text.learn.SampleFE;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ComponentViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import edu.cmu.minorthird.util.gui.Visible;
import java.awt.BorderLayout;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import javax.swing.JComponent;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.border.TitledBorder;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/text/learn/ConditionalSemiMarkovModel.class */
public class ConditionalSemiMarkovModel {
    private static Logger log;
    private static final boolean DEBUG;
    static Class class$edu$cmu$minorthird$text$learn$ConditionalSemiMarkovModel;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/cmu/minorthird/text/learn/ConditionalSemiMarkovModel$BackPointer.class */
    public static class BackPointer {
        public Span span;
        public int lastT;
        public int lastY;
        public boolean onBestPath = false;

        public BackPointer(Span span, int i, int i2) {
            this.span = span;
            this.lastT = i;
            this.lastY = i2;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/text/learn/ConditionalSemiMarkovModel$CSMMAnnotator.class */
    public static class CSMMAnnotator extends AbstractAnnotator implements Visible, ExtractorAnnotator, Serializable {
        private static final long serialVersionUID = 1;
        private final int CURRENT_VERSION_NUMBER = 1;
        private SpanFeatureExtractor fe;
        private BinaryClassifier classifier;
        private String annotationType;
        private int maxSegSize;

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            ComponentViewer componentViewer = new ComponentViewer(this) { // from class: edu.cmu.minorthird.text.learn.ConditionalSemiMarkovModel.1
                private final CSMMAnnotator this$0;

                {
                    this.this$0 = this;
                }

                @Override // edu.cmu.minorthird.util.gui.ComponentViewer
                public JComponent componentFor(Object obj) {
                    JPanel jPanel = new JPanel();
                    jPanel.setLayout(new BorderLayout());
                    jPanel.add(new JLabel(new StringBuffer().append("CSMM: segsize ").append(this.this$0.maxSegSize).toString()), "North");
                    SmartVanillaViewer smartVanillaViewer = new SmartVanillaViewer(((CSMMAnnotator) obj).classifier);
                    smartVanillaViewer.setSuperView(this);
                    jPanel.add(smartVanillaViewer, "South");
                    jPanel.setBorder(new TitledBorder("Conditional Semi-Markov-Model"));
                    return new JScrollPane(jPanel);
                }
            };
            componentViewer.setContent(this);
            return componentViewer;
        }

        public CSMMAnnotator(SpanFeatureExtractor spanFeatureExtractor, BinaryClassifier binaryClassifier, String str, int i) {
            this.fe = spanFeatureExtractor;
            this.classifier = binaryClassifier;
            this.annotationType = str;
            this.maxSegSize = i;
        }

        @Override // edu.cmu.minorthird.text.learn.ExtractorAnnotator
        public String getSpanType() {
            return this.annotationType;
        }

        @Override // edu.cmu.minorthird.text.AbstractAnnotator
        public void doAnnotate(MonotonicTextLabels monotonicTextLabels) {
            ProgressCounter progressCounter = new ProgressCounter("annotating", "document", monotonicTextLabels.getTextBase().size());
            Span.Looper documentSpanIterator = monotonicTextLabels.getTextBase().documentSpanIterator();
            while (documentSpanIterator.hasNext()) {
                Span.Looper it = ConditionalSemiMarkovModel.bestSegments(documentSpanIterator.nextSpan(), monotonicTextLabels, this.fe, this.classifier, this.maxSegSize).iterator();
                while (it.hasNext()) {
                    monotonicTextLabels.addToType((Span) it.next(), this.annotationType);
                }
                progressCounter.progress();
            }
            progressCounter.finished();
        }

        @Override // edu.cmu.minorthird.text.AbstractAnnotator, edu.cmu.minorthird.text.Annotator
        public String explainAnnotation(TextLabels textLabels, Span span) {
            return "not implemented";
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/text/learn/ConditionalSemiMarkovModel$CSMMLearner.class */
    public static class CSMMLearner implements AnnotatorLearner {
        private SpanFeatureExtractor fe;
        private OnlineBinaryClassifierLearner classifierLearner;
        private int epochs;
        private int maxSegmentSize;
        private Span.Looper documentLooper;
        private List exampleList;
        private String annotationType;

        public CSMMLearner() {
            this(new CSMMSpanFE(), new VotedPerceptron(), 5, 5, "");
        }

        public CSMMLearner(int i) {
            this(new CSMMSpanFE(), new VotedPerceptron(), i, 5, "");
        }

        public CSMMLearner(int i, int i2) {
            this(new CSMMSpanFE(), new VotedPerceptron(), i, i2, "");
        }

        public CSMMLearner(String str) {
            this(new CSMMSpanFE(), new VotedPerceptron(), 5, 5, str);
        }

        public CSMMLearner(String str, String str2, int i) {
            this(str, str2, 5, i);
        }

        public CSMMLearner(String str, String str2, int i, int i2) {
            this(str, str2, i, i2, "");
        }

        public CSMMLearner(String str, String str2, int i, int i2, String str3) {
            this(str, str2, i, i2, false, str3);
        }

        public CSMMLearner(String str, String str2, int i, int i2, boolean z, boolean z2, String str3) {
            this(new CSMMWithDictionarySpanFE(str, str2, z, z2), new VotedPerceptron(), i, i2, str3);
        }

        public CSMMLearner(String str, String str2, int i, int i2, boolean z, String str3) {
            this(str, str2, i, i2, z, true, str3);
        }

        public CSMMLearner(SpanFeatureExtractor spanFeatureExtractor, OnlineBinaryClassifierLearner onlineBinaryClassifierLearner, int i, int i2, String str) {
            this.maxSegmentSize = 5;
            this.fe = spanFeatureExtractor;
            if (str.length() > 0) {
                System.out.println("Reading annotations");
                ((CSMMSpanFE) spanFeatureExtractor).setRequiredAnnotation(str, new StringBuffer().append(str).append(".mixup").toString());
                ((CSMMSpanFE) spanFeatureExtractor).setTokenPropertyFeatures("*");
            }
            this.classifierLearner = onlineBinaryClassifierLearner;
            this.epochs = i;
            this.maxSegmentSize = i2;
            reset();
        }

        public OnlineBinaryClassifierLearner getLearner() {
            return this.classifierLearner;
        }

        public void setLearner(OnlineBinaryClassifierLearner onlineBinaryClassifierLearner) {
            this.classifierLearner = onlineBinaryClassifierLearner;
        }

        public int getEpochs() {
            return this.epochs;
        }

        public void setEpochs(int i) {
            this.epochs = i;
        }

        public int getMaxSegmentSize() {
            return this.maxSegmentSize;
        }

        public void setMaxSegmentSize(int i) {
            this.maxSegmentSize = i;
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public SpanFeatureExtractor getSpanFeatureExtractor() {
            return this.fe;
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public void setSpanFeatureExtractor(SpanFeatureExtractor spanFeatureExtractor) {
            this.fe = spanFeatureExtractor;
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public void reset() {
            this.exampleList = new ArrayList();
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public void setDocumentPool(Span.Looper looper) {
            this.documentLooper = looper;
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public boolean hasNextQuery() {
            return this.documentLooper.hasNext();
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public Span nextQuery() {
            return this.documentLooper.nextSpan();
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public void setAnswer(AnnotationExample annotationExample) {
            this.exampleList.add(annotationExample);
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public void setAnnotationType(String str) {
            this.annotationType = str;
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public String getAnnotationType() {
            return this.annotationType;
        }

        @Override // edu.cmu.minorthird.text.learn.AnnotatorLearner
        public Annotator getAnnotator() {
            this.classifierLearner.reset();
            ConditionalSemiMarkovModel.log.debug(new StringBuffer().append("processing ").append(this.exampleList.size()).append(" examples for ").append(this.epochs).append(" epochs").toString());
            ProgressCounter progressCounter = new ProgressCounter("training CSMM", "document", this.epochs * this.exampleList.size());
            if (this.fe.getClass().getName().endsWith("CSMMWithDictionarySpanFE")) {
                ((CSMMWithDictionarySpanFE) this.fe).train(this.exampleList.iterator());
            }
            for (int i = 0; i < this.epochs; i++) {
                for (AnnotationExample annotationExample : this.exampleList) {
                    Span documentSpan = annotationExample.getDocumentSpan();
                    if (ConditionalSemiMarkovModel.DEBUG) {
                        ConditionalSemiMarkovModel.log.debug(new StringBuffer().append("updating from ").append(documentSpan).toString());
                    }
                    Segments bestSegments = ConditionalSemiMarkovModel.bestSegments(documentSpan, annotationExample.getLabels(), this.fe, this.classifierLearner.getBinaryClassifier(), this.maxSegmentSize);
                    if (ConditionalSemiMarkovModel.DEBUG) {
                        ConditionalSemiMarkovModel.log.debug(new StringBuffer().append("viterbi solution:\n").append(bestSegments).toString());
                    }
                    Segments correctSegments = correctSegments(annotationExample);
                    if (ConditionalSemiMarkovModel.DEBUG) {
                        ConditionalSemiMarkovModel.log.debug(new StringBuffer().append("correct spans:\n").append(correctSegments).toString());
                    }
                    Span span = null;
                    Span.Looper it = bestSegments.iterator();
                    while (it.hasNext()) {
                        Span nextSpan = it.nextSpan();
                        if (!correctSegments.contains(nextSpan)) {
                            if (ConditionalSemiMarkovModel.DEBUG) {
                                ConditionalSemiMarkovModel.log.debug(new StringBuffer().append("false pos: ").append(nextSpan).toString());
                            }
                            this.classifierLearner.addExample(exampleFor(annotationExample, nextSpan, span, -1.0d));
                        }
                        span = nextSpan;
                    }
                    Span span2 = null;
                    Span.Looper it2 = correctSegments.iterator();
                    while (it2.hasNext()) {
                        Span nextSpan2 = it2.nextSpan();
                        if (!bestSegments.contains(nextSpan2)) {
                            if (ConditionalSemiMarkovModel.DEBUG) {
                                ConditionalSemiMarkovModel.log.debug(new StringBuffer().append("false neg: ").append(nextSpan2).toString());
                            }
                            this.classifierLearner.addExample(exampleFor(annotationExample, nextSpan2, span2, 1.0d));
                        }
                        span2 = nextSpan2;
                    }
                    progressCounter.progress();
                }
                if (ConditionalSemiMarkovModel.DEBUG) {
                    new ViewerFrame(new StringBuffer().append("classifier after epoch ").append(i).toString(), new SmartVanillaViewer(this.classifierLearner.getBinaryClassifier()));
                }
                progressCounter.finished();
            }
            return new CSMMAnnotator(this.fe, this.classifierLearner.getBinaryClassifier(), this.annotationType, this.maxSegmentSize);
        }

        private Example exampleFor(AnnotationExample annotationExample, Span span, Span span2, double d) {
            InstanceFromSequence instanceFromSequence = new InstanceFromSequence(this.fe.extractInstance(annotationExample.getLabels(), span), new String[]{(span2 == null || !span2.getRightBoundary().equals(span.getLeftBoundary())) ? ExampleSchema.NEG_CLASS_NAME : ExampleSchema.POS_CLASS_NAME});
            if (ConditionalSemiMarkovModel.DEBUG) {
                ConditionalSemiMarkovModel.log.debug(new StringBuffer().append("example for ").append(span).append(": ").append(instanceFromSequence).toString());
            }
            return new Example(instanceFromSequence, ClassLabel.binaryLabel(d));
        }

        private Segments correctSegments(AnnotationExample annotationExample) {
            TreeSet treeSet = new TreeSet();
            String documentId = annotationExample.getDocumentSpan().getDocumentId();
            Span.Looper instanceIterator = annotationExample.getLabels().instanceIterator(annotationExample.getInputType(), documentId);
            while (instanceIterator.hasNext()) {
                treeSet.add(instanceIterator.nextSpan());
            }
            return new Segments(treeSet);
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/text/learn/ConditionalSemiMarkovModel$CSMMSpanFE.class */
    public static class CSMMSpanFE extends SampleFE.ExtractionFE {
        public CSMMSpanFE() {
        }

        public CSMMSpanFE(int i) {
            super(i);
        }

        @Override // edu.cmu.minorthird.text.learn.SampleFE.ExtractionFE, edu.cmu.minorthird.text.learn.SpanFE
        public void extractFeatures(Span span) {
            extractFeatures(new EmptyLabels(), span);
        }

        @Override // edu.cmu.minorthird.text.learn.SampleFE.ExtractionFE, edu.cmu.minorthird.text.learn.SpanFE
        public void extractFeatures(TextLabels textLabels, Span span) {
            super.extractFeatures(textLabels, span);
            from(span).eq().lc().emit();
            if (this.useCharType) {
                from(span).eq().charTypes().emit();
            }
            if (this.useCompressedCharType) {
                from(span).eq().charTypePattern().emit();
            }
            from(span).size().emit();
            from(span).exactSize().emit();
            from(span).token(0).eq().lc().emit();
            from(span).token(-1).eq().lc().emit();
            if (this.useCharType) {
                from(span).token(0).eq().charTypes().lc().emit();
                from(span).token(-1).eq().charTypes().lc().emit();
            }
            if (this.useCompressedCharType) {
                from(span).token(0).eq().charTypePattern().lc().emit();
                from(span).token(-1).eq().charTypePattern().lc().emit();
            }
            for (int i = 0; i < this.tokenPropertyFeatures.length; i++) {
                String str = this.tokenPropertyFeatures[i];
                from(span).token(0).prop(str).emit();
                from(span).token(-1).prop(str).emit();
                from(span).subSpan(1, span.size() - 2).tokens().prop(str).emit();
            }
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/text/learn/ConditionalSemiMarkovModel$CSMMWithDictionarySpanFE.class */
    public static class CSMMWithDictionarySpanFE extends CSMMSpanFE {
        boolean addTrainingSegsToDictionary;
        boolean useCrossVal;
        SoftDictionary dictionary;
        StringDistance[] distances;
        Feature[] features;

        public CSMMWithDictionarySpanFE(String str, String str2) {
            this(str, str2, false, false);
        }

        public CSMMWithDictionarySpanFE(String str, String str2, boolean z, boolean z2) {
            try {
                this.addTrainingSegsToDictionary = z;
                this.useCrossVal = z2;
                this.dictionary = new SoftDictionary();
                this.distances = DistanceLearnerFactory.buildArray(str2);
                if (str.length() > 0) {
                    this.dictionary.load(new File(str));
                    trainDistances();
                }
                this.features = new Feature[this.distances.length];
                for (int i = 0; i < this.distances.length; i++) {
                    this.features[i] = new Feature(this.distances[i].toString());
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public void trainDistances() {
            for (int i = 0; i < this.distances.length; i++) {
                if (this.distances[i] instanceof StringDistanceLearner) {
                    this.distances[i] = this.dictionary.getTeacher().train((StringDistanceLearner) this.distances[i]);
                }
            }
        }

        public void train(Iterator it) {
            if (this.addTrainingSegsToDictionary) {
                int i = 0;
                while (it.hasNext()) {
                    AnnotationExample annotationExample = (AnnotationExample) it.next();
                    String documentId = annotationExample.getDocumentSpan().getDocumentId();
                    Span.Looper instanceIterator = annotationExample.getLabels().instanceIterator(annotationExample.getInputType(), documentId);
                    while (instanceIterator.hasNext()) {
                        i++;
                        this.dictionary.put(documentId, instanceIterator.nextSpan().asString(), (Object) null);
                    }
                }
                trainDistances();
            }
        }

        @Override // edu.cmu.minorthird.text.learn.ConditionalSemiMarkovModel.CSMMSpanFE, edu.cmu.minorthird.text.learn.SampleFE.ExtractionFE, edu.cmu.minorthird.text.learn.SpanFE
        public void extractFeatures(TextLabels textLabels, Span span) {
            super.extractFeatures(textLabels, span);
            BasicStringWrapper basicStringWrapper = new BasicStringWrapper(span.asString());
            Object lookup = this.dictionary.lookup((this.addTrainingSegsToDictionary && this.useCrossVal) ? span.getDocumentId() : null, basicStringWrapper);
            if (lookup != null) {
                for (int i = 0; i < this.distances.length; i++) {
                    double score = this.distances[i].score(basicStringWrapper, (StringWrapper) lookup);
                    if (score != 0.0d) {
                        this.instance.addNumeric(this.features[i], score);
                    }
                }
            }
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/text/learn/ConditionalSemiMarkovModel$Segments.class */
    public static class Segments {
        private Set spanSet;

        public Segments(Set set) {
            this.spanSet = set;
        }

        public Span.Looper iterator() {
            return new BasicSpanLooper(this.spanSet.iterator());
        }

        public boolean contains(Span span) {
            return this.spanSet.contains(span);
        }

        public String toString() {
            return new StringBuffer().append("[Segments: ").append(this.spanSet.toString()).append("]").toString();
        }
    }

    public static Segments bestSegments(Span span, TextLabels textLabels, SpanFeatureExtractor spanFeatureExtractor, BinaryClassifier binaryClassifier, int i) {
        double[][] dArr = new double[span.size() + 1][2];
        BackPointer[][] backPointerArr = new BackPointer[span.size() + 1][2];
        for (int i2 = 0; i2 < span.size() + 1; i2++) {
            for (int i3 = 0; i3 < 2; i3++) {
                dArr[i2][i3] = -99999.0d;
                backPointerArr[i2][i3] = null;
            }
        }
        double[] dArr2 = dArr[0];
        dArr[0][1] = 0.0d;
        dArr2[0] = 0.0d;
        for (int i4 = 0; i4 < span.size() + 1; i4++) {
            int i5 = 0;
            while (i5 < 2) {
                for (int i6 = 0; i6 < 2; i6++) {
                    for (int max = Math.max(0, i4 - (i5 == 0 ? 1 : i)); max < i4; max++) {
                        Span subSpan = span.subSpan(max, i4 - max);
                        double score = score(textLabels, i6, i5, max, i4, subSpan, spanFeatureExtractor, binaryClassifier);
                        if (score + dArr[max][i6] > dArr[i4][i5]) {
                            dArr[i4][i5] = score + dArr[max][i6];
                            backPointerArr[i4][i5] = new BackPointer(subSpan, max, i6);
                        }
                    }
                }
                i5++;
            }
        }
        int i7 = dArr[span.size()][1] > dArr[span.size()][0] ? 1 : 0;
        TreeSet treeSet = new TreeSet();
        BackPointer backPointer = backPointerArr[span.size()][i7];
        while (true) {
            BackPointer backPointer2 = backPointer;
            if (backPointer2 == null) {
                break;
            }
            backPointer2.onBestPath = true;
            if (i7 == 1) {
                treeSet.add(backPointer2.span);
            }
            i7 = backPointer2.lastY;
            backPointer = backPointerArr[backPointer2.lastT][backPointer2.lastY];
        }
        if (DEBUG) {
            dumpStuff(dArr, backPointerArr);
        }
        return new Segments(treeSet);
    }

    private static void dumpStuff(double[][] dArr, BackPointer[][] backPointerArr) {
        DecimalFormat decimalFormat = new DecimalFormat("####.###");
        System.out.println("t.y\tf(t,y)\tt'.y'\tspan");
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < 2; i2++) {
                BackPointer backPointer = backPointerArr[i][i2];
                String asString = backPointer == null ? "*NULL*" : backPointer.span.asString();
                if (backPointer == null) {
                    backPointer = new BackPointer((Span) null, -1, -1);
                }
                System.out.println(new StringBuffer().append(i).append(".").append(i2).append("\t").append(decimalFormat.format(dArr[i][i2])).append("\t").append(backPointer.lastT).append(".").append(backPointer.lastY).append("  '").append(asString).append("' ").append(backPointer.onBestPath ? "<==" : "").toString());
            }
        }
    }

    private static double score(TextLabels textLabels, int i, int i2, int i3, int i4, Span span, SpanFeatureExtractor spanFeatureExtractor, BinaryClassifier binaryClassifier) {
        if (i2 == 0) {
            return 0.0d;
        }
        InstanceFromSequence instanceFromSequence = new InstanceFromSequence(spanFeatureExtractor.extractInstance(textLabels, span), new String[]{i == 1 ? ExampleSchema.POS_CLASS_NAME : ExampleSchema.NEG_CLASS_NAME});
        if (DEBUG) {
            log.debug(new StringBuffer().append("score: ").append(binaryClassifier.score(instanceFromSequence)).append("\t").append(span).toString());
        }
        return binaryClassifier.score(instanceFromSequence);
    }

    static Class class$(String str) {
        try {
            return Class.forName(str);
        } catch (ClassNotFoundException e) {
            throw new NoClassDefFoundError(e.getMessage());
        }
    }

    static {
        Class cls;
        if (class$edu$cmu$minorthird$text$learn$ConditionalSemiMarkovModel == null) {
            cls = class$("edu.cmu.minorthird.text.learn.ConditionalSemiMarkovModel");
            class$edu$cmu$minorthird$text$learn$ConditionalSemiMarkovModel = cls;
        } else {
            cls = class$edu$cmu$minorthird$text$learn$ConditionalSemiMarkovModel;
        }
        log = Logger.getLogger(cls);
        DEBUG = log.isDebugEnabled();
    }
}
