package codemining.lm.ngram;

import codemining.lm.ngram.Trie;
import codemining.util.SymbolKey;
import codemining.util.SymbolMap;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

/* loaded from: input_file:codemining/lm/ngram/SymbolTrie.class */
public class SymbolTrie<K> extends Trie<SymbolKey> {
    private static final long serialVersionUID = -5392154551719623858L;
    private final SymbolMap<K> alphabet;
    private final Collection<SymbolKey> vocabularySyms;

    public SymbolTrie(K k) {
        super(null);
        this.vocabularySyms = Lists.newArrayList();
        this.alphabet = new SymbolMap<>();
        this.unkSymbolId = this.alphabet.getSymbolId(k, true);
    }

    public SymbolTrie(SymbolMap<K> symbolMap, SymbolKey symbolKey) {
        super(null);
        this.vocabularySyms = Lists.newArrayList();
        this.alphabet = symbolMap;
        this.unkSymbolId = symbolKey;
    }

    public void add(NGram<K> nGram, boolean z) {
        List<SymbolKey> symbolIds = this.alphabet.getSymbolIds(nGram, z);
        if (!z) {
            for (int i = 0; i < symbolIds.size(); i++) {
                if (symbolIds.get(i) == null) {
                    symbolIds.set(i, getUnkSymbolId());
                }
            }
        }
        add(symbolIds);
    }

    public void buildVocabularySymbols(Set<K> set) {
        Iterator<K> it = set.iterator();
        while (it.hasNext()) {
            this.vocabularySyms.add(this.alphabet.getSymbolId(it.next(), true));
        }
    }

    public long countDistinctStartingWith(NGram<K> nGram, boolean z) {
        return countDistinctStartingWith(this.alphabet.getSymbolIds(nGram, false), z);
    }

    public long getCount(NGram<K> nGram, boolean z, boolean z2) {
        return getCount(this.alphabet.getSymbolIds(nGram, false), z, z2);
    }

    public Trie.TrieNode<SymbolKey> getNGramNodeForInput(NGram<K> nGram, boolean z) {
        return getTrieNodeForInput(this.alphabet.getSymbolIds(nGram, false), z);
    }

    public Map<K, Long> getPossibleProductionsWithCounts(NGram<K> nGram) {
        Trie.TrieNode<SymbolKey> trieNodeForInput = getTrieNodeForInput(this.alphabet.getSymbolIds(nGram, false), false);
        TreeMap treeMap = new TreeMap();
        if (trieNodeForInput == null) {
            return treeMap;
        }
        for (Map.Entry<SymbolKey, Trie.TrieNode<SymbolKey>> entry : trieNodeForInput.prods.entrySet()) {
            treeMap.put(this.alphabet.getSymbolFromId(entry.getKey()), Long.valueOf(entry.getValue().count));
        }
        return treeMap;
    }

    public K getSymbolFromKey(SymbolKey symbolKey) {
        return this.alphabet.getSymbolFromId(symbolKey);
    }

    public boolean isUNK(K k) {
        return this.alphabet.getSymbolId(k, false) == null;
    }

    public void purgeVocabularySymbols() {
        this.vocabularySyms.clear();
    }

    protected NGram<K> substituteWordsToUNK(NGram<K> nGram) {
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<K> it = nGram.iterator();
        while (it.hasNext()) {
            K next = it.next();
            SymbolKey symbolId = this.alphabet.getSymbolId(next, false);
            if (symbolId == null || !getRoot().prods.containsKey(symbolId)) {
                newArrayList.add(this.alphabet.getSymbolFromId((SymbolKey) this.unkSymbolId));
            } else {
                newArrayList.add(next);
            }
        }
        return new NGram<>(newArrayList);
    }

    public long sumStartingWith(NGram<K> nGram, boolean z) {
        return sumStartingWith(this.alphabet.getSymbolIds(nGram, false), z);
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append('[');
        for (Map.Entry<SymbolKey, Trie.TrieNode<SymbolKey>> entry : getRoot().prods.entrySet()) {
            SymbolKey key = entry.getKey();
            ArrayList newArrayList = Lists.newArrayList();
            toStringHelper(this.alphabet.getSymbolFromId(key).toString(), entry.getValue(), newArrayList);
            Iterator<String> it = newArrayList.iterator();
            while (it.hasNext()) {
                stringBuffer.append(String.valueOf(it.next()) + '\n');
            }
        }
        stringBuffer.append(']');
        return stringBuffer.toString();
    }

    private void toStringHelper(String str, Trie.TrieNode<SymbolKey> trieNode, List<String> list) {
        if (trieNode.prods.size() == 0) {
            list.add(String.valueOf(str) + " count:" + trieNode.count);
            return;
        }
        for (Map.Entry<SymbolKey, Trie.TrieNode<SymbolKey>> entry : trieNode.prods.entrySet()) {
            toStringHelper(String.valueOf(str) + ", " + this.alphabet.getSymbolFromId(entry.getKey()), entry.getValue(), list);
        }
    }
}
