package edu.cmu.minorthird.classify.transform;

import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.SampleDatasets;

/* loaded from: input_file:edu/cmu/minorthird/classify/transform/ChiSquareTransformLearner.class */
public class ChiSquareTransformLearner implements InstanceTransformLearner {
    private String frequencyModel;

    public ChiSquareTransformLearner() {
        this.frequencyModel = "document";
    }

    public ChiSquareTransformLearner(String str) {
        this.frequencyModel = str;
    }

    @Override // edu.cmu.minorthird.classify.transform.InstanceTransformLearner
    public void setSchema(ExampleSchema exampleSchema) {
        if (!ExampleSchema.BINARY_EXAMPLE_SCHEMA.equals(exampleSchema)) {
            throw new IllegalStateException("can only learn binary example data");
        }
    }

    @Override // edu.cmu.minorthird.classify.transform.InstanceTransformLearner
    public InstanceTransform batchTrain(Dataset dataset) {
        ChiSquareInstanceTransform chiSquareInstanceTransform = new ChiSquareInstanceTransform();
        BasicFeatureIndex basicFeatureIndex = new BasicFeatureIndex(dataset);
        if (this.frequencyModel.equals("document")) {
            if (basicFeatureIndex.size(ExampleSchema.POS_CLASS_NAME) + basicFeatureIndex.size(ExampleSchema.NEG_CLASS_NAME) != dataset.size()) {
                throw new IllegalStateException("ERROR - Dataset size and index size do not match");
            }
            Feature.Looper featureIterator = basicFeatureIndex.featureIterator();
            while (featureIterator.hasNext()) {
                Feature nextFeature = featureIterator.nextFeature();
                chiSquareInstanceTransform.addFeature(new ContingencyTable(basicFeatureIndex.size(nextFeature, ExampleSchema.POS_CLASS_NAME), basicFeatureIndex.size(nextFeature, ExampleSchema.NEG_CLASS_NAME), r0 - r0, r0 - r0).getChiSquared(), nextFeature);
            }
        } else if (this.frequencyModel.equals("word")) {
            System.out.println("warning: " + this.frequencyModel + " not implemented yet!");
            System.exit(1);
        } else {
            System.out.println("warning: " + this.frequencyModel + " is an unknown model for frequency!");
            System.exit(1);
        }
        return chiSquareInstanceTransform;
    }

    public static void main(String[] strArr) {
        Dataset sampleData = SampleDatasets.sampleData("toy", false);
        System.out.println("old data:\n" + sampleData);
        ChiSquareInstanceTransform chiSquareInstanceTransform = (ChiSquareInstanceTransform) new ChiSquareTransformLearner().batchTrain(sampleData);
        chiSquareInstanceTransform.setNumberOfFeatures(10);
        System.out.println("new data:\n" + chiSquareInstanceTransform.transform(sampleData));
        System.out.println("\n\n\n " + chiSquareInstanceTransform.toString(8));
    }
}
