package com.wcohen.ss.expt;

import cern.colt.matrix.impl.AbstractFormatter;
import com.wcohen.cls.BinaryExample;
import com.wcohen.cls.ClassLabel;
import com.wcohen.cls.MutableInstance;
import com.wcohen.cls.expt.CrossValSplitter;
import com.wcohen.cls.expt.Evaluation;
import com.wcohen.ss.AdaptiveStringDistanceLearner;
import com.wcohen.ss.BasicDistanceInstanceIterator;
import com.wcohen.ss.DistanceLearnerFactory;
import com.wcohen.ss.PrintfFormat;
import com.wcohen.ss.api.StringDistance;
import com.wcohen.ss.api.StringDistanceLearner;
import com.wcohen.ss.expt.Blocker;
import com.wcohen.util.ProgressCounter;
import com.wcohen.util.gui.ComponentViewer;
import com.wcohen.util.gui.ParallelViewer;
import com.wcohen.util.gui.TransformedViewer;
import com.wcohen.util.gui.Viewer;
import com.wcohen.util.gui.Visible;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import javax.swing.JComponent;
import javax.swing.JScrollPane;
import javax.swing.JTable;

/* loaded from: input_file:com/wcohen/ss/expt/MatchExpt.class */
public class MatchExpt implements Serializable, Visible {
    public static final String BLOCKER_PACKAGE = "com.wcohen.ss.expt.";
    public static final String DISTANCE_PACKAGE = "com.wcohen.ss.";
    private static final long serialVersionUID = 1;
    private Blocker.Pair[] pairs;
    private int numCorrectPairs;
    private double learningTime;
    private double blockingTime;
    private double matchingTime;
    private double sortingTime;
    private String fileName;
    private String learnerName;
    private String blockerName;
    private static int CURRENT_SERIALIZED_VERSION_NUMBER = 1;
    private static double[] elevenPoints = {0.0d, 0.1d, 0.2d, 0.3d, 0.4d, 0.5d, 0.6d, 0.7d, 0.8d, 0.9d, 1.0d};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/wcohen/ss/expt/MatchExpt$MatchExptEvaluation.class */
    public static class MatchExptEvaluation extends Evaluation {
        private transient Blocker.Pair[] pairs;
        private transient int numCorrectPairs;

        public MatchExptEvaluation(Blocker.Pair[] pairArr, int i) {
            this.pairs = pairArr;
            this.numCorrectPairs = i;
        }

        public Viewer toGUI() {
            ParallelViewer parallelViewer = new ParallelViewer();
            parallelViewer.addSubView("Summary", new Evaluation.SummaryViewer());
            parallelViewer.addSubView("Properties", new Evaluation.PropertyViewer());
            parallelViewer.addSubView("11Pt Precision/Recall", new Evaluation.ElevenPointPrecisionViewer());
            parallelViewer.addSubView("Details", new ComponentViewer() { // from class: com.wcohen.ss.expt.MatchExpt.MatchExptEvaluation.1
                public JComponent componentFor(Object obj) {
                    Object[][] objArr = new Object[MatchExptEvaluation.this.numCorrectPairs + 1][5];
                    int i = 0;
                    PrintfFormat printfFormat = new PrintfFormat("%7.2f");
                    for (int i2 = 0; i2 < MatchExptEvaluation.this.pairs.length; i2++) {
                        if (MatchExptEvaluation.this.pairs[i2] != null) {
                            objArr[i][0] = new Integer(i2);
                            objArr[i][1] = MatchExptEvaluation.this.pairs[i2].isCorrect() ? "+" : "-";
                            objArr[i][2] = printfFormat.sprintf(MatchExptEvaluation.this.pairs[i2].getDistance());
                            objArr[i][3] = MatchExptEvaluation.this.pairs[i2].getA() == null ? "***" : MatchExptEvaluation.this.pairs[i2].getA().unwrap();
                            objArr[i][4] = MatchExptEvaluation.this.pairs[i2].getB() == null ? "***" : MatchExptEvaluation.this.pairs[i2].getB().unwrap();
                            if (MatchExptEvaluation.this.pairs[i2].isCorrect()) {
                                i++;
                            }
                        }
                    }
                    JScrollPane jScrollPane = new JScrollPane(new JTable(objArr, new String[]{"rank", "", "score", "String A", "String B"}));
                    jScrollPane.setHorizontalScrollBarPolicy(30);
                    return jScrollPane;
                }
            });
            parallelViewer.setContent(this);
            return parallelViewer;
        }
    }

    public MatchExpt(MatchData matchData, StringDistanceLearner stringDistanceLearner, Blocker blocker) {
        if (stringDistanceLearner instanceof AdaptiveStringDistanceLearner) {
            setUpAdaptiveExperiment(matchData, stringDistanceLearner, blocker);
        } else {
            setUpFixedExperiment(matchData, stringDistanceLearner, blocker);
        }
        this.fileName = matchData.getFilename();
        this.learnerName = stringDistanceLearner.toString();
        this.blockerName = blocker.toString();
    }

