package edu.cmu.minorthird.classify.algorithms.linear;

import edu.cmu.minorthird.classify.BasicDataset;
import edu.cmu.minorthird.classify.BasicFeatureIndex;
import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Feature;
import edu.cmu.minorthird.classify.SampleDatasets;
import edu.cmu.minorthird.classify.algorithms.random.Estimate;
import edu.cmu.minorthird.classify.algorithms.random.Estimators;
import java.util.ArrayList;
import java.util.TreeMap;

/* loaded from: input_file:edu/cmu/minorthird/classify/algorithms/linear/KWayMixtureLearner.class */
public class KWayMixtureLearner extends BatchClassifierLearner {
    private double SCALE;
    private String MODEL;
    private String PARAMETERIZATION;

    public KWayMixtureLearner() {
        this.SCALE = 10.0d;
        this.MODEL = "Poisson";
        this.PARAMETERIZATION = "default";
    }

    public KWayMixtureLearner(String str) {
        this.SCALE = 10.0d;
        this.MODEL = str;
        this.PARAMETERIZATION = "default";
    }

    public KWayMixtureLearner(String str, String str2) {
        this.SCALE = 10.0d;
        this.MODEL = str;
        this.PARAMETERIZATION = str2;
    }

    public KWayMixtureLearner(String str, String str2, double d) {
        this.SCALE = d;
        this.MODEL = str;
        this.PARAMETERIZATION = str2;
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public void setSchema(ExampleSchema exampleSchema) {
        if (ExampleSchema.BINARY_EXAMPLE_SCHEMA.equals(exampleSchema)) {
            throw new IllegalStateException("can only learn non-binary example data");
        }
    }

    @Override // edu.cmu.minorthird.classify.BatchClassifierLearner
    public Classifier batchTrain(Dataset dataset) {
        MultinomialClassifier multinomialClassifier = new MultinomialClassifier();
        multinomialClassifier.setScale(this.SCALE);
        ExampleSchema schema = dataset.getSchema();
        BasicFeatureIndex basicFeatureIndex = new BasicFeatureIndex(dataset);
        int numberOfClasses = schema.getNumberOfClasses();
        String[] strArr = new String[numberOfClasses];
        int[] iArr = new int[numberOfClasses];
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < numberOfClasses; i++) {
            strArr[i] = schema.getClassName(i);
            multinomialClassifier.addValidLabel(new ClassLabel(strArr[i]));
            iArr[i] = basicFeatureIndex.size(strArr[i]);
            double[] dArr = new double[iArr[i]];
            double[] dArr2 = new double[iArr[i]];
            arrayList.add(dArr);
            arrayList2.add(dArr2);
        }
        double size = dataset.size();
        double numberOfFeatures = basicFeatureIndex.numberOfFeatures();
        double[] dArr3 = new double[numberOfClasses];
        double[] dArr4 = new double[numberOfClasses];
        int[] iArr2 = new int[numberOfClasses];
        Example.Looper it = dataset.iterator();
        while (it.hasNext()) {
            Example nextExample = it.nextExample();
            int classIndex = schema.getClassIndex(nextExample.getLabel().bestClassName().toString());
            if (classIndex != multinomialClassifier.indexOf(nextExample.getLabel())) {
                System.out.println("Buzz! Error: incompatible class indeces ...");
                System.exit(1);
            }
            dArr4[classIndex] = dArr4[classIndex] + 1.0d;
            Feature.Looper featureIterator = basicFeatureIndex.featureIterator();
            while (featureIterator.hasNext()) {
                Feature nextFeature = featureIterator.nextFeature();
                dArr3[classIndex] = dArr3[classIndex] + nextExample.getWeight(nextFeature);
                double[] dArr5 = (double[]) arrayList2.get(classIndex);
                int i2 = iArr2[classIndex];
                dArr5[i2] = dArr5[i2] + nextExample.getWeight(nextFeature);
            }
            iArr2[classIndex] = iArr2[classIndex] + 1;
        }
        Feature.Looper featureIterator2 = basicFeatureIndex.featureIterator();
        while (featureIterator2.hasNext()) {
            int[] iArr3 = new int[numberOfClasses];
            Feature nextFeature2 = featureIterator2.nextFeature();
            Example.Looper it2 = dataset.iterator();
            while (it2.hasNext()) {
                Example nextExample2 = it2.nextExample();
                int classIndex2 = schema.getClassIndex(nextExample2.getLabel().bestClassName().toString());
                if (this.MODEL.equals("Naive-Bayes")) {
                    double[] dArr6 = (double[]) arrayList.get(classIndex2);
                    int i3 = iArr3[classIndex2];
                    iArr3[classIndex2] = i3 + 1;
                    dArr6[i3] = Math.min(1.0d, nextExample2.getWeight(nextFeature2));
                } else {
                    double[] dArr7 = (double[]) arrayList.get(classIndex2);
                    int i4 = iArr3[classIndex2];
                    iArr3[classIndex2] = i4 + 1;
                    dArr7[i4] = nextExample2.getWeight(nextFeature2);
                }
            }
            if (this.MODEL.equals("Naive-Bayes")) {
                multinomialClassifier.setPrior(1.0d / numberOfFeatures);
                multinomialClassifier.setUnseenModel("Naive-Bayes");
                for (int i5 = 0; i5 < numberOfClasses; i5++) {
                    multinomialClassifier.setClassParameter(i5, estimateClassProbMLE(1.0d, numberOfClasses, dArr4[i5], size));
                    if (this.PARAMETERIZATION.equals("default") || this.PARAMETERIZATION.equals("mean")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i5, Estimators.estimateNaiveBayesMean(1.0d, numberOfFeatures, sum((double[]) arrayList.get(i5)), dArr4[i5]));
                    } else if (this.PARAMETERIZATION.equals("weighted-mean")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i5, Estimators.estimateNaiveBayesWeightedMean((double[]) arrayList.get(i5), (double[]) arrayList2.get(i5), 1.0d / numberOfFeatures, this.SCALE));
                    }
                }
                multinomialClassifier.setFeatureModel(nextFeature2, "Naive-Bayes");
            } else if (this.MODEL.equals("Binomial")) {
                multinomialClassifier.setPrior(1.0d / numberOfFeatures);
                multinomialClassifier.setUnseenModel("Binomial");
                for (int i6 = 0; i6 < numberOfClasses; i6++) {
                    multinomialClassifier.setClassParameter(i6, estimateClassProbMLE(1.0d, numberOfClasses, dArr4[i6], size));
                    if (this.PARAMETERIZATION.equals("default") || this.PARAMETERIZATION.equals("p/N")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i6, Estimators.estimateBinomialPN((double[]) arrayList.get(i6), (double[]) arrayList2.get(i6), 1.0d / numberOfFeatures, this.SCALE));
                    } else if (this.PARAMETERIZATION.equals("mu/delta")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i6, Estimators.estimateBinomialMuDelta((double[]) arrayList.get(i6), (double[]) arrayList2.get(i6), 1.0d / numberOfFeatures, this.SCALE));
                    }
                }
                multinomialClassifier.setFeatureModel(nextFeature2, "Binomial");
            } else if (this.MODEL.equals("Poisson")) {
                multinomialClassifier.setPrior(1.0d / numberOfFeatures);
                multinomialClassifier.setUnseenModel("Poisson");
                for (int i7 = 0; i7 < numberOfClasses; i7++) {
                    multinomialClassifier.setClassParameter(i7, estimateClassProbMLE(1.0d, numberOfClasses, dArr4[i7], size));
                    if (this.PARAMETERIZATION.equals("default") || this.PARAMETERIZATION.equals("weighted-lambda")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i7, Estimators.estimatePoissonWeightedLambda((double[]) arrayList.get(i7), (double[]) arrayList2.get(i7), 1.0d / numberOfFeatures, this.SCALE));
                    } else if (this.PARAMETERIZATION.equals("lambda")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i7, Estimators.estimatePoissonLambda(1.0d / this.SCALE, numberOfFeatures, sum((double[]) arrayList.get(i7)), dArr3[i7] / this.SCALE));
                    }
                }
                multinomialClassifier.setFeatureModel(nextFeature2, "Poisson");
            } else if (this.MODEL.equals("Negative-Binomial")) {
                multinomialClassifier.setPrior(1.0d / numberOfFeatures);
                multinomialClassifier.setUnseenModel("Negative-Binomial");
                for (int i8 = 0; i8 < numberOfClasses; i8++) {
                    multinomialClassifier.setClassParameter(i8, estimateClassProbMLE(1.0d, numberOfClasses, dArr4[i8], size));
                    if (this.PARAMETERIZATION.equals("default") | this.PARAMETERIZATION.equals("mu/delta")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i8, Estimators.estimateNegativeBinomialMuDelta((double[]) arrayList.get(i8), (double[]) arrayList2.get(i8), 1.0d / numberOfFeatures, this.SCALE));
                    }
                }
                multinomialClassifier.setFeatureModel(nextFeature2, "Negative-Binomial");
            } else if (this.MODEL.equals("Mixture")) {
                multinomialClassifier.setPrior(1.0d / numberOfFeatures);
                multinomialClassifier.setUnseenModel("Mixture");
                for (int i9 = 0; i9 < numberOfClasses; i9++) {
                    multinomialClassifier.setClassParameter(i9, estimateClassProbMLE(1.0d, numberOfClasses, dArr4[i9], size));
                    double[] dArr8 = (double[]) arrayList.get(i9);
                    double estimateMean = Estimators.estimateMean(dArr8);
                    double estimateVar = Estimators.estimateVar(dArr8);
                    Estimators.Max(dArr8);
                    String str = "";
                    if (estimateMean > estimateVar) {
                        str = "Binomial";
                    } else if (estimateMean <= estimateVar) {
                        str = "Negative-Binomial";
                    }
                    multinomialClassifier.setFeatureModel(nextFeature2, str);
                    if (str.equals("Naive-Bayes")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i9, Estimators.estimateNaiveBayesWeightedMean(dArr8, (double[]) arrayList2.get(i9), 1.0d / numberOfFeatures, this.SCALE));
                    } else if (str.equals("Binomial")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i9, Estimators.estimateBinomialMuDelta(dArr8, (double[]) arrayList2.get(i9), 1.0d / numberOfFeatures, this.SCALE));
                    } else if (str.equals("Poisson")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i9, Estimators.estimatePoissonWeightedLambda(dArr8, (double[]) arrayList2.get(i9), 1.0d / numberOfFeatures, this.SCALE));
                    } else if (str.equals("Negative-Binomial")) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i9, Estimators.estimateNegativeBinomialMuDelta(dArr8, (double[]) arrayList2.get(i9), 1.0d / numberOfFeatures, this.SCALE));
                    }
                }
            } else if (this.MODEL.equals("Dirichlet-Poisson MCMC")) {
                multinomialClassifier.setPrior(1.0d / numberOfFeatures);
                multinomialClassifier.setUnseenModel("Dirichlet-Poisson MCMC");
                double[] dArr9 = new double[numberOfClasses];
                double[] dArr10 = new double[numberOfClasses];
                Estimate[] estimateArr = new Estimate[numberOfClasses];
                if (this.PARAMETERIZATION.equals("default") || this.PARAMETERIZATION.equals("weighted-lambda")) {
                    for (int i10 = 0; i10 < numberOfClasses; i10++) {
                        multinomialClassifier.setClassParameter(i10, estimateClassProbMLE(1.0d, numberOfClasses, dArr4[i10], size));
                        double[] dArr11 = (double[]) arrayList.get(i10);
                        double[] dArr12 = (double[]) arrayList2.get(i10);
                        estimateArr[i10] = Estimators.estimatePoissonWeightedLambda(dArr11, dArr12, 1.0d / numberOfFeatures, this.SCALE);
                        dArr9[i10] = sum(dArr11);
                        dArr10[i10] = sum(dArr12);
                    }
                    Estimate[] mcmcEstimateDirichletPoissonTauSigma = Estimators.mcmcEstimateDirichletPoissonTauSigma(estimateArr, new double[]{1.0E-7d, 1.0E-7d}, new double[]{1.0d, 150.0d}, dArr9[0], dArr9[1], dArr10[0], dArr10[1], new double[]{2.0d, 1.0d}, 0.075d, (((Double) estimateArr[0].getPms().get("lambda")).doubleValue() + ((Double) estimateArr[1].getPms().get("lambda")).doubleValue()) / 10.0d, 100);
                    for (int i11 = 0; i11 < numberOfClasses; i11++) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, i11, mcmcEstimateDirichletPoissonTauSigma[i11]);
                    }
                } else if (this.PARAMETERIZATION.equals("lambda")) {
                    for (int i12 = 0; i12 < numberOfClasses; i12++) {
                        multinomialClassifier.setClassParameter(i12, estimateClassProbMLE(1.0d, numberOfClasses, dArr4[i12], size));
                        double[] dArr13 = (double[]) arrayList.get(i12);
                        double[] dArr14 = (double[]) arrayList2.get(i12);
                        estimateArr[i12] = Estimators.estimatePoissonLambda(1.0d, numberOfFeatures, sum((double[]) arrayList.get(i12)), dArr3[i12]);
                        dArr9[i12] = sum(dArr13);
                        dArr10[i12] = sum(dArr14);
                    }
                    Estimate[] mcmcEstimateDirichletPoissonTauSigma2 = Estimators.mcmcEstimateDirichletPoissonTauSigma(estimateArr, new double[]{1.0E-7d, 1.0E-7d}, new double[]{1.0E-7d, 150.0d}, dArr9[0], dArr9[1], dArr10[0], dArr10[1], new double[]{2.0d, 1.0d}, 0.1d, 0.5d, 100);
                    for (int i13 = 0; i13 < numberOfClasses; i13++) {
                        multinomialClassifier.setFeatureGivenClassParameter(nextFeature2, 0, mcmcEstimateDirichletPoissonTauSigma2[i13]);
                    }
                }
                multinomialClassifier.setFeatureModel(nextFeature2, "Dirichlet-Poisson MCMC");
            }
        }
        return multinomialClassifier;
    }

    private double estimateClassProbMLE(double d, double d2, double d3, double d4) {
        return ((d / d2) + d3) / (1.0d + d4);
    }

    private double sum(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        return d;
    }

    private TreeMap estimateNegBinMOME(double[] dArr, double[] dArr2, double d) {
        double d2;
        int length = dArr.length;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (int i = 0; i < length; i++) {
            d3 += dArr[i];
            d4 += dArr2[i];
            d5 += Math.pow(dArr2[i], 2.0d);
        }
        double d6 = (d3 + ((d * 1.0d) / this.SCALE)) / (d4 + (1.0d / this.SCALE));
        double d7 = 0.0d;
        if (length <= 1.0d) {
            d2 = 0.0d;
            d7 = 0.0d;
        } else {
            d2 = (d4 - (d5 / d4)) / (length - 1.0d);
            for (int i2 = 0; i2 < length; i2++) {
                d7 += (dArr2[i2] * Math.pow((dArr[i2] / dArr2[i2]) - d6, 2.0d)) / (length - 1.0d);
            }
        }
        double max = Math.max(0.0d, (d7 - d6) / (d2 * d6));
        if (new Double(max).isNaN()) {
            max = 0.0d;
        }
        TreeMap treeMap = new TreeMap();
        treeMap.put("mu", new Double(d6));
        treeMap.put("delta", new Double(max));
        return treeMap;
    }

    public static void main(String[] strArr) {
        new BasicDataset();
        System.out.println("SampleDatasets (bayesUnlabeled):\n" + SampleDatasets.sampleData("bayesUnlabeled", false));
    }
}
