package edu.cmu.minorthird.classify.relational;

import edu.cmu.minorthird.classify.BatchClassifierLearner;
import edu.cmu.minorthird.classify.ClassLabel;
import edu.cmu.minorthird.classify.Classifier;
import edu.cmu.minorthird.classify.Dataset;
import edu.cmu.minorthird.classify.DatasetClassifierTeacher;
import edu.cmu.minorthird.classify.Example;
import edu.cmu.minorthird.classify.ExampleSchema;
import edu.cmu.minorthird.classify.Explanation;
import edu.cmu.minorthird.classify.Instance;
import edu.cmu.minorthird.classify.SGMExample;
import edu.cmu.minorthird.classify.Splitter;
import edu.cmu.minorthird.classify.algorithms.linear.MaxEntLearner;
import edu.cmu.minorthird.classify.experiments.CrossValSplitter;
import edu.cmu.minorthird.classify.transform.AugmentedInstance;
import edu.cmu.minorthird.util.ProgressCounter;
import edu.cmu.minorthird.util.gui.ParallelViewer;
import edu.cmu.minorthird.util.gui.SmartVanillaViewer;
import edu.cmu.minorthird.util.gui.TransformedViewer;
import edu.cmu.minorthird.util.gui.Viewer;
import edu.cmu.minorthird.util.gui.ViewerFrame;
import edu.cmu.minorthird.util.gui.Visible;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.log4j.Logger;

/* loaded from: input_file:edu/cmu/minorthird/classify/relational/StackedGraphicalLearner.class */
public class StackedGraphicalLearner extends StackedBatchClassifierLearner {
    private static Logger log = Logger.getLogger(StackedGraphicalLearner.class);
    private static final boolean DEBUG = false;
    private ExampleSchema schema;
    private BatchClassifierLearner baseLearner;
    private StackingParams params;

    /* loaded from: input_file:edu/cmu/minorthird/classify/relational/StackedGraphicalLearner$StackedGraphicalClassifier.class */
    public class StackedGraphicalClassifier implements Classifier, Visible {
        private Classifier[] m;
        private RealRelationalDataset dataset;
        private StackingParams params;

