package edu.cmu.minorthird.classify.semisupervised;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.ClassifierLearner;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.SampleDatasets;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/semisupervised/SemiSupervisedNaiveBayesLearner.class */
public class SemiSupervisedNaiveBayesLearner extends SemiSupervisedBatchClassifierLearner {
    private static Logger log = Logger.getLogger(SemiSupervisedNaiveBayesLearner.class);
    private int MAX_ITER;
    private Instance.Looper iteratorOverUnlabeled;

    public SemiSupervisedNaiveBayesLearner() {
        this.MAX_ITER = 1000;
    }

    public SemiSupervisedNaiveBayesLearner(int i) {
        this.MAX_ITER = 1000;
        this.MAX_ITER = i;
    }

    @Override // edu.cmu.minorthird.classify.semisupervised.SemiSupervisedBatchClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
    }

    @Override // edu.cmu.minorthird.classify.semisupervised.SemiSupervisedBatchClassifierLearner, edu.cmu.minorthird.classify.ClassifierLearner
    public void setInstancePool(Instance.Looper looper) {
        this.iteratorOverUnlabeled = looper;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public ClassifierLearner copy() {
        ClassifierLearner classifierLearner = null;
        try {
            classifierLearner = (ClassifierLearner) clone();
            classifierLearner.reset();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return classifierLearner;
    }

    @Override // edu.cmu.minorthird.classify.semisupervised.SemiSupervisedBatchClassifierLearner
    public Classifier batchTrain(SemiSupervisedDataset semiSupervisedDataset) {
        MultinomialClassifier multinomialClassifier = new MultinomialClassifier();
        int i = 0;
        Example.Looper it = semiSupervisedDataset.iterator();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            if (!multinomialClassifier.isPresent(nextExample.getLabel())) {
                multinomialClassifier.addValidLabel(nextExample.getLabel());
                i++;
            }
        }
        BasicFeatureIndex basicFeatureIndex = new BasicFeatureIndex(semiSupervisedDataset);
        double[] dArr = new double[i];
        double[] dArr2 = new double[i];
        double size = semiSupervisedDataset.size();
        double numberOfFeatures = basicFeatureIndex.numberOfFeatures();
        Example.Looper it2 = semiSupervisedDataset.iterator();
        while (it2.hasNext()) {
            Example nextExample2 = it2.nextExample();
            int indexOf = multinomialClassifier.indexOf(nextExample2.getLabel());
            dArr2[indexOf] = dArr2[indexOf] + 1.0d;
            Feature.Looper featureIterator = basicFeatureIndex.featureIterator();
            while (featureIterator.hasNext()) {
                dArr[indexOf] = dArr[indexOf] + nextExample2.getWeight(featureIterator.nextFeature());
            }
        }
        for (int i2 = 0; i2 < i; i2++) {
            multinomialClassifier.setClassParameter(i2, estimateClassProbMLE(1.0d, i, dArr2[i2], size));
        }
        Feature.Looper featureIterator2 = basicFeatureIndex.featureIterator();
        while (featureIterator2.hasNext()) {
            Feature nextFeature = featureIterator2.nextFeature();
            double[] dArr3 = new double[i];
            for (int i3 = 0; i3 < basicFeatureIndex.size(nextFeature); i3++) {
                Example example = basicFeatureIndex.getExample(nextFeature, i3);
                int indexOf2 = multinomialClassifier.indexOf(example.getLabel());
                dArr3[indexOf2] = dArr3[indexOf2] + example.getWeight(nextFeature);
            }
            for (int i4 = 0; i4 < i; i4++) {
                multinomialClassifier.setFeatureGivenClassParameter(nextFeature, i4, estimateFeatureProbMLE(1.0d, numberOfFeatures, dArr3[i4], dArr[i4]));
            }
            multinomialClassifier.setFeatureModel(nextFeature, "Binomial");
        }
        BasicDataset basicDataset = new BasicDataset();
        Instance.Looper looper = new Instance.Looper(this.iteratorOverUnlabeled);
        while (looper.hasNext()) {
            Instance nextInstance = looper.nextInstance();
            System.out.println(nextInstance);
            basicDataset.add(new Example(nextInstance, multinomialClassifier.classification(nextInstance)));
        }
        double d = Double.NEGATIVE_INFINITY;
        int i5 = 0;
        boolean z = false;
        while (true) {
            if (!(i5 < this.MAX_ITER) || !(!z)) {
                return multinomialClassifier;
            }
            double d2 = d;
            BasicDataset basicDataset2 = new BasicDataset();
            Example.Looper it3 = semiSupervisedDataset.iterator();
            while (it3.hasNext()) {
                basicDataset2.add(it3.nextExample());
            }
            Example.Looper it4 = basicDataset.iterator();
            while (it4.hasNext()) {
                basicDataset2.add(it4.nextExample());
            }
            multinomialClassifier.reset();
            BasicFeatureIndex basicFeatureIndex2 = new BasicFeatureIndex(basicDataset2);
            double[] dArr4 = new double[i];
            double[] dArr5 = new double[i];
            double size2 = basicDataset2.size();
            double numberOfFeatures2 = basicFeatureIndex2.numberOfFeatures();
            Example.Looper it5 = semiSupervisedDataset.iterator();
            while (it5.hasNext()) {
                Example nextExample3 = it5.nextExample();
                int indexOf3 = multinomialClassifier.indexOf(nextExample3.getLabel());
                dArr5[indexOf3] = dArr5[indexOf3] + 1.0d;
                Feature.Looper featureIterator3 = basicFeatureIndex2.featureIterator();
                while (featureIterator3.hasNext()) {
                    dArr4[indexOf3] = dArr4[indexOf3] + nextExample3.getWeight(featureIterator3.nextFeature());
                }
            }
            for (int i6 = 0; i6 < i; i6++) {
                multinomialClassifier.setClassParameter(i6, estimateClassProbMLE(1.0d, i, dArr5[i6], size2));
            }
            Feature.Looper featureIterator4 = basicFeatureIndex2.featureIterator();
            while (featureIterator4.hasNext()) {
                Feature nextFeature2 = featureIterator4.nextFeature();
                double[] dArr6 = new double[i];
                for (int i7 = 0; i7 < basicFeatureIndex2.size(nextFeature2); i7++) {
                    Example example2 = basicFeatureIndex2.getExample(nextFeature2, i7);
                    int indexOf4 = multinomialClassifier.indexOf(example2.getLabel());
                    dArr6[indexOf4] = dArr6[indexOf4] + example2.getWeight(nextFeature2);
                }
                for (int i8 = 0; i8 < i; i8++) {
                    multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i8, estimateFeatureProbMLE(1.0d, numberOfFeatures2, dArr6[i8], dArr4[i8]));
                }
                multinomialClassifier.setFeatureModel(nextFeature2, "Binomial");
            }
            Instance.Looper looper2 = new Instance.Looper(this.iteratorOverUnlabeled);
            while (looper2.hasNext()) {
                Instance nextInstance2 = looper2.nextInstance();
                System.out.println(nextInstance2);
                basicDataset.add(new Example(nextInstance2, multinomialClassifier.classification(nextInstance2)));
            }
            d = 0.0d;
            Example.Looper it6 = basicDataset2.iterator();
            while (it6.hasNext()) {
                d += multinomialClassifier.getLogLikelihood(it6.nextExample());
            }
            if (EMconverged(d, d2, 1.0E-7d, true)) {
                z = true;
                System.out.println("EM converged!");
            } else {
                System.out.println("iteration=" + (i5 + 1) + " log-likelihood=" + d);
            }
            i5++;
        }
    }

    private double estimateClassProbMLE(double d, double d2, double d3, double d4) {
        return (d + d3) / (d2 + d4);
    }

    private double estimateFeatureProbMLE(double d, double d2, double d3, double d4) {
        return (d + d3) / (d2 + d4);
    }

    private boolean EMconverged(double d, double d2, double d3, boolean z) {
        boolean z2 = false;
        if (z && d - d2 < -0.001d) {
            System.out.println("******likelihood decreased from " + d2 + " to " + d);
        }
        if (Math.abs(d - d2) / (((Math.abs(d) + Math.abs(d2)) + 2.2204E-16d) / 2.0d) < d3) {
            z2 = true;
        }
        return z2;
    }

    public static void main(String[] strArr) {
        new BasicDataset();
        System.out.println("SampleDatasets (bayesUnlabeled):\n" + SampleDatasets.sampleData("bayesUnlabeled", false));
    }
}
