package edu.cmu.minorthird.text.learn;

import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.sequential.ConfidenceReportingSequenceClassifier;
import edu.cmu.minorthird.classify.sequential.ConfidenceUtils;
import edu.cmu.minorthird.classify.sequential.SequenceClassifier;
import edu.cmu.minorthird.text.AbstractAnnotator;
import edu.cmu.minorthird.text.Annotator;
import edu.cmu.minorthird.text.BasicSpanLooper;
import edu.cmu.minorthird.text.Details;
import edu.cmu.minorthird.text.FancyLoader;
import edu.cmu.minorthird.text.MonotonicTextLabels;
import edu.cmu.minorthird.text.NestedTextLabels;
import edu.cmu.minorthird.text.Span;
import edu.cmu.minorthird.text.TextLabels;
import edu.cmu.minorthird.text.learn.SequenceAnnotatorLearner;
import edu.cmu.minorthird.text.learn.experiments.MonotonicSubTextLabels;
import edu.cmu.minorthird.text.learn.experiments.SubTextBase;
import edu.cmu.minorthird.util.IOUtil;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.Visible;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/text/learn/ConfidenceReportingSequenceAnnotator.class */
public class ConfidenceReportingSequenceAnnotator extends AbstractAnnotator implements ExtractorAnnotator, Serializable, Visible {
    private static Logger log = Logger.getLogger(ConfidenceReportingSequenceAnnotator.class);
    private static final boolean DEBUG = false;
    private SequenceAnnotatorLearner.SequenceAnnotator sequenceAnnotator;

    public ConfidenceReportingSequenceAnnotator(SequenceAnnotatorLearner.SequenceAnnotator sequenceAnnotator) {
        this.sequenceAnnotator = sequenceAnnotator;
    }

    @Override // edu.cmu.minorthird.text.learn.ExtractorAnnotator
    public String getSpanType() {
        return this.sequenceAnnotator.getSpanType();
    }

    @Override // edu.cmu.minorthird.text.AbstractAnnotator
    public void doAnnotate(MonotonicTextLabels monotonicTextLabels) {
        Span.Looper documentSpanIterator = monotonicTextLabels.getTextBase().documentSpanIterator();
        ProgressCounter progressCounter = new ProgressCounter("tagging with classifier", "document", documentSpanIterator.estimatedSize());
        while (documentSpanIterator.hasNext()) {
            Span nextSpan = documentSpanIterator.nextSpan();
            log.info("extracting from doc '" + nextSpan.getDocumentId() + "'");
            try {
                MonotonicSubTextLabels monotonicSubTextLabels = new MonotonicSubTextLabels(new SubTextBase(monotonicTextLabels.getTextBase(), new BasicSpanLooper(Collections.singleton(nextSpan))), new NestedTextLabels(monotonicTextLabels));
                Instance[] instanceArr = new Instance[nextSpan.size()];
                for (int i = 0; i < nextSpan.size(); i++) {
                    instanceArr[i] = this.sequenceAnnotator.getSpanFeatureExtractor().extractInstance(monotonicTextLabels, nextSpan.subSpan(i, 1));
                }
                ClassLabel[] classification = this.sequenceAnnotator.getSequenceClassifier().classification(instanceArr);
                this.sequenceAnnotator.doAnnotate(monotonicSubTextLabels);
                Span.Looper instanceIterator = monotonicSubTextLabels.instanceIterator(this.sequenceAnnotator.getSpanType());
                while (instanceIterator.hasNext()) {
                    Span nextSpan2 = instanceIterator.nextSpan();
                    monotonicTextLabels.addToType(nextSpan2, this.sequenceAnnotator.getSpanType(), new Details(computeConfidence(this.sequenceAnnotator.getSequenceClassifier(), instanceArr, nextSpan2, classification), ConfidenceReportingSequenceAnnotator.class));
                }
                progressCounter.progress();
            } catch (SubTextBase.UnknownDocumentException e) {
                throw new IllegalStateException("error: " + e);
            }
        }
        progressCounter.finished();
    }

    private double computeConfidence(SequenceClassifier sequenceClassifier, Instance[] instanceArr, Span span, ClassLabel[] classLabelArr) {
        int documentSpanStartIndex = span.documentSpanStartIndex();
        int size = documentSpanStartIndex + span.size();
        if (!(sequenceClassifier instanceof ConfidenceReportingSequenceClassifier)) {
            return ConfidenceUtils.sumPredictedWeights(classLabelArr, documentSpanStartIndex, size);
        }
        ConfidenceReportingSequenceClassifier confidenceReportingSequenceClassifier = (ConfidenceReportingSequenceClassifier) sequenceClassifier;
        ClassLabel[] classLabelArr2 = new ClassLabel[classLabelArr.length];
        for (int i = documentSpanStartIndex; i < size; i++) {
            classLabelArr2[i] = ClassLabel.negativeLabel(-1.0d);
        }
        return confidenceReportingSequenceClassifier.confidence(instanceArr, classLabelArr, classLabelArr2, documentSpanStartIndex, size);
    }

    @Override // edu.cmu.minorthird.text.AbstractAnnotator, edu.cmu.minorthird.text.Annotator
    public String explainAnnotation(TextLabels textLabels, Span span) {
        return this.sequenceAnnotator.explainAnnotation(textLabels, span);
    }

    @Override // edu.cmu.minorthird.util.gui.Visible
    public Viewer toGUI() {
        return new SmartVanillaViewer(this.sequenceAnnotator);
    }

    public static void main(String[] strArr) {
        if (strArr.length == 3 && "-test".equals(strArr[0])) {
            File file = new File(strArr[1]);
            NestedTextLabels nestedTextLabels = new NestedTextLabels(FancyLoader.loadTextLabels(strArr[2]));
            try {
                ConfidenceReportingSequenceAnnotator confidenceReportingSequenceAnnotator = new ConfidenceReportingSequenceAnnotator((SequenceAnnotatorLearner.SequenceAnnotator) IOUtil.loadSerialized(file));
                confidenceReportingSequenceAnnotator.annotate(nestedTextLabels);
                Span.Looper instanceIterator = nestedTextLabels.instanceIterator(confidenceReportingSequenceAnnotator.getSpanType());
                while (instanceIterator.hasNext()) {
                    Span nextSpan = instanceIterator.nextSpan();
                    System.out.println("confidence=" + nestedTextLabels.getDetails(nextSpan, confidenceReportingSequenceAnnotator.getSpanType()).getConfidence() + " for span " + nextSpan);
                }
                return;
            } catch (IOException e) {
                throw new IllegalArgumentException("can't load annotator from " + file + ": " + e);
            }
        }
        if (strArr.length != 2) {
            throw new IllegalArgumentException("usage: previouslySavedAnnotatorFile newAnnotatorFile");
        }
        File file2 = new File(strArr[0]);
        File file3 = new File(strArr[1]);
        try {
            Annotator annotator = (Annotator) IOUtil.loadSerialized(file2);
            if (!(annotator instanceof SequenceAnnotatorLearner.SequenceAnnotator)) {
                throw new IllegalArgumentException(file2 + " does not contain an annotator learned with a SequenceAnnotatorLearner");
            }
            try {
                IOUtil.saveSerialized(new ConfidenceReportingSequenceAnnotator((SequenceAnnotatorLearner.SequenceAnnotator) annotator), file3);
            } catch (IOException e2) {
                throw new IllegalArgumentException("can't save new annotator in " + file3 + ": " + e2);
            }
        } catch (IOException e3) {
            throw new IllegalArgumentException("can't load annotator from " + file2 + ": " + e3);
        }
    }
}
