package edu.columbia.stat.wood.pub.sequencememoizer.util;

import java.io.Serializable;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;

/* loaded from: input_file:edu/columbia/stat/wood/pub/sequencememoizer/util/Discounts.class */
public class Discounts implements Serializable {
    static final long serialVersionUID = 1;
    private double[] discounts;
    private double[] logDiscounts;
    private double[] discountGradient;
    private double alpha;
    private double alphaGradient;

    public Discounts(double[] dArr, double d) {
        if (d <= 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("dInfinity must be in the interval (0.0,1.0)");
        }
        this.alpha = Math.log(d) / (Math.log(d) + Math.log(dArr[dArr.length - 1]));
        this.discounts = dArr;
        this.logDiscounts = new double[dArr.length];
        this.discountGradient = new double[dArr.length];
        fillLogDiscounts();
    }

    public double get(int i) {
        if (i >= this.discounts.length) {
            throw new IllegalArgumentException("Must only get discounts with an index in [0, length())");
        }
        return this.discounts[i];
    }

    public void set(int i, double d) {
        if (i >= this.discounts.length) {
            throw new IllegalArgumentException("Must only set discounts with an index in [0, length())");
        }
        this.discounts[i] = d;
    }

    public double getdInfinity() {
        return Math.pow(this.discounts[this.discounts.length - 1], this.alpha / (1.0d - this.alpha));
    }

    public void setDInfinity(double d) {
        this.alpha = Math.log(d) / (Math.log(d) + Math.log(this.discounts[this.discounts.length - 1]));
    }

    public double get(int i, int i2) {
        if (i >= i2 && i != 0 && i2 != 0) {
            throw new IllegalArgumentException("parent depth (" + i + ") must be less than depth of this restaurant (" + i2 + DefaultExpressionEngine.DEFAULT_INDEX_END);
        }
        int i3 = i + 1;
        double d = 0.0d;
        if (i2 == 0) {
            d = this.logDiscounts[0];
        } else {
            while (i3 <= i2 && i3 < this.discounts.length - 1) {
                int i4 = i3;
                i3++;
                d += this.logDiscounts[i4];
            }
            if (i2 >= this.discounts.length - 1) {
                d += ((this.logDiscounts[this.discounts.length - 1] * Math.pow(this.alpha, (i3 - this.discounts.length) + 1.0d)) * (1.0d - Math.pow(this.alpha, (i2 - i3) + 1.0d))) / (1.0d - this.alpha);
            }
        }
        return Math.exp(d);
    }

    public int length() {
        return this.discounts.length;
    }

    public void clearGradient() {
        for (int i = 0; i < this.discountGradient.length; i++) {
            this.discountGradient[i] = 0.0d;
        }
        this.alphaGradient = 0.0d;
    }

    public void updateGradient(int i, int i2, int i3, int i4, int i5, double d, double d2, double d3) {
        if (i4 > 0) {
            int i6 = i + 1;
            if (i2 == 0) {
                double d4 = 1.0d / this.discounts[0];
                double[] dArr = this.discountGradient;
                dArr[0] = dArr[0] + ((((((i5 * d) - i3) * d2) * d4) / i4) * d3);
                return;
            }
            while (i6 <= i2 && i6 < this.discounts.length - 1) {
                double d5 = 1.0d / this.discounts[i6];
                double[] dArr2 = this.discountGradient;
                int i7 = i6;
                dArr2[i7] = dArr2[i7] + ((((((i5 * d) - i3) * d2) * d5) / i4) * d3);
                i6++;
            }
            if (i2 >= this.discounts.length - 1) {
                double length = (i6 - this.discounts.length) + 1.0d;
                double d6 = (i2 - i6) + 1.0d;
                double pow = ((Math.pow(this.alpha, length) * (1.0d - Math.pow(this.alpha, d6))) / (1.0d - this.alpha)) / this.discounts[this.discounts.length - 1];
                double[] dArr3 = this.discountGradient;
                int length2 = this.discounts.length - 1;
                dArr3[length2] = dArr3[length2] + ((((((i5 * d) - i3) * d2) * pow) / i4) * d3);
                this.alphaGradient += (((((i5 * d) - i3) * d2) * (this.logDiscounts[this.discounts.length - 1] * ((((length * Math.pow(this.alpha, length - 1.0d)) - ((length + d6) * Math.pow(this.alpha, (length + d6) - 1.0d))) / (1.0d - this.alpha)) + (((Math.pow(this.alpha, length) - Math.pow(this.alpha, length + d6)) / (1.0d - this.alpha)) / (1.0d - this.alpha))))) / i4) * d3;
            }
        }
    }

    public void stepDiscounts(double d, double d2) {
        double d3 = d2 < 0.05d ? 0.05d : d2;
        for (int i = 0; i < this.discountGradient.length; i++) {
            double d4 = this.discounts[i] + ((d * this.discountGradient[i]) / d3);
            if (d4 > 1.0d) {
                this.discounts[i] = 1.0d;
            } else if (d4 < 0.0d) {
                this.discounts[i] = 1.0E-8d;
            } else {
                this.discounts[i] = d4;
            }
        }
        double d5 = this.alpha + ((d * this.alphaGradient) / d3);
        if (d5 >= 1.0d) {
            d5 = this.alpha + ((1.0d - this.alpha) / 2.0d);
        } else if (d5 <= 0.0d) {
            d5 = this.alpha / 2.0d;
        }
        if (d5 < 1.0d && d5 > 0.0d) {
            this.alpha = d5;
        }
        clearGradient();
        fillLogDiscounts();
    }

    private void fillLogDiscounts() {
        int i = 0;
        for (double d : this.discounts) {
            int i2 = i;
            i++;
            this.logDiscounts[i2] = Math.log(d);
        }
    }

    public void print() {
        System.out.print("[" + this.discounts[0]);
        for (int i = 1; i < this.discounts.length; i++) {
            System.out.print(", " + this.discounts[i]);
        }
        System.out.println(DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
        System.out.println("The infinite discount is = " + Math.pow(this.discounts[this.discounts.length - 1], this.alpha / (1.0d - this.alpha)));
    }
}