        public StackedGraphicalClassifier(Classifier[] classifierArr, StackingParams stackingParams, RealRelationalDataset realRelationalDataset) {
            this.m = classifierArr;
            this.params = stackingParams;
            this.dataset = realRelationalDataset;
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public ClassLabel classification(Instance instance) {
            return this.m[0].classification(instance);
        }

        public HashMap classification(RealRelationalDataset realRelationalDataset) {
            HashMap hashMap = new HashMap();
            RealRelationalDataset realRelationalDataset2 = realRelationalDataset;
            for (int i = 0; i <= this.params.stackingDepth; i++) {
                Example.Looper it = realRelationalDataset2.iterator();
                while (it.hasNext()) {
                    SGMExample sGMExample = (SGMExample) it.nextExample();
                    hashMap.put(sGMExample.getExampleID(), this.m[i].classification(sGMExample));
                }
                if (i + 1 <= this.params.stackingDepth) {
                    realRelationalDataset2 = stackTestDataset(realRelationalDataset2, hashMap);
                }
            }
            return hashMap;
        }

        public RealRelationalDataset stackTestDataset(RealRelationalDataset realRelationalDataset, HashMap hashMap) {
            RealRelationalDataset realRelationalDataset2 = new RealRelationalDataset();
            HashMap linksMap = realRelationalDataset.getLinksMap();
            HashMap aggregators = realRelationalDataset.getAggregators();
            Example.Looper it = realRelationalDataset.iterator();
            while (it.hasNext()) {
                realRelationalDataset2.addSGM(StackedGraphicalLearner.this.AugmentExample((SGMExample) it.nextExample(), linksMap, aggregators, hashMap));
            }
            return realRelationalDataset2;
        }

        public double score(Instance instance, String str) {
            return classification(instance).getWeight(str);
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public String explain(Instance instance) {
            return "sorry, not implemented yet";
        }

        @Override // edu.cmu.minorthird.classify.Classifier
        public Explanation getExplanation(Instance instance) {
            return new Explanation(explain(instance));
        }

        @Override // edu.cmu.minorthird.util.gui.Visible
        public Viewer toGUI() {
            ParallelViewer parallelViewer = new ParallelViewer();
            for (int i = 0; i < this.m.length; i++) {
                final int i2 = i;
                parallelViewer.addSubView("Level " + i2 + " classifier", new TransformedViewer(new SmartVanillaViewer(this.m[i2])) { // from class: edu.cmu.minorthird.classify.relational.StackedGraphicalLearner.StackedGraphicalClassifier.1
                    @Override // edu.cmu.minorthird.util.gui.TransformedViewer
                    public Object transform(Object obj) {
                        return ((StackedGraphicalClassifier) obj).m[i2];
                    }
                });
            }
            parallelViewer.setContent(this);
            return parallelViewer;
        }
    }

    /* loaded from: input_file:edu/cmu/minorthird/classify/relational/StackedGraphicalLearner$StackingParams.class */
    public static class StackingParams {
        public int stackingDepth = 1;
        public boolean useLogistic = true;
        public boolean useTargetPrediction = true;
        public boolean useConfidence = true;
        public Splitter splitter = new CrossValSplitter(5);
        int crossValSplits = 5;

        public boolean getUseLogisticOnConfidences() {
            return this.useLogistic;
        }

        public void setUseLogisticOnConfidences(boolean z) {
            this.useLogistic = z;
        }

        public boolean getUseConfidences() {
            return this.useConfidence;
        }

        public void setUseConfidences(boolean z) {
            this.useConfidence = z;
        }

        public boolean getUseTargetPrediction() {
            return this.useTargetPrediction;
        }

        public void setUseTargetPrediction(boolean z) {
            this.useTargetPrediction = z;
        }

        public int getStackingDepth() {
            return this.stackingDepth;
        }

        public void setStackingDepth(int i) {
            this.stackingDepth = i;
        }

        public int getCrossValSplits() {
            return this.crossValSplits;
        }

        public void setCrossValSplits(int i) {
            this.splitter = new CrossValSplitter(i);
            this.crossValSplits = i;
        }
    }

    public StackingParams getParams() {
        return this.params;
    }

    public StackedGraphicalLearner() {
        this.baseLearner = new MaxEntLearner();
        this.params = new StackingParams();
    }

    public StackedGraphicalLearner(BatchClassifierLearner batchClassifierLearner) {
        this();
        this.baseLearner = batchClassifierLearner;
        this.params.setStackingDepth(1);
    }

    public StackedGraphicalLearner(BatchClassifierLearner batchClassifierLearner, int i) {
        this();
        this.baseLearner = batchClassifierLearner;
        this.params.setStackingDepth(i);
    }

    @Override // edu.cmu.minorthird.classify.ClassifierLearner
    public final void setSchema(ExampleSchema exampleSchema) {
        this.schema = exampleSchema;
    }

    @Override // edu.cmu.minorthird.classify.relational.StackedBatchClassifierLearner
    public Classifier batchTrain(RealRelationalDataset realRelationalDataset) {
        Classifier[] classifierArr = new Classifier[this.params.stackingDepth + 1];
        RealRelationalDataset realRelationalDataset2 = realRelationalDataset;
        ProgressCounter progressCounter = new ProgressCounter("training stacked learner", "stacking level", this.params.stackingDepth + 1);
        for (int i = 0; i <= this.params.stackingDepth; i++) {
            classifierArr[i] = new DatasetClassifierTeacher(realRelationalDataset2).train(this.baseLearner);
            if (i + 1 <= this.params.stackingDepth) {
                realRelationalDataset2 = stackDataset(realRelationalDataset2);
                new ViewerFrame("Dataset " + (i + 1), new SmartVanillaViewer(realRelationalDataset2));
            }
            progressCounter.progress();
        }
        progressCounter.finished();
        return new StackedGraphicalClassifier(classifierArr, this.params, realRelationalDataset);
    }

    public RealRelationalDataset stackDataset(RealRelationalDataset realRelationalDataset) {
        RealRelationalDataset realRelationalDataset2 = new RealRelationalDataset();
        Dataset.Split split = realRelationalDataset.split(this.params.splitter);
        this.schema = realRelationalDataset.getSchema();
        ProgressCounter progressCounter = new ProgressCounter("labeling for stacking", "fold", split.getNumPartitions());
        HashMap hashMap = new HashMap();
        for (int i = 0; i < split.getNumPartitions(); i++) {
            RealRelationalDataset realRelationalDataset3 = (RealRelationalDataset) split.getTrain(i);
            RealRelationalDataset realRelationalDataset4 = (RealRelationalDataset) split.getTest(i);
            log.info("splitting with " + this.params.splitter + ", preparing to train on " + realRelationalDataset3.size() + " and test on " + realRelationalDataset4.size());
            Classifier train = new DatasetClassifierTeacher(realRelationalDataset3).train(this.baseLearner);
            Example.Looper it = realRelationalDataset4.iterator();
            while (it.hasNext()) {
                SGMExample sGMExample = (SGMExample) it.nextExample();
                hashMap.put(sGMExample.getExampleID(), train.classification(sGMExample));
            }
            log.info("splitting with " + this.params.splitter + ", stored classified dataset");
            progressCounter.progress();
        }
        HashMap linksMap = realRelationalDataset.getLinksMap();
        HashMap aggregators = realRelationalDataset.getAggregators();
        Example.Looper it2 = realRelationalDataset.iterator();
        while (it2.hasNext()) {
            realRelationalDataset2.add(AugmentExample((SGMExample) it2.nextExample(), linksMap, aggregators, hashMap));
        }
        progressCounter.finished();
        return realRelationalDataset2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public SGMExample AugmentExample(SGMExample sGMExample, HashMap hashMap, HashMap hashMap2, HashMap hashMap3) {
        int i = 0;
        Iterator it = hashMap2.keySet().iterator();
        while (it.hasNext()) {
            i += ((HashSet) hashMap2.get(it.next())).size() * this.schema.getNumberOfClasses();
        }
        String[] strArr = new String[i];
        double[] dArr = new double[i];
        int i2 = 0;
        String exampleID = sGMExample.getExampleID();
        if (!hashMap.containsKey(exampleID)) {
            return sGMExample;
        }
        Iterator it2 = hashMap2.keySet().iterator();
        while (it2.hasNext()) {
            String obj = it2.next().toString();
            if (((HashMap) hashMap.get(exampleID)).containsKey(obj)) {
                Iterator it3 = ((HashSet) hashMap2.get(obj)).iterator();
                while (it3.hasNext()) {
                    String str = (String) it3.next();
                    int[] iArr = new int[this.schema.getNumberOfClasses()];
                    Iterator it4 = ((HashSet) ((HashMap) hashMap.get(exampleID)).get(obj)).iterator();
                    while (it4.hasNext()) {
                        String str2 = (String) it4.next();
                        if (hashMap3.get(str2) != null) {
                            int classIndex = this.schema.getClassIndex(((ClassLabel) hashMap3.get(str2)).bestClassName());
                            iArr[classIndex] = iArr[classIndex] + 1;
                        }
                    }
                    for (int i3 = 0; i3 < this.schema.getNumberOfClasses(); i3++) {
                        strArr[i2] = stackFeatureName(obj, str, this.schema.getClassName(i3));
                        if (str.equals("COUNT")) {
                            dArr[i2] = iArr[i3];
                        }
                        if (str.equals("EXISTS") && iArr[i3] > 0) {
                            dArr[i2] = 1.0d;
                        }
                        i2++;
                    }
                }
            }
        }
        String[] strArr2 = new String[i2];
        double[] dArr2 = new double[i2];
        for (int i4 = 0; i4 < i2; i4++) {
            strArr2[i4] = strArr[i4];
            dArr2[i4] = dArr[i4];
        }
        return new SGMExample(new AugmentedInstance(sGMExample.asInstance(), strArr2, dArr2), sGMExample.getLabel(), sGMExample.getExampleID());
    }

    private static String stackFeatureName(String str, String str2, String str3) {
        return "pred." + str + "." + str2 + "." + str3;
    }
}
