package codemining.lm.hmm;

import java.text.DecimalFormat;
import net.didion.jwnl.dictionary.file.DictionaryFile;

/* loaded from: input_file:codemining/lm/hmm/HMM.class */
public class HMM {
    public int numStates;
    public int sigmaSize;
    public double[] pi;
    public double[][] a;
    public double[][] b;

    public HMM(int i, int i2) {
        this.numStates = i;
        this.sigmaSize = i2;
        this.pi = new double[i];
        this.a = new double[i][i];
        this.b = new double[i][i2];
    }

    public void train(int[] iArr, int i) {
        int length = iArr.length;
        double[] dArr = new double[this.numStates];
        double[][] dArr2 = new double[this.numStates][this.numStates];
        double[][] dArr3 = new double[this.numStates][this.sigmaSize];
        for (int i2 = 0; i2 < i; i2++) {
            double[][] forwardProc = forwardProc(iArr);
            double[][] backwardProc = backwardProc(iArr);
            for (int i3 = 0; i3 < this.numStates; i3++) {
                dArr[i3] = gamma(i3, 0, iArr, forwardProc, backwardProc);
            }
            for (int i4 = 0; i4 < this.numStates; i4++) {
                for (int i5 = 0; i5 < this.numStates; i5++) {
                    double d = 0.0d;
                    double d2 = 0.0d;
                    for (int i6 = 0; i6 <= length - 1; i6++) {
                        d += p(i6, i4, i5, iArr, forwardProc, backwardProc);
                        d2 += gamma(i4, i6, iArr, forwardProc, backwardProc);
                    }
                    dArr2[i4][i5] = divide(d, d2);
                }
            }
            for (int i7 = 0; i7 < this.numStates; i7++) {
                int i8 = 0;
                while (i8 < this.sigmaSize) {
                    double d3 = 0.0d;
                    double d4 = 0.0d;
                    for (int i9 = 0; i9 <= length - 1; i9++) {
                        double gamma = gamma(i7, i9, iArr, forwardProc, backwardProc);
                        d3 += gamma * (i8 == iArr[i9] ? 1 : 0);
                        d4 += gamma;
                    }
                    dArr3[i7][i8] = divide(d3, d4);
                    i8++;
                }
            }
            this.pi = dArr;
            this.a = dArr2;
            this.b = dArr3;
        }
    }

    public double[][] forwardProc(int[] iArr) {
        int length = iArr.length;
        double[][] dArr = new double[this.numStates][length];
        for (int i = 0; i < this.numStates; i++) {
            dArr[i][0] = this.pi[i] * this.b[i][iArr[0]];
        }
        for (int i2 = 0; i2 <= length - 2; i2++) {
            for (int i3 = 0; i3 < this.numStates; i3++) {
                dArr[i3][i2 + 1] = 0.0d;
                for (int i4 = 0; i4 < this.numStates; i4++) {
                    double[] dArr2 = dArr[i3];
                    int i5 = i2 + 1;
                    dArr2[i5] = dArr2[i5] + (dArr[i4][i2] * this.a[i4][i3]);
                }
                double[] dArr3 = dArr[i3];
                int i6 = i2 + 1;
                dArr3[i6] = dArr3[i6] * this.b[i3][iArr[i2 + 1]];
            }
        }
        return dArr;
    }

    public double[][] backwardProc(int[] iArr) {
        int length = iArr.length;
        double[][] dArr = new double[this.numStates][length];
        for (int i = 0; i < this.numStates; i++) {
            dArr[i][length - 1] = 1.0d;
        }
        for (int i2 = length - 2; i2 >= 0; i2--) {
            for (int i3 = 0; i3 < this.numStates; i3++) {
                dArr[i3][i2] = 0.0d;
                for (int i4 = 0; i4 < this.numStates; i4++) {
                    double[] dArr2 = dArr[i3];
                    int i5 = i2;
                    dArr2[i5] = dArr2[i5] + (dArr[i4][i2 + 1] * this.a[i3][i4] * this.b[i4][iArr[i2 + 1]]);
                }
            }
        }
        return dArr;
    }

    public double p(int i, int i2, int i3, int[] iArr, double[][] dArr, double[][] dArr2) {
        double d = i == iArr.length - 1 ? dArr[i2][i] * this.a[i2][i3] : dArr[i2][i] * this.a[i2][i3] * this.b[i3][iArr[i + 1]] * dArr2[i3][i + 1];
        double d2 = 0.0d;
        for (int i4 = 0; i4 < this.numStates; i4++) {
            d2 += dArr[i4][i] * dArr2[i4][i];
        }
        return divide(d, d2);
    }

    public double gamma(int i, int i2, int[] iArr, double[][] dArr, double[][] dArr2) {
        double d = dArr[i][i2] * dArr2[i][i2];
        double d2 = 0.0d;
        for (int i3 = 0; i3 < this.numStates; i3++) {
            d2 += dArr[i3][i2] * dArr2[i3][i2];
        }
        return divide(d, d2);
    }

    public void print() {
        DecimalFormat decimalFormat = new DecimalFormat();
        decimalFormat.setMinimumFractionDigits(5);
        decimalFormat.setMaximumFractionDigits(5);
        for (int i = 0; i < this.numStates; i++) {
            System.out.println("pi(" + i + ") = " + decimalFormat.format(this.pi[i]));
        }
        System.out.println();
        for (int i2 = 0; i2 < this.numStates; i2++) {
            for (int i3 = 0; i3 < this.numStates; i3++) {
                System.out.print("a(" + i2 + "," + i3 + ") = " + decimalFormat.format(this.a[i2][i3]) + DictionaryFile.COMMENT_HEADER);
            }
            System.out.println();
        }
        System.out.println();
        for (int i4 = 0; i4 < this.numStates; i4++) {
            for (int i5 = 0; i5 < this.sigmaSize; i5++) {
                System.out.print("b(" + i4 + "," + i5 + ") = " + decimalFormat.format(this.b[i4][i5]) + DictionaryFile.COMMENT_HEADER);
            }
            System.out.println();
        }
    }

    public double divide(double d, double d2) {
        if (d == 0.0d) {
            return 0.0d;
        }
        return d / d2;
    }
}
