package edu.cmu.minorthird.classify.transform;

import cern.colt.matrix.impl.AbstractFormatter;
import com.wcohen.ss.BasicStringWrapper;
import com.wcohen.ss.DistanceLearnerFactory;
import com.wcohen.ss.api.StringDistance;
import com.wcohen.ss.api.StringDistanceLearner;
import com.wcohen.ss.api.StringWrapper;
import com.wcohen.ss.lookup.SoftDictionary;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import java.io.Serializable;

/* loaded from: input_file:edu/cmu/minorthird/classify/transform/LeaveOneOutDictTransformLearner.class */
public class LeaveOneOutDictTransformLearner {
    private static boolean idIsSubpopulation = true;
    public static final String[] DEFAULT_PATTERN = {"eq", "lc"};
    private String[] featurePattern;
    private boolean buildDictionaryForNegativeClass;
    private StringDistance[][] distances;
    String distanceNames;

    /* loaded from: input_file:edu/cmu/minorthird/classify/transform/LeaveOneOutDictTransformLearner$DictionaryTransform.class */
    public static class DictionaryTransform extends AbstractInstanceTransform implements Serializable {
        private static final long serialVersionUID = 1;
        private final int CURRENT_VERSION_NUMBER = 1;
        private SoftDictionary[] softDict;
        private String[] featurePattern;
        private ExampleSchema schema;
        private String[] newFeatureNames;
        private double[] newFeatureValues;
        private StringDistance[][] distances;
        int numDistances;

        public DictionaryTransform(ExampleSchema exampleSchema, SoftDictionary[] softDictionaryArr, String[] strArr, StringDistance[][] stringDistanceArr) {
            this.schema = exampleSchema;
            this.softDict = softDictionaryArr;
            this.featurePattern = strArr;
            this.distances = stringDistanceArr;
            this.numDistances = this.distances[0].length;
            this.newFeatureNames = new String[exampleSchema.getNumberOfClasses() * this.numDistances];
            this.newFeatureValues = new double[this.newFeatureNames.length];
            int i = 0;
            for (int i2 = 0; i2 < exampleSchema.getNumberOfClasses(); i2++) {
                for (int i3 = 0; i3 < this.distances[i2].length; i3++) {
                    int i4 = i;
                    i++;
                    this.newFeatureNames[i4] = new StringBuffer().append(this.distances[i2][i3].toString()).append("_").append(exampleSchema.getClassName(i2)).toString();
                }
            }
        }

        @Override // edu.cmu.minorthird.classify.transform.AbstractInstanceTransform, edu.cmu.minorthird.classify.transform.InstanceTransform
        public Instance transform(Instance instance) {
            int i = 0;
            while (i < this.newFeatureValues.length) {
                int i2 = i;
                i++;
                this.newFeatureValues[i2] = 0.0d;
            }
            String featureValue = LeaveOneOutDictTransformLearner.getFeatureValue(instance, this.featurePattern);
            if (featureValue == null) {
                return instance;
            }
            boolean z = false;
            BasicStringWrapper basicStringWrapper = new BasicStringWrapper(featureValue);
            for (int i3 = 0; i3 < this.schema.getNumberOfClasses(); i3++) {
                Object lookup = this.softDict[i3].lookup(LeaveOneOutDictTransformLearner.instanceId(instance), basicStringWrapper);
                if (lookup != null) {
                    for (int i4 = 0; i4 < this.distances[i3].length; i4++) {
                        double score = this.distances[i3][i4].score(basicStringWrapper, (StringWrapper) lookup);
                        if (score >= 0.0d) {
                            z = true;
                            this.newFeatureValues[(i3 * this.numDistances) + i4] = score;
                        }
                    }
                }
            }
            return z ? new AugmentedInstance(instance, this.newFeatureNames, this.newFeatureValues) : instance;
        }

        public String toString() {
            StringBuffer stringBuffer = new StringBuffer("[DictionaryTransform: dictSize");
            for (int i = 0; i < this.schema.getNumberOfClasses(); i++) {
                stringBuffer.append(new StringBuffer().append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(this.schema.getClassName(i)).append("=").toString());
                stringBuffer.append(Integer.toString(this.softDict[i].size()));
            }
            stringBuffer.append("]");
            return stringBuffer.toString();
        }
    }

    public LeaveOneOutDictTransformLearner() {
        this("SoftTFIDF");
    }

    public LeaveOneOutDictTransformLearner(String str) {
        this(DEFAULT_PATTERN, str);
    }

    public LeaveOneOutDictTransformLearner(String[] strArr) {
        this(strArr, "SoftTFIDF");
    }

    public LeaveOneOutDictTransformLearner(String[] strArr, String str) {
        this.buildDictionaryForNegativeClass = false;
        this.featurePattern = strArr;
        this.distanceNames = str;
    }

    public void setSchema(ExampleSchema exampleSchema) {
    }

    public void trainDistances(ExampleSchema exampleSchema, SoftDictionary[] softDictionaryArr) {
        this.distances = new StringDistance[exampleSchema.getNumberOfClasses()][0];
        for (int i = 0; i < exampleSchema.getNumberOfClasses(); i++) {
            this.distances[i] = DistanceLearnerFactory.buildArray(this.distanceNames);
        }
        for (int i2 = 0; i2 < exampleSchema.getNumberOfClasses(); i2++) {
            for (int i3 = 0; i3 < this.distances[i2].length; i3++) {
                if (this.distances[i2][i3] instanceof StringDistanceLearner) {
                    this.distances[i2][i3] = softDictionaryArr[i2].getTeacher().train((StringDistanceLearner) this.distances[i2][i3]);
                }
            }
        }
    }

    public InstanceTransform batchTrain(Dataset dataset) {
        ExampleSchema schema = dataset.getSchema();
        int classIndex = schema.getClassIndex(ExampleSchema.NEG_CLASS_NAME);
        SoftDictionary[] softDictionaryArr = new SoftDictionary[schema.getNumberOfClasses()];
        for (int i = 0; i < schema.getNumberOfClasses(); i++) {
            softDictionaryArr[i] = new SoftDictionary();
        }
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            String featureValue = getFeatureValue(nextExample, this.featurePattern);
            if (featureValue != null) {
                int classIndex2 = schema.getClassIndex(nextExample.getLabel().bestClassName());
                if (this.buildDictionaryForNegativeClass || classIndex2 != classIndex) {
                    softDictionaryArr[classIndex2].put(instanceId(nextExample), featureValue, nextExample);
                }
            }
        }
        trainDistances(schema, softDictionaryArr);
        return new DictionaryTransform(schema, softDictionaryArr, this.featurePattern, this.distances);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String instanceId(Instance instance) {
        if (idIsSubpopulation) {
            return instance.getSubpopulationId();
        }
        if (instance instanceof Example) {
            return new StringBuffer().append(instance.getSubpopulationId()).append(":").append(Integer.toString(instance.hashCode())).toString();
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String getFeatureValue(Instance instance, String[] strArr) {
        Feature.Looper featureIterator = instance.featureIterator();
        while (featureIterator.hasNext()) {
            String[] name = featureIterator.nextFeature().getName();
            if (matches(name, strArr)) {
                return name[name.length - 1];
            }
        }
        return null;
    }

    private static boolean matches(String[] strArr, String[] strArr2) {
        if (strArr.length - 1 != strArr2.length) {
            return false;
        }
        for (int i = 0; i < strArr2.length; i++) {
            if (!strArr2[i].equals(strArr[i])) {
                return false;
            }
        }
        return true;
    }
}