    public MatchExpt(MatchData matchData, StringDistanceLearner stringDistanceLearner) {
        this(matchData, stringDistanceLearner, new NullBlocker());
    }

    public String toString() {
        return "[MatchExpt: " + this.fileName + "," + this.learnerName + "," + this.blockerName + "]";
    }

    private void setUpAdaptiveExperiment(MatchData matchData, StringDistanceLearner stringDistanceLearner, Blocker blocker) {
        System.out.println("setting up expt: " + stringDistanceLearner + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + blocker + " file: " + matchData.getFilename());
        long currentTimeMillis = System.currentTimeMillis();
        blocker.block(matchData);
        this.blockingTime = (System.currentTimeMillis() - currentTimeMillis) / 1000.0d;
        ArrayList arrayList = new ArrayList(blocker.size());
        for (int i = 0; i < blocker.size(); i++) {
            arrayList.add(blocker.getPair(i));
        }
        CrossValSplitter crossValSplitter = new CrossValSplitter(3);
        crossValSplitter.split(arrayList.iterator());
        this.learningTime = 0.0d;
        this.pairs = new Blocker.Pair[blocker.size()];
        this.numCorrectPairs = blocker.numCorrectPairs();
        int i2 = 0;
        for (int i3 = 0; i3 < crossValSplitter.getNumPartitions(); i3++) {
            BasicTeacher basicTeacher = new BasicTeacher(matchData.getIterator(), new BasicDistanceInstanceIterator(Collections.EMPTY_SET.iterator()), new BasicDistanceInstanceIterator(crossValSplitter.getTrain(i3)));
            long currentTimeMillis2 = System.currentTimeMillis();
            StringDistance train = basicTeacher.train(stringDistanceLearner);
            this.learningTime += (System.currentTimeMillis() - currentTimeMillis2) / 1000.0d;
            System.out.println("fold " + i3 + " distance is '" + train + "'");
            long currentTimeMillis3 = System.currentTimeMillis();
            Iterator test = crossValSplitter.getTest(i3);
            while (test.hasNext()) {
                this.pairs[i2] = (Blocker.Pair) test.next();
                this.pairs[i2].setDistance(train.score(this.pairs[i2].getA(), this.pairs[i2].getB()));
                i2++;
            }
            this.matchingTime += (System.currentTimeMillis() - currentTimeMillis3) / 1000.0d;
        }
        Arrays.sort(this.pairs);
    }

    private void setUpFixedExperiment(MatchData matchData, StringDistanceLearner stringDistanceLearner, Blocker blocker) {
        System.out.println("setting up expt: " + stringDistanceLearner + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + blocker + " file: " + matchData.getFilename());
        BasicTeacher basicTeacher = new BasicTeacher(blocker, matchData);
        long currentTimeMillis = System.currentTimeMillis();
        StringDistance train = basicTeacher.train(stringDistanceLearner);
        this.learningTime = (System.currentTimeMillis() - currentTimeMillis) / 1000.0d;
        System.out.println("distance is '" + train + "'");
        long currentTimeMillis2 = System.currentTimeMillis();
        blocker.block(matchData);
        this.blockingTime = (System.currentTimeMillis() - currentTimeMillis2) / 1000.0d;
        this.numCorrectPairs = blocker.numCorrectPairs();
        this.pairs = new Blocker.Pair[blocker.size()];
        long currentTimeMillis3 = System.currentTimeMillis();
        System.out.println("Pairs: " + this.pairs.length + " Correct: " + blocker.numCorrectPairs());
        ProgressCounter progressCounter = new ProgressCounter("computing distances", "proposed pair", blocker.size());
        for (int i = 0; i < blocker.size(); i++) {
            this.pairs[i] = blocker.getPair(i);
            this.pairs[i].setDistance(train.score(this.pairs[i].getA(), this.pairs[i].getB()));
            progressCounter.progress();
        }
        progressCounter.finished();
        this.matchingTime = (System.currentTimeMillis() - currentTimeMillis3) / 1000.0d;
        long currentTimeMillis4 = System.currentTimeMillis();
        Arrays.sort(this.pairs);
        this.sortingTime = (System.currentTimeMillis() - currentTimeMillis4) / 1000.0d;
        System.out.println("Matching time: " + this.matchingTime);
    }

    public Double time() {
        return new Double(this.learningTime + this.blockingTime + this.matchingTime + this.sortingTime);
    }

    public Double pairsPerSecond() {
        return new Double(this.pairs.length / (((this.learningTime + this.blockingTime) + this.matchingTime) + this.sortingTime));
    }

