package codemining.lm.ngram.cache;

import cc.mallet.optimize.GradientAscent;
import cc.mallet.optimize.Optimizable;
import codemining.lm.ngram.cache.SymbolicWeightCache;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.collect.Multiset;
import com.google.common.math.DoubleMath;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import org.apache.commons.lang.exception.ExceptionUtils;

/* JADX WARN: Classes with same name are omitted:
  input_file:lib/naturalize.jar:codemining/lm/ngram/cache/ParameterOptimizer.class
 */
/* loaded from: input_file:naturalize.jar:codemining/lm/ngram/cache/ParameterOptimizer.class */
class ParameterOptimizer implements Optimizable.ByGradientValue {
    final Multiset<LPair> elems;
    protected static final Logger LOGGER = Logger.getLogger(ParameterOptimizer.class.getName());
    double currentLambda = 0.5d;
    double decay = 0.5d;

    /* JADX WARN: Classes with same name are omitted:
      input_file:lib/naturalize.jar:codemining/lm/ngram/cache/ParameterOptimizer$LPair.class
     */
    /* loaded from: input_file:naturalize.jar:codemining/lm/ngram/cache/ParameterOptimizer$LPair.class */
    public static class LPair {
        double ngramProb;
        List<SymbolicWeightCache.DecayFactor> cacheProb;
        double importance = 1.0d;

        public boolean equals(Object obj) {
            if (!(obj instanceof LPair)) {
                return false;
            }
            LPair lPair = (LPair) obj;
            return Double.compare(lPair.ngramProb, this.ngramProb) == 0 && Double.compare(this.importance, lPair.importance) == 0 && !lPair.cacheProb.equals(this.cacheProb);
        }

        public int hashCode() {
            return Objects.hashCode(Double.valueOf(this.ngramProb), this.cacheProb, Double.valueOf(this.importance));
        }
    }

    public ParameterOptimizer(Multiset<LPair> multiset) {
        this.elems = multiset;
    }

    private double computeValue(double d, double d2) {
        double d3 = 0.0d;
        if (d2 < 0.0d) {
            d2 = 1.0E-9d;
            d3 = 1.0E-9d * 1.0E-9d;
        } else if (d2 > 1.0d) {
            d2 = 0.999999999d;
            d3 = (0.999999999d - 1.0d) * (0.999999999d - 1.0d);
        }
        double d4 = 0.0d;
        for (Multiset.Entry<LPair> entry : this.elems.entrySet()) {
            double d5 = 0.0d;
            Iterator<SymbolicWeightCache.DecayFactor> it = entry.getElement().cacheProb.iterator();
            while (it.hasNext()) {
                d5 += it.next().getForAlpha(d);
            }
            Preconditions.checkArgument(entry.getElement().ngramProb > 0.0d && entry.getElement().ngramProb <= 1.0d);
            d4 += entry.getCount() * DoubleMath.log2((entry.getElement().importance * ((d2 * d5) + ((1.0d - d2) * entry.getElement().ngramProb))) + ((1.0d - entry.getElement().importance) * entry.getElement().ngramProb));
        }
        double size = (d4 / this.elems.size()) - d3;
        Preconditions.checkArgument((Double.isInfinite(size) || Double.isNaN(size)) ? false : true, "Value Should not be NaN or Inf but is " + size + " with sum=" + d4);
        return size;
    }

    public int getNumParameters() {
        return 1;
    }

    public double getParameter(int i) {
        return this.currentLambda;
    }

    public void getParameters(double[] dArr) {
        dArr[0] = this.currentLambda;
    }

    public double getValue() {
        return computeValue(this.decay, this.currentLambda);
    }

    public void getValueGradient(double[] dArr) {
        double d = this.currentLambda;
        double d2 = 0.0d;
        if (d < 0.0d) {
            d = 1.0E-9d;
            d2 = 2.0d * 1.0E-9d;
        } else if (d > 1.0d) {
            d = 1.0d;
            d2 = 2.0d * (1.0d - 1.0d);
        }
        double d3 = 0.0d;
        for (Multiset.Entry<LPair> entry : this.elems.entrySet()) {
            double d4 = 0.0d;
            Iterator<SymbolicWeightCache.DecayFactor> it = entry.getElement().cacheProb.iterator();
            while (it.hasNext()) {
                d4 += it.next().getForAlpha(this.decay);
            }
            Preconditions.checkArgument(d4 >= 0.0d && d4 <= 1.0d, "Decay must be between 0 and 1 but is " + d4);
            d3 += ((entry.getCount() * entry.getElement().importance) * (d4 - entry.getElement().ngramProb)) / (Math.log(2.0d) * ((entry.getElement().importance * ((d * d4) + ((1.0d - d) * entry.getElement().ngramProb))) + ((1.0d - entry.getElement().importance) * entry.getElement().ngramProb)));
        }
        dArr[0] = (d3 / this.elems.size()) - d2;
        Preconditions.checkArgument((Double.isInfinite(dArr[0]) || Double.isNaN(dArr[0])) ? false : true, "gradient(lambda) should not be NaN or Inf but is " + dArr[0]);
        System.err.println(Arrays.toString(dArr));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void optimizeParameters() {
        double d = Double.NEGATIVE_INFINITY;
        double d2 = 0.8d;
        double d3 = 0.3d;
        for (double d4 : new double[]{0.15d, 0.2d, 0.25d, 0.3d, 0.35d, 0.4d, 0.45d, 0.5d, 0.55d, 0.6d, 0.65d, 0.7d, 0.75d, 0.8d, 0.85d, 0.9d, 0.95d, 0.98d, 0.99d}) {
            this.decay = d4;
            this.currentLambda = 0.5d;
            System.err.println("Start at d=" + d4 + " l=" + this.currentLambda);
            GradientAscent gradientAscent = new GradientAscent(this);
            gradientAscent.setMaxStepSize(0.01d);
            gradientAscent.setInitialStepSize(0.01d);
            boolean z = false;
            try {
                z = gradientAscent.optimize();
            } catch (IllegalArgumentException e) {
                LOGGER.warning(ExceptionUtils.getFullStackTrace(e));
            }
            if (d < getValue() && z) {
                LOGGER.info("With l=" + this.currentLambda + " d=" + d4 + " the value is " + getValue());
                Preconditions.checkArgument(Double.compare(d4, this.decay) == 0, "Decay value has been changed!");
                d2 = d4;
                d3 = getParameter(0);
                d = getValue();
            }
        }
        this.decay = d2;
        this.currentLambda = d3;
        LOGGER.info("Optimized parameters for d=" + this.decay + " l=" + this.currentLambda + " val=" + d);
    }

    public void setParameter(int i, double d) {
        if (i == 0) {
            this.currentLambda = d;
        } else if (i == 1) {
            this.decay = d;
        }
    }

    public void setParameters(double[] dArr) {
        this.currentLambda = dArr[0];
    }
}
