package edu.cmu.minorthird.classify.transform;

import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.MutableInstance;
import gnu.trove.TObjectDoubleHashMap;
import java.io.Serializable;

/* loaded from: input_file:edu/cmu/minorthird/classify/transform/TFIDFTransformLearner.class */
public class TFIDFTransformLearner implements InstanceTransformLearner, Serializable {
    private static final long serialVersionUID = 1;
    private final int CURRENT_VERSION_NUMBER = 1;
    private TObjectDoubleHashMap featureFreq;
    private double numDocuments;

    /* loaded from: input_file:edu/cmu/minorthird/classify/transform/TFIDFTransformLearner$TFIDFWeighter.class */
    private class TFIDFWeighter extends AbstractInstanceTransform implements Serializable {
        private static final long serialVersionUID = 1;
        private final int CURRENT_VERSION_NUMBER = 1;
        private double numDocuments;
        private TObjectDoubleHashMap featureFreq;

        public TFIDFWeighter(double d, TObjectDoubleHashMap tObjectDoubleHashMap) {
            this.numDocuments = d;
            this.featureFreq = tObjectDoubleHashMap;
        }

        @Override // edu.cmu.minorthird.classify.transform.AbstractInstanceTransform, edu.cmu.minorthird.classify.transform.InstanceTransform
        public Instance transform(Instance instance) {
            double d = 0.0d;
            Feature.Looper featureIterator = instance.featureIterator();
            while (featureIterator.hasNext()) {
                double unnormalizedTFIDFWeight = unnormalizedTFIDFWeight(featureIterator.nextFeature(), instance);
                d += unnormalizedTFIDFWeight * unnormalizedTFIDFWeight;
            }
            double sqrt = Math.sqrt(d);
            MutableInstance mutableInstance = new MutableInstance(instance.getSource(), instance.getSubpopulationId());
            Feature.Looper featureIterator2 = instance.featureIterator();
            while (featureIterator2.hasNext()) {
                Feature nextFeature = featureIterator2.nextFeature();
                mutableInstance.addNumeric(nextFeature, unnormalizedTFIDFWeight(nextFeature, instance) / sqrt);
            }
            return mutableInstance;
        }

        private double unnormalizedTFIDFWeight(Feature feature, Instance instance) {
            double d = this.featureFreq.get(feature);
            if (d == 0.0d) {
                d = 1.0d;
            }
            return Math.log(instance.getWeight(feature) + 1.0d) * Math.log(this.numDocuments / d);
        }

        public String toString() {
            return "[TFIDFWeighter]";
        }
    }

    @Override // edu.cmu.minorthird.classify.transform.InstanceTransformLearner
    public void setSchema(ExampleSchema exampleSchema) {
    }

    @Override // edu.cmu.minorthird.classify.transform.InstanceTransformLearner
    public InstanceTransform batchTrain(Dataset dataset) {
        this.numDocuments = dataset.size();
        this.featureFreq = new TObjectDoubleHashMap();
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            Feature.Looper featureIterator = it.nextExample().featureIterator();
            while (featureIterator.hasNext()) {
                Feature nextFeature = featureIterator.nextFeature();
                this.featureFreq.put(nextFeature, this.featureFreq.get(nextFeature) + 1.0d);
            }
        }
        return new TFIDFWeighter(this.numDocuments, this.featureFreq);
    }
}
