package codemining.lm.ngram.smoothing;

import codemining.lm.ILanguageModel;
import codemining.lm.ngram.AbstractNGramLM;
import codemining.lm.ngram.NGram;
import codemining.lm.ngram.Trie;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.TreeMultiset;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.Stack;
import java.util.TreeMap;
import java.util.logging.Logger;
import org.apache.commons.lang.exception.ExceptionUtils;

/* JADX WARN: Classes with same name are omitted:
  input_file:lib/naturalize.jar:codemining/lm/ngram/smoothing/KatzBackoff.class
 */
/* loaded from: input_file:naturalize.jar:codemining/lm/ngram/smoothing/KatzBackoff.class */
public class KatzBackoff extends AbstractNGramLM {
    public static final long NO_DISCOUNT_THRESHOLD = 10;
    private static final long serialVersionUID = 8858981913051295954L;
    SortedMap<Integer, Map<Long, Double>> katzCounts;
    private final Map<Integer, Map<Long, Long>> countOfCounts;
    private static final Logger LOGGER = Logger.getLogger(KatzBackoff.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Classes with same name are omitted:
      input_file:lib/naturalize.jar:codemining/lm/ngram/smoothing/KatzBackoff$NodeOrder.class
     */
    /* loaded from: input_file:naturalize.jar:codemining/lm/ngram/smoothing/KatzBackoff$NodeOrder.class */
    public static class NodeOrder {
        Trie.TrieNode<Long> node;
        int order;

        private NodeOrder() {
        }

        /* synthetic */ NodeOrder(NodeOrder nodeOrder) {
            this();
        }
    }

    public KatzBackoff(AbstractNGramLM abstractNGramLM) {
        super(abstractNGramLM);
        this.countOfCounts = new TreeMap();
        computeKatzCountsOfCounts();
        this.katzCounts = Maps.newTreeMap();
        for (int i = 1; i <= getN(); i++) {
            computeKatzCounts(i);
        }
    }

    @Override // codemining.lm.ngram.AbstractNGramLM
    public void addFromSentence(List<String> list, boolean z) {
        throw new UnsupportedOperationException("KatzSmoother is an immutable Language Model");
    }

    @Override // codemining.lm.ngram.AbstractNGramLM
    protected void addNgramToDict(NGram<String> nGram, boolean z) {
        throw new UnsupportedOperationException("KatzSmoother is an immutable Language Model");
    }

    @Override // codemining.lm.ngram.AbstractNGramLM
    public void addSentences(Set<List<String>> set, boolean z) {
        throw new UnsupportedOperationException("KatzSmoother is an immutable Language Model");
    }

    private double computeGamma(NGram<String> nGram) {
        NGram<String> prefix = nGram.getPrefix();
        double residualProbability = getResidualProbability(prefix);
        if (residualProbability == 0.0d) {
            return Math.pow(10.0d, -10.0d);
        }
        double residualProbability2 = getResidualProbability(prefix.getSuffix());
        if (residualProbability2 == 0.0d) {
            return 1.0d;
        }
        return residualProbability / residualProbability2;
    }

    private void computeKatzCounts(int i) {
        Map<Long, Long> map = this.countOfCounts.get(Integer.valueOf(i));
        TreeMap newTreeMap = Maps.newTreeMap();
        double longValue = (11.0d * map.get(11L).longValue()) / map.get(1L).longValue();
        Preconditions.checkArgument(longValue > 0.0d && longValue < 1.0d, "Discount must be betwee 0 and 1, but is " + longValue);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 > 10) {
                this.katzCounts.put(Integer.valueOf(i), ImmutableSortedMap.copyOf((Map) newTreeMap));
                return;
            }
            double d = j2 == 0 ? longValue : j2 - (((longValue * j2) * 2.0d) / 110.0d);
            newTreeMap.put(Long.valueOf(j2), Double.valueOf(d));
            Preconditions.checkArgument(((double) (j2 - 1)) < d && d <= ((double) (j2 + 1)), "smoothed katz count is between original and (original-1), unless it's zero. Original " + j2 + " but katz " + d);
            j = j2 + 1;
        }
    }

    private void computeKatzCountsOfCounts() {
        for (int i = 1; i <= getN(); i++) {
            TreeMap newTreeMap = Maps.newTreeMap();
            this.countOfCounts.put(Integer.valueOf(i), newTreeMap);
            long j = 1;
            while (true) {
                long j2 = j;
                if (j2 > 11) {
                    break;
                }
                newTreeMap.put(Long.valueOf(j2), 0L);
                j = j2 + 1;
            }
        }
        Stack stack = new Stack();
        Stack stack2 = new Stack();
        for (Map.Entry<Long, Trie.TrieNode<Long>> entry : this.trie.getRoot().prods.entrySet()) {
            NodeOrder nodeOrder = new NodeOrder(null);
            nodeOrder.order = 1;
            nodeOrder.node = entry.getValue();
            if (entry.getKey().equals(this.trie.getUnkSymbolId())) {
                stack2.push(nodeOrder);
            } else {
                stack.push(nodeOrder);
            }
        }
        while (!stack.isEmpty()) {
            NodeOrder nodeOrder2 = (NodeOrder) stack.pop();
            Map<Long, Long> map = this.countOfCounts.get(Integer.valueOf(nodeOrder2.order));
            Long l = map.get(Long.valueOf(nodeOrder2.node.count));
            if (nodeOrder2.node.count <= 11) {
                map.put(Long.valueOf(nodeOrder2.node.count), l == null ? 1L : Long.valueOf(l.longValue() + 1));
                for (Map.Entry<Long, Trie.TrieNode<Long>> entry2 : nodeOrder2.node.prods.entrySet()) {
                    NodeOrder nodeOrder3 = new NodeOrder(null);
                    nodeOrder3.order = nodeOrder2.order + 1;
                    nodeOrder3.node = entry2.getValue();
                    stack.push(nodeOrder3);
                }
            }
        }
        Preconditions.checkArgument(stack2.size() == 1);
        NodeOrder nodeOrder4 = (NodeOrder) stack2.pop();
        for (int i2 = 1; i2 <= getN(); i2++) {
            this.countOfCounts.get(Integer.valueOf(i2)).put(1L, Long.valueOf(nodeOrder4.node.count));
        }
    }

    @Override // codemining.lm.ngram.AbstractNGramLM
    public void cutoffRare(int i) {
        throw new UnsupportedOperationException("KatzSmoother is an immutable Language Model");
    }

    @Override // codemining.lm.ILanguageModel
    public ILanguageModel getImmutableVersion() {
        return this;
    }

    public double getKatzCount(long j, int i) {
        return j > 10 ? j : ((Double) ((Map) Preconditions.checkNotNull(this.katzCounts.get(Integer.valueOf(i)))).get(Long.valueOf(j))).doubleValue();
    }

    @Override // codemining.lm.ngram.AbstractNGramLM
    public double getProbabilityFor(NGram<String> nGram) {
        long count = this.trie.getCount(nGram, false, true);
        long count2 = this.trie.getCount(nGram.getPrefix(), false, false);
        if (count > 0 || nGram.size() == 1) {
            return getKatzCount(count, nGram.size()) / count2;
        }
        if (count2 <= 0) {
            return getProbabilityFor(nGram.getSuffix());
        }
        try {
            return computeGamma(nGram) * getProbabilityFor(nGram.getSuffix());
        } catch (IllegalArgumentException e) {
            LOGGER.warning("Failed to compute gamma, using 1 instead: " + ExceptionUtils.getFullStackTrace(e));
            return getProbabilityFor(nGram.getSuffix());
        }
    }

    private double getResidualProbability(NGram<String> nGram) {
        Trie.TrieNode<Long> nGramNodeForInput = this.trie.getNGramNodeForInput(nGram, true);
        TreeMultiset create = TreeMultiset.create();
        Long unkSymbolId = this.trie.getUnkSymbolId();
        for (Map.Entry<Long, Trie.TrieNode<Long>> entry : nGramNodeForInput.prods.entrySet()) {
            if (!entry.getKey().equals(unkSymbolId)) {
                create.add(Long.valueOf(entry.getValue().count));
            }
        }
        double d = 0.0d;
        Iterator it = create.entrySet().iterator();
        while (it.hasNext()) {
            d += getKatzCount(((Long) ((Multiset.Entry) it.next()).getElement()).longValue(), nGram.size() + 1) * r0.getCount();
        }
        double d2 = 1.0d - (d / (nGramNodeForInput.count - nGramNodeForInput.terminateHere));
        Preconditions.checkArgument(d2 >= 0.0d);
        Preconditions.checkArgument(d2 <= 1.0d);
        return d2;
    }

    @Override // codemining.lm.ILanguageModel
    public void trainIncrementalModel(Collection<File> collection) throws IOException {
        throw new UnsupportedOperationException("KatzSmoother is an immutable Language Model");
    }

    @Override // codemining.lm.ILanguageModel
    public void trainModel(Collection<File> collection) throws IOException {
        throw new UnsupportedOperationException("KatzSmoother is an immutable Language Model");
    }
}
