package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/pr/CRFTrainerByPR.class */
public class CRFTrainerByPR extends TransducerTrainer implements TransducerTrainer.ByOptimization {
    private boolean converged;
    private int iter;
    private int numThreads;
    private double pGpv;
    private double tolerance;
    private double value;
    private double qValue;
    private ArrayList<PRConstraint> constraints;
    private LimitedMemoryBFGS bfgs;
    private CRF crf;
    private StateLabelMap stateLabelMap;
    static final /* synthetic */ boolean $assertionsDisabled;

    public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> arrayList) {
        this(crf, arrayList, 1);
    }

    public CRFTrainerByPR(CRF crf, ArrayList<PRConstraint> arrayList, int i) {
        this.crf = crf;
        this.iter = 0;
        this.value = Double.NEGATIVE_INFINITY;
        this.constraints = arrayList;
        this.pGpv = 10.0d;
        this.tolerance = 0.001d;
        this.numThreads = i;
        this.stateLabelMap = new StateLabelMap(crf.getOutputAlphabet(), true);
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return this.iter;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return this.converged;
    }

    public void setStateLabelMap(StateLabelMap stateLabelMap) {
        this.stateLabelMap = stateLabelMap;
    }

    public void setPGaussianPriorVariance(double d) {
        this.pGpv = d;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        return train(instanceList, 0, i);
    }

    public boolean train(InstanceList instanceList, int i, int i2) {
        return train(instanceList, i, i2, Integer.MAX_VALUE);
    }

    public boolean train(InstanceList instanceList, int i, int i2, int i3) {
        double d = 0.0d;
        int i4 = this.iter + i2;
        BitSet bitSet = new BitSet();
        Iterator<PRConstraint> it = this.constraints.iterator();
        while (it.hasNext()) {
            PRConstraint next = it.next();
            bitSet.or(next.preProcess(instanceList));
            next.setStateLabelMap(this.stateLabelMap);
        }
        int i5 = 0;
        InstanceList cloneEmpty = instanceList.cloneEmpty();
        for (int i6 = 0; i6 < instanceList.size(); i6++) {
            if (bitSet.get(i6)) {
                cloneEmpty.add(instanceList.get(i6));
            } else {
                i5++;
            }
        }
        System.err.println("Removed " + i5 + " instances that do not contain constraints.");
        PRAuxiliaryModel pRAuxiliaryModel = new PRAuxiliaryModel(this.crf, this.constraints);
        while (true) {
            if (this.iter >= i4) {
                break;
            }
            long currentTimeMillis = System.currentTimeMillis();
            ConstraintsOptimizableByPR constraintsOptimizableByPR = new ConstraintsOptimizableByPR(this.crf, cloneEmpty, pRAuxiliaryModel, this.numThreads);
            this.bfgs = new LimitedMemoryBFGS(constraintsOptimizableByPR);
            try {
                this.bfgs.optimize(i3);
            } catch (Exception e) {
                e.printStackTrace();
            }
            constraintsOptimizableByPR.shutdown();
            this.qValue = constraintsOptimizableByPR.getCompleteValueContribution();
            if (!$assertionsDisabled && this.qValue <= 0.0d) {
                throw new AssertionError();
            }
            CRFOptimizableByKL cRFOptimizableByKL = new CRFOptimizableByKL(this.crf, cloneEmpty, pRAuxiliaryModel, constraintsOptimizableByPR.getCachedDots(), this.numThreads, 1.0d);
            cRFOptimizableByKL.setGaussianPriorVariance(this.pGpv);
            try {
                new LimitedMemoryBFGS(cRFOptimizableByKL).optimize(i3);
            } catch (Exception e2) {
                e2.printStackTrace();
            }
            cRFOptimizableByKL.shutdown();
            this.value = cRFOptimizableByKL.getValue() - this.qValue;
            if (!$assertionsDisabled && this.value >= 0.0d) {
                throw new AssertionError();
            }
            System.err.println("Total value = " + this.value + " (pValue = " + cRFOptimizableByKL.getValue() + ") (qValue = " + (-this.qValue) + ")");
            System.err.println("Time for iteration " + String.format("%.2f", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d)) + "s");
            if (this.iter >= i && 2.0d * Math.abs(this.value - d) <= this.tolerance * (Math.abs(this.value) + Math.abs(d) + 1.0E-5d)) {
                System.err.println("AP value difference below tolerance (oldValue: " + d + "newValue: " + this.value);
                break;
            }
            d = this.value;
            runEvaluators();
            this.iter++;
        }
        this.converged = true;
        return this.converged;
    }

    public double getTotalValue() {
        return this.value;
    }

    public double getQValue() {
        return this.qValue;
    }

    @Override // cc.mallet.fst.TransducerTrainer.ByOptimization
    public Optimizer getOptimizer() {
        return this.bfgs;
    }

    static {
        $assertionsDisabled = !CRFTrainerByPR.class.desiredAssertionStatus();
    }
}