    public Double averagePrecision() {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.pairs.length; i++) {
            if (correctPair(i)) {
                d += 1.0d;
                d2 += d / (i + 1.0d);
            }
        }
        return new Double(d2 / this.numCorrectPairs);
    }

    public Double maxF1() {
        double d = -1.7976931348623157E308d;
        double d2 = 0.0d;
        for (int i = 0; i < this.pairs.length; i++) {
            if (correctPair(i)) {
                d2 += 1.0d;
                double d3 = d2 / (i + 1.0d);
                double d4 = d2 / this.numCorrectPairs;
                if (d3 > 0.0d && d4 > 0.0d) {
                    d = Math.max((2.0d * (d3 * d4)) / (d3 + d4), d);
                }
            }
        }
        return new Double(d);
    }

    public Double blockerRecall() {
        double d = 0.0d;
        for (int i = 0; i < this.pairs.length; i++) {
            if (correctPair(i)) {
                d += 1.0d;
            }
        }
        return new Double(d / this.numCorrectPairs);
    }

    public static double[] interpolated11PointRecallLevels() {
        return elevenPoints;
    }

    public double[] interpolated11PointPrecision() {
        double[] dArr = new double[11];
        int i = 0;
        for (int i2 = 0; i2 < this.pairs.length; i2++) {
            if (correctPair(i2)) {
                i++;
            }
            double d = i / this.numCorrectPairs;
            double d2 = i / (i2 + 1.0d);
            for (int i3 = 0; i3 < elevenPoints.length; i3++) {
                if (d >= elevenPoints[i3]) {
                    dArr[i3] = Math.max(dArr[i3], d2);
                }
            }
        }
        return dArr;
    }

    public void graphPrecisionRecall(PrintStream printStream) throws IOException {
        double[] dArr = new double[this.pairs.length];
        double d = this.numCorrectPairs;
        double length = d / this.pairs.length;
        for (int length2 = this.pairs.length - 1; length2 >= 0; length2--) {
            if (correctPair(length2)) {
                dArr[length2] = length;
                d -= 1.0d;
                length = Math.max(length, d / (length2 + 1));
            }
        }
        double d2 = 0.0d;
        for (int i = 0; i < this.pairs.length; i++) {
            if (correctPair(i)) {
                d2 += 1.0d;
                printStream.println((d2 / this.numCorrectPairs) + "\t" + dArr[i]);
            }
        }
    }

    public void displayResults(boolean z, PrintStream printStream) throws IOException {
        PrintfFormat printfFormat = new PrintfFormat("%s %3d %7.2f | %30s | %30s\n");
        for (int i = 0; i < this.pairs.length; i++) {
            if (this.pairs[i] != null) {
                String str = this.pairs[i].isCorrect() ? "+" : "-";
                String unwrap = this.pairs[i].getA() == null ? "***" : this.pairs[i].getA().unwrap();
                String unwrap2 = this.pairs[i].getB() == null ? "***" : this.pairs[i].getB().unwrap();
                if (z || "+".equals(str)) {
                    printStream.print(printfFormat.sprintf(new Object[]{str, new Integer(i + 1), new Double(this.pairs[i].getDistance()), unwrap, unwrap2}));
                }
            }
        }
    }

    public Viewer toGUI() {
        TransformedViewer transformedViewer = new TransformedViewer(toEvaluation().toGUI()) { // from class: com.wcohen.ss.expt.MatchExpt.1
            public Object transform(Object obj) {
                return ((MatchExpt) obj).toEvaluation();
            }
        };
        transformedViewer.setContent(this);
        return transformedViewer;
    }

    private boolean correctPair(int i) {
        return this.pairs[i] != null && this.pairs[i].isCorrect();
    }

    public Evaluation toEvaluation() {
        ProgressCounter progressCounter = new ProgressCounter("computing statistics", "distance", this.pairs.length);
        MatchExptEvaluation matchExptEvaluation = new MatchExptEvaluation(this.pairs, this.numCorrectPairs);
        for (int i = 0; i < this.pairs.length; i++) {
            matchExptEvaluation.extend(ClassLabel.negativeLabel(this.pairs[i].getDistance()), new BinaryExample(new MutableInstance(this.pairs[i]), this.pairs[i].isCorrect() ? ClassLabel.positiveLabel(1.0d) : ClassLabel.negativeLabel(-1.0d)));
            progressCounter.progress();
        }
        progressCounter.finished();
        matchExptEvaluation.setProperty("Blocker", this.blockerName);
        matchExptEvaluation.setProperty("Distance", this.learnerName);
        matchExptEvaluation.setProperty("File", this.fileName);
        return matchExptEvaluation;
    }

    public static void main(String[] strArr) {
        try {
            MatchExpt matchExpt = new MatchExpt(new MatchData(strArr[2]), DistanceLearnerFactory.build(strArr[1]), (Blocker) Class.forName("com.wcohen.ss.expt." + strArr[0]).newInstance());
            int i = 3;
            while (i < strArr.length) {
                int i2 = i;
                i++;
                String str = strArr[i2];
                if (str.equals("-display")) {
                    matchExpt.displayResults(true, System.out);
                } else if (str.equals("-shortDisplay")) {
                    matchExpt.displayResults(false, System.out);
                } else if (str.equals("-graph")) {
                    matchExpt.graphPrecisionRecall(System.out);
                } else {
                    if (!str.equals("-summarize")) {
                        throw new RuntimeException("illegal command " + str);
                    }
                    System.out.println("maxF1:\t" + matchExpt.maxF1());
                    System.out.println("avgPrec:\t" + matchExpt.averagePrecision());
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.out.println("\nusage: <blocker> <distanceClass> <matchDataFile> [commands]\n");
        }
    }
}
