package codemining.lm.ngram;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;

/* loaded from: input_file:codemining/lm/ngram/Trie.class */
public class Trie<T extends Comparable<T>> implements Serializable {
    private static final long serialVersionUID = 1365912094350571019L;
    private final TrieNode<T> root = new TrieNode<>();
    protected T unkSymbolId;

    /* loaded from: input_file:codemining/lm/ngram/Trie$TrieNode.class */
    public static class TrieNode<T> implements Serializable {
        private static final long serialVersionUID = -3590197616044935851L;
        public final SortedMap<T, TrieNode<T>> prods = new TreeMap();
        public long count;
        public long terminateHere;
    }

    public Trie(T t) {
        this.unkSymbolId = t;
    }

    public final synchronized void add(List<T> list) {
        this.root.count++;
        TrieNode<T> trieNode = this.root;
        for (T t : list) {
            if (!trieNode.prods.containsKey(t)) {
                trieNode.prods.put(t, new TrieNode<>());
            }
            TrieNode<T> trieNode2 = trieNode.prods.get(t);
            trieNode2.count++;
            trieNode = trieNode2;
        }
        trieNode.terminateHere++;
    }

    public final long countDistinctStartingWith(List<T> list, boolean z) {
        Preconditions.checkArgument(list.size() > 0);
        TrieNode<T> trieNodeForInput = getTrieNodeForInput(list, z);
        if (trieNodeForInput == null) {
            return 0L;
        }
        return (z || !trieNodeForInput.prods.containsKey(getUnkSymbolId())) ? trieNodeForInput.prods.size() : trieNodeForInput.prods.size() - 1;
    }

    public final synchronized void cutoffRare(int i) {
        cutoffRare(this.root, i);
    }

    private final void cutoffRare(TrieNode<T> trieNode, int i) {
        ArrayList newArrayList = Lists.newArrayList();
        TrieNode<T> trieNode2 = trieNode.prods.containsKey(getUnkSymbolId()) ? trieNode.prods.get(getUnkSymbolId()) : new TrieNode<>();
        for (Map.Entry<T, TrieNode<T>> entry : trieNode.prods.entrySet()) {
            T key = entry.getKey();
            TrieNode<T> value = entry.getValue();
            if (value.count > i || key.equals(getUnkSymbolId())) {
                cutoffRare(value, i);
            } else {
                newArrayList.add(key);
                mergeTrieNodes(value, trieNode2);
            }
        }
        if (trieNode2.count > 0) {
            trieNode.prods.put(getUnkSymbolId(), trieNode2);
            cutoffRare(trieNode2, i);
        }
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            trieNode.prods.remove((Comparable) it.next());
        }
    }

    public final long getCount(List<T> list, boolean z, boolean z2) {
        long j;
        TrieNode<T> trieNodeForInput = getTrieNodeForInput(list, z);
        if (trieNodeForInput == null) {
            return 0L;
        }
        if (z) {
            j = 0;
        } else {
            TrieNode<T> trieNode = trieNodeForInput.prods.get(getUnkSymbolId());
            j = trieNode != null ? trieNode.count : 0L;
        }
        long j2 = z2 ? trieNodeForInput.count - j : (trieNodeForInput.count - trieNodeForInput.terminateHere) - j;
        Preconditions.checkArgument(j2 >= 0);
        return j2;
    }

    public final TrieNode<T> getRoot() {
        return this.root;
    }

    public final TrieNode<T> getTrieNodeForInput(List<T> list, boolean z) {
        return getTrieNodeForInput(list, z, this.root);
    }

    public TrieNode<T> getTrieNodeForInput(List<T> list, boolean z, TrieNode<T> trieNode) {
        TrieNode<T> trieNode2 = trieNode;
        for (T t : list) {
            if (t != null && trieNode2.prods.containsKey(t)) {
                trieNode2 = trieNode2.prods.get(t);
            } else {
                if (!trieNode2.prods.containsKey(getUnkSymbolId()) || !z) {
                    trieNode2 = null;
                    break;
                }
                trieNode2 = trieNode2.prods.get(getUnkSymbolId());
            }
        }
        return trieNode2;
    }

    public final T getUnkSymbolId() {
        return this.unkSymbolId;
    }

    private final void mergeTrieNodes(TrieNode<T> trieNode, TrieNode<T> trieNode2) {
        ((TrieNode) Preconditions.checkNotNull(trieNode2)).count += ((TrieNode) Preconditions.checkNotNull(trieNode)).count;
        trieNode2.terminateHere += trieNode.terminateHere;
        for (Map.Entry<T, TrieNode<T>> entry : trieNode.prods.entrySet()) {
            if (trieNode2.prods.containsKey(entry.getKey())) {
                mergeTrieNodes(entry.getValue(), trieNode2.prods.get(entry.getKey()));
            } else {
                trieNode2.prods.put(entry.getKey(), entry.getValue());
            }
        }
    }

    public final long sumStartingWith(List<T> list, boolean z) {
        Preconditions.checkArgument(list.size() > 0);
        TrieNode<T> trieNodeForInput = getTrieNodeForInput(list, z);
        if (trieNodeForInput == null) {
            return 0L;
        }
        return trieNodeForInput.count - trieNodeForInput.terminateHere;
    }
}
