package cc.mallet.classify;

import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.classify.constraints.ge.MaxEntRangeL2FLGEConstraints;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:cc/mallet/classify/MaxEntGERangeTrainer.class */
public class MaxEntGERangeTrainer extends ClassifierTrainer<MaxEnt> implements ClassifierTrainer.ByOptimization<MaxEnt>, Boostable, Serializable {
    private static final long serialVersionUID = 1;
    private static Logger logger = MalletLogger.getLogger(MaxEntGERangeTrainer.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntGERangeTrainer.class.getName() + "-pl");
    private boolean normalize;
    private boolean useValues;
    private String constraintsFile;
    private int numIterations;
    private int maxIterations;
    private double temperature;
    private double gaussianPriorVariance;
    protected ArrayList<MaxEntGEConstraint> constraints;
    private InstanceList trainingList;
    private MaxEnt classifier;
    private MaxEntOptimizableByGE ge;
    private Optimizer opt;

    public MaxEntGERangeTrainer() {
        this.normalize = true;
        this.useValues = false;
        this.numIterations = 0;
        this.maxIterations = Integer.MAX_VALUE;
        this.temperature = 1.0d;
        this.gaussianPriorVariance = 1.0d;
        this.trainingList = null;
        this.classifier = null;
        this.ge = null;
        this.opt = null;
    }

    public MaxEntGERangeTrainer(ArrayList<MaxEntGEConstraint> arrayList) {
        this.normalize = true;
        this.useValues = false;
        this.numIterations = 0;
        this.maxIterations = Integer.MAX_VALUE;
        this.temperature = 1.0d;
        this.gaussianPriorVariance = 1.0d;
        this.trainingList = null;
        this.classifier = null;
        this.ge = null;
        this.opt = null;
        this.constraints = arrayList;
    }

    public MaxEntGERangeTrainer(ArrayList<MaxEntGEConstraint> arrayList, MaxEnt maxEnt) {
        this.normalize = true;
        this.useValues = false;
        this.numIterations = 0;
        this.maxIterations = Integer.MAX_VALUE;
        this.temperature = 1.0d;
        this.gaussianPriorVariance = 1.0d;
        this.trainingList = null;
        this.classifier = null;
        this.ge = null;
        this.opt = null;
        this.constraints = arrayList;
        this.classifier = maxEnt;
    }

    public void setConstraintsFile(String str) {
        this.constraintsFile = str;
    }

    public void setTemperature(double d) {
        this.temperature = d;
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public MaxEnt getClassifier() {
        return this.classifier;
    }

    public void setUseValues(boolean z) {
        this.useValues = z;
    }

    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    public Optimizable.ByGradientValue getOptimizable(InstanceList instanceList) {
        if (this.ge == null) {
            this.ge = new MaxEntOptimizableByGE(instanceList, this.constraints, this.classifier);
            this.ge.setTemperature(this.temperature);
            this.ge.setGaussianPriorVariance(this.gaussianPriorVariance);
        }
        return this.ge;
    }

    @Override // cc.mallet.classify.ClassifierTrainer.ByOptimization
    public Optimizer getOptimizer() {
        getOptimizable(this.trainingList);
        if (this.opt == null) {
            this.opt = new LimitedMemoryBFGS(this.ge);
        }
        return this.opt;
    }

    public void setOptimizer(Optimizer optimizer) {
        this.opt = optimizer;
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    @Override // cc.mallet.classify.ClassifierTrainer.ByOptimization
    public int getIteration() {
        return this.numIterations;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer
    public MaxEnt train(InstanceList instanceList) {
        return train(instanceList, this.maxIterations);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // cc.mallet.classify.ClassifierTrainer.ByOptimization
    public MaxEnt train(InstanceList instanceList, int i) {
        this.trainingList = instanceList;
        if (this.constraints == null && this.constraintsFile != null) {
            HashMap<Integer, double[][]> readRangeConstraintsFromFile = FeatureConstraintUtil.readRangeConstraintsFromFile(this.constraintsFile, this.trainingList);
            logger.info("number of constraints: " + readRangeConstraintsFromFile.size());
            this.constraints = new ArrayList<>();
            MaxEntRangeL2FLGEConstraints maxEntRangeL2FLGEConstraints = new MaxEntRangeL2FLGEConstraints(instanceList.getDataAlphabet().size(), instanceList.getTargetAlphabet().size(), this.useValues, this.normalize);
            Iterator<Integer> it = readRangeConstraintsFromFile.keySet().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                double[][] dArr = readRangeConstraintsFromFile.get(Integer.valueOf(intValue));
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    if (!Double.isInfinite(dArr[i2][0])) {
                        maxEntRangeL2FLGEConstraints.addConstraint(intValue, i2, dArr[i2][0], dArr[i2][1], 1.0d);
                    }
                }
            }
            this.constraints.add(maxEntRangeL2FLGEConstraints);
        }
        getOptimizable(this.trainingList);
        getOptimizer();
        if (this.opt instanceof LimitedMemoryBFGS) {
            ((LimitedMemoryBFGS) this.opt).reset();
        }
        logger.fine("trainingList.size() = " + this.trainingList.size());
        try {
            this.opt.optimize(i);
            this.numIterations += i;
        } catch (Exception e) {
            e.printStackTrace();
            logger.info("Catching exception; saying converged.");
        }
        if (i == Integer.MAX_VALUE && (this.opt instanceof LimitedMemoryBFGS)) {
            ((LimitedMemoryBFGS) this.opt).reset();
            try {
                this.opt.optimize(i);
                this.numIterations += i;
            } catch (Exception e2) {
                e2.printStackTrace();
                logger.info("Catching exception; saying converged.");
            }
        }
        progressLogger.info(IOUtils.LINE_SEPARATOR_UNIX);
        this.classifier = this.ge.getClassifier();
        return this.classifier;
    }
}
