package iitb.CRF;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.AbstractFormatter;
import java.util.Iterator;

/* loaded from: input_file:iitb/CRF/SegmentTrainer.class */
public class SegmentTrainer extends SparseTrainer {
    protected DoubleMatrix1D[] alpha_Y_Array;
    protected DoubleMatrix1D[] alpha_Y_ArrayM;
    protected boolean[] initAlphaMDone;
    protected DoubleMatrix1D allZeroVector;

    public SegmentTrainer(CrfParams crfParams) {
        super(crfParams);
        this.logTrainer = true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // iitb.CRF.Trainer
    public void init(CRF crf, DataIter dataIter, double[] dArr) {
        super.init(crf, dataIter, dArr);
        this.allZeroVector = newLogDoubleMatrix1D(this.numY);
        this.allZeroVector.assign(0.0d);
    }

    @Override // iitb.CRF.SparseTrainer, iitb.CRF.Trainer
    protected double computeFunctionGradient(double[] dArr, double[] dArr2) {
        try {
            FeatureGeneratorNested featureGeneratorNested = (FeatureGeneratorNested) this.featureGenerator;
            double d = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                dArr2[i] = (-1.0d) * dArr[i] * this.params.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.params.invSigmaSquare) / 2.0d;
            }
            this.diter.startScan();
            this.initMDone = false;
            if (this.featureGenCache != null) {
                this.featureGenCache.startDataScan();
            }
            int i2 = 0;
            while (this.diter.hasNext()) {
                CandSegDataSequence candSegDataSequence = (CandSegDataSequence) this.diter.next();
                if (this.featureGenCache != null) {
                    this.featureGenCache.nextDataIndex();
                }
                if (this.params.debugLvl > 1) {
                    Util.printDbg(new StringBuffer().append("Read next seq: ").append(i2).append(" logli ").append(d).toString());
                }
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    this.ExpF[i3] = RobustMath.LOG0;
                }
                if (this.alpha_Y_Array == null || this.alpha_Y_Array.length < candSegDataSequence.length() - (-1)) {
                    allocateAlphaBeta((2 * candSegDataSequence.length()) + 1);
                }
                if (this.reuseM) {
                    for (int length = candSegDataSequence.length(); length >= 0; length--) {
                        this.initAlphaMDone[length] = false;
                    }
                }
                int length2 = candSegDataSequence.length();
                DoubleMatrix1D doubleMatrix1D = this.beta_Y[candSegDataSequence.length() - 1];
                this.beta_Y[candSegDataSequence.length() - 1] = this.allZeroVector;
                for (int length3 = candSegDataSequence.length() - 2; length3 >= 0; length3--) {
                    this.beta_Y[length3].assign(RobustMath.LOG0);
                }
                for (int length4 = candSegDataSequence.length() - 1; length4 >= 0; length4--) {
                    for (int numCandSegmentsEndingAt = candSegDataSequence.numCandSegmentsEndingAt(length4) - 1; numCandSegmentsEndingAt >= 0; numCandSegmentsEndingAt--) {
                        int candSegmentStart = candSegDataSequence.candSegmentStart(length4, numCandSegmentsEndingAt);
                        int i4 = (length4 - candSegmentStart) + 1;
                        int i5 = candSegmentStart - 1;
                        if (i5 >= 0) {
                            this.initMDone = computeLogMi(candSegDataSequence, i5, i5 + i4, featureGeneratorNested, dArr, this.Mi_YY, this.Ri_Y, this.reuseM, this.initMDone);
                            this.tmp_Y.assign(this.Ri_Y);
                            if (i5 + i4 < length2 - 1) {
                                this.tmp_Y.assign(this.beta_Y[i5 + i4], sumFunc);
                            }
                            if (this.reuseM) {
                                this.beta_Y[i5].assign(this.tmp_Y, RobustMath.logSumExpFunc);
                            } else {
                                this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[i5], 1.0d, 1.0d, false);
                            }
                        }
                    }
                    if (this.reuseM && length4 - 1 >= 0) {
                        this.tmp_Y.assign(this.beta_Y[length4 - 1]);
                        this.Mi_YY.zMult(this.tmp_Y, this.beta_Y[length4 - 1], 1.0d, 0.0d, false);
                    }
                }
                double d2 = 0.0d;
                this.alpha_Y_Array[0] = this.allZeroVector;
                int i6 = -1;
                int i7 = 0;
                boolean z = true;
                boolean z2 = true;
                for (int i8 = 0; i8 < length2; i8++) {
                    this.alpha_Y_Array[i8 - (-1)].assign(RobustMath.LOG0);
                    if (i6 < i8) {
                        if (!z && z2) {
                            System.out.println(new StringBuffer().append("Error: Training segment (").append(i7).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(i6).append(") not found amongst candidate segments").toString());
                        }
                        z = false;
                        i7 = i8;
                        i6 = candSegDataSequence.getSegmentEnd(i8);
                    }
                    for (int numCandSegmentsEndingAt2 = candSegDataSequence.numCandSegmentsEndingAt(i8) - 1; numCandSegmentsEndingAt2 >= 0; numCandSegmentsEndingAt2--) {
                        int candSegmentStart2 = (i8 - candSegDataSequence.candSegmentStart(i8, numCandSegmentsEndingAt2)) + 1;
                        this.initMDone = computeLogMi(candSegDataSequence, i8 - candSegmentStart2, i8, featureGeneratorNested, dArr, this.Mi_YY, this.Ri_Y, this.reuseM, this.initMDone);
                        if (i8 - candSegmentStart2 >= 0) {
                            if (this.reuseM) {
                                if (!this.initAlphaMDone[(i8 - candSegmentStart2) - (-1)]) {
                                    this.alpha_Y_ArrayM[(i8 - candSegmentStart2) - (-1)].assign(RobustMath.LOG0);
                                    this.Mi_YY.zMult(this.alpha_Y_Array[(i8 - candSegmentStart2) - (-1)], this.alpha_Y_ArrayM[(i8 - candSegmentStart2) - (-1)], 1.0d, 0.0d, true);
                                    this.initAlphaMDone[(i8 - candSegmentStart2) - (-1)] = true;
                                }
                                this.newAlpha_Y.assign(this.alpha_Y_ArrayM[(i8 - candSegmentStart2) - (-1)]);
                            } else {
                                this.Mi_YY.zMult(this.alpha_Y_Array[(i8 - candSegmentStart2) - (-1)], this.newAlpha_Y, 1.0d, 0.0d, true);
                            }
                            this.newAlpha_Y.assign(this.Ri_Y, sumFunc);
                        } else {
                            this.newAlpha_Y.assign(this.Ri_Y);
                        }
                        this.alpha_Y_Array[i8 - (-1)].assign(this.newAlpha_Y, RobustMath.logSumExpFunc);
                        featureGeneratorNested.startScanFeaturesAt(candSegDataSequence, i8 - candSegmentStart2, i8);
                        while (featureGeneratorNested.hasNext()) {
                            Feature next = featureGeneratorNested.next();
                            int index = next.index();
                            int y = next.y();
                            int yprev = next.yprev();
                            float value = next.value();
                            if (candSegDataSequence.holdsInTrainingData(next, i8 - candSegmentStart2, i8)) {
                                dArr2[index] = dArr2[index] + value;
                                d2 += value * dArr[index];
                                z2 = false;
                                if (this.params.debugLvl > 2) {
                                    System.out.println(new StringBuffer().append("Feature fired ").append(index).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(next).toString());
                                }
                            }
                            if (yprev < 0) {
                                this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.newAlpha_Y.get(y) + RobustMath.log(value) + this.beta_Y[i8].get(y));
                            } else {
                                this.ExpF[index] = RobustMath.logSumExp(this.ExpF[index], this.alpha_Y_Array[(i8 - candSegmentStart2) - (-1)].get(yprev) + this.Ri_Y.get(y) + this.Mi_YY.get(yprev, y) + RobustMath.log(value) + this.beta_Y[i8].get(y));
                            }
                        }
                        if (i8 == i6 && (i8 - candSegmentStart2) + 1 == i7) {
                            z = true;
                            double d3 = this.Ri_Y.get(candSegDataSequence.y(i6));
                            double d4 = i7 > 0 ? this.Mi_YY.get(candSegDataSequence.y(i7 - 1), candSegDataSequence.y(i6)) : 0.0d;
                            if (d3 == RobustMath.LOG0 || d4 == RobustMath.LOG0) {
                                System.out.println(new StringBuffer().append("Error: training labels not covered in generated features ").append(d3).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(d4).append(" yprev ").append(candSegDataSequence.y(i7 - 1)).append(" y ").append(candSegDataSequence.y(i6)).toString());
                                System.out.println(candSegDataSequence);
                                featureGeneratorNested.startScanFeaturesAt(candSegDataSequence, i8 - candSegmentStart2, i8);
                                while (featureGeneratorNested.hasNext()) {
                                    Feature next2 = featureGeneratorNested.next();
                                    System.out.println(new StringBuffer().append(next2).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(next2.yprev()).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(next2.y()).toString());
                                }
                            }
                        }
                    }
                    if (this.params.debugLvl > 2) {
                        System.out.println(new StringBuffer().append("Alpha-i ").append(this.alpha_Y_Array[i8 - (-1)].toString()).toString());
                        System.out.println(new StringBuffer().append("Ri ").append(this.Ri_Y.toString()).toString());
                        System.out.println(new StringBuffer().append("Mi ").append(this.Mi_YY.toString()).toString());
                        System.out.println(new StringBuffer().append("Beta-i ").append(this.beta_Y[i8].toString()).toString());
                    }
                }
                double zSum = this.alpha_Y_Array[(candSegDataSequence.length() - 1) - (-1)].zSum();
                double d5 = d2 - zSum;
                d += d5;
                for (int i9 = 0; i9 < dArr2.length; i9++) {
                    int i10 = i9;
                    dArr2[i10] = dArr2[i10] - expLE(this.ExpF[i9] - zSum);
                }
                if (z2) {
                    System.out.println("WARNING: no features fired in the training set");
                }
                if (d5 > 0.0d) {
                    System.out.println(new StringBuffer().append("ERROR: something is wrong Pr(y|x) > 1! for sequence ").append(i2).toString());
                    System.out.println(candSegDataSequence);
                }
                if (this.params.debugLvl > 1 || d5 > 0.0d) {
                    System.out.println(new StringBuffer().append("Sequence likelihood ").append(d5).append(" lZx ").append(zSum).append(" Zx ").append(Math.exp(zSum)).toString());
                    System.out.println(new StringBuffer().append("Last Alpha-i ").append(this.alpha_Y_Array[(candSegDataSequence.length() - 1) - (-1)].toString()).toString());
                }
                this.beta_Y[candSegDataSequence.length() - 1] = doubleMatrix1D;
                i2++;
            }
            if (this.params.debugLvl > 2) {
                for (double d6 : dArr) {
                    System.out.print(new StringBuffer().append(d6).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString());
                }
                System.out.println(" :x");
                for (int i11 = 0; i11 < dArr.length; i11++) {
                    System.out.println(new StringBuffer().append(i11).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(featureGeneratorNested.featureName(i11)).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).append(dArr2[i11]).append(AbstractFormatter.DEFAULT_COLUMN_SEPARATOR).toString());
                }
                System.out.println(" :g");
            }
            if (this.params.debugLvl > 0) {
                if (this.icall == 0) {
                    Util.printDbg(new StringBuffer().append("Number of training records ").append(i2).toString());
                }
                Util.printDbg(new StringBuffer().append("Iter ").append(this.icall).append(" loglikelihood ").append(d).append(" gnorm ").append(norm(dArr2)).append(" xnorm ").append(norm(dArr)).toString());
            }
            return d;
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
            return 0.0d;
        }
    }

    protected void allocateAlphaBeta(int i) {
        this.alpha_Y_Array = new DoubleMatrix1D[i];
        for (int i2 = 0; i2 < this.alpha_Y_Array.length; i2++) {
            this.alpha_Y_Array[i2] = newLogDoubleMatrix1D(this.numY);
        }
        this.beta_Y = new DoubleMatrix1D[i];
        for (int i3 = 0; i3 < this.beta_Y.length; i3++) {
            this.beta_Y[i3] = newLogDoubleMatrix1D(this.numY);
        }
        this.alpha_Y_ArrayM = new DoubleMatrix1D[i];
        for (int i4 = 0; i4 < this.alpha_Y_ArrayM.length; i4++) {
            this.alpha_Y_ArrayM[i4] = newLogDoubleMatrix1D(this.numY);
        }
        this.initAlphaMDone = new boolean[i];
    }

    public static double initLogMi(CandSegDataSequence candSegDataSequence, int i, int i2, FeatureGeneratorNested featureGeneratorNested, double[] dArr, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D) {
        featureGeneratorNested.startScanFeaturesAt(candSegDataSequence, i, i2);
        Iterator constraints = candSegDataSequence.constraints(i, i2);
        double d = RobustMath.LOG0;
        if (doubleMatrix2D != null) {
            doubleMatrix2D.assign(d);
        }
        doubleMatrix1D.assign(d);
        if (constraints != null) {
            while (constraints.hasNext()) {
                Constraint constraint = (Constraint) constraints.next();
                if (constraint.type() == 3) {
                    RestrictConstraint restrictConstraint = (RestrictConstraint) constraint;
                    restrictConstraint.startScan();
                    while (restrictConstraint.hasNext()) {
                        restrictConstraint.advance();
                        int y = restrictConstraint.y();
                        int yprev = restrictConstraint.yprev();
                        if (yprev < 0) {
                            doubleMatrix1D.set(y, 0.0d);
                        } else if (doubleMatrix2D != null) {
                            doubleMatrix2D.set(yprev, y, 0.0d);
                        }
                    }
                }
            }
        } else {
            d = 0.0d;
            if (doubleMatrix2D != null) {
                doubleMatrix2D.assign(0.0d);
            }
            doubleMatrix1D.assign(0.0d);
        }
        return d;
    }

    static boolean computeLogMi(CandSegDataSequence candSegDataSequence, int i, int i2, FeatureGeneratorNested featureGeneratorNested, double[] dArr, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D, boolean z, boolean z2) {
        if (z && z2) {
            doubleMatrix2D = null;
        }
        computeLogMi(candSegDataSequence, i, i2, featureGeneratorNested, dArr, doubleMatrix2D, doubleMatrix1D);
        if (i >= 0 && z) {
            z2 = true;
        }
        return z2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void computeLogMi(CandSegDataSequence candSegDataSequence, int i, int i2, FeatureGeneratorNested featureGeneratorNested, double[] dArr, DoubleMatrix2D doubleMatrix2D, DoubleMatrix1D doubleMatrix1D) {
        SparseTrainer.computeLogMiInitDone(featureGeneratorNested, dArr, doubleMatrix2D, doubleMatrix1D, initLogMi(candSegDataSequence, i, i2, featureGeneratorNested, dArr, doubleMatrix2D, doubleMatrix1D));
    }
}
