package codemining.lm.ngram;

import codemining.lm.ngram.Trie;
import com.esotericsoftware.kryo.DefaultSerializer;
import com.esotericsoftware.kryo.serializers.JavaSerializer;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

@DefaultSerializer(JavaSerializer.class)
/* loaded from: input_file:codemining/lm/ngram/LongTrie.class */
public class LongTrie<K> implements Serializable {
    private static final long serialVersionUID = -7194495381473625925L;
    private final K unkSymbol;
    private final Trie<Long> baseTrie = new Trie<>(null);
    private long nextId = Long.MIN_VALUE;
    private final BiMap<K, Long> alphabet = HashBiMap.create();

    public LongTrie(K k) {
        this.baseTrie.unkSymbolId = Long.valueOf(this.nextId);
        this.unkSymbol = k;
        this.alphabet.put(k, this.baseTrie.unkSymbolId);
        this.nextId++;
    }

    public void add(NGram<K> nGram, boolean z) {
        List<Long> symbolIds = getSymbolIds(nGram, z);
        if (!z) {
            for (int i = 0; i < symbolIds.size(); i++) {
                if (symbolIds.get(i) == null) {
                    symbolIds.set(i, this.baseTrie.getUnkSymbolId());
                }
            }
        }
        this.baseTrie.add(symbolIds);
    }

    private synchronized long addSymbolId(K k) {
        this.alphabet.put(k, Long.valueOf(this.nextId));
        this.nextId++;
        return this.nextId - 1;
    }

    public void buildVocabularySymbols(Set<K> set) {
        Iterator<K> it = set.iterator();
        while (it.hasNext()) {
            addSymbolId(it.next());
        }
    }

    public long countDistinctStartingWith(NGram<K> nGram, boolean z) {
        return this.baseTrie.countDistinctStartingWith(getSymbolIds(nGram, false), z);
    }

    public void cutoffRare(int i) {
        this.baseTrie.cutoffRare(i);
        TreeSet newTreeSet = Sets.newTreeSet();
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.push(this.baseTrie.getRoot());
        while (!arrayDeque.isEmpty()) {
            Trie.TrieNode trieNode = (Trie.TrieNode) arrayDeque.pop();
            newTreeSet.addAll(trieNode.prods.keySet());
            Iterator it = trieNode.prods.values().iterator();
            while (it.hasNext()) {
                arrayDeque.push((Trie.TrieNode) it.next());
            }
        }
        Iterator it2 = Lists.newArrayList(Sets.difference(this.alphabet.values(), newTreeSet)).iterator();
        while (it2.hasNext()) {
            long longValue = ((Long) it2.next()).longValue();
            if (longValue != getUnkSymbolId().longValue()) {
                Preconditions.checkNotNull(this.alphabet.inverse().remove(Long.valueOf(longValue)));
            }
        }
    }

    public long getCount(NGram<K> nGram, boolean z, boolean z2) {
        return this.baseTrie.getCount(getSymbolIds(nGram, false), z, z2);
    }

    public Trie.TrieNode<Long> getNGramNodeForInput(NGram<K> nGram, boolean z) {
        return this.baseTrie.getTrieNodeForInput(getSymbolIds(nGram, false), z);
    }

    public Trie.TrieNode<Long> getNGramNodeForInput(NGram<K> nGram, boolean z, Trie.TrieNode<Long> trieNode) {
        return this.baseTrie.getTrieNodeForInput(getSymbolIds(nGram, false), z, trieNode);
    }

    public Map<K, Long> getPossibleProductionsWithCounts(NGram<K> nGram) {
        Trie.TrieNode<Long> trieNodeForInput = this.baseTrie.getTrieNodeForInput(getSymbolIds(nGram, false), false);
        TreeMap treeMap = new TreeMap();
        if (trieNodeForInput == null) {
            return treeMap;
        }
        for (Map.Entry<Long, Trie.TrieNode<Long>> entry : trieNodeForInput.prods.entrySet()) {
            K k = this.alphabet.inverse().get(entry.getKey());
            long j = entry.getValue().count;
            if (k != null) {
                treeMap.put(k, Long.valueOf(j));
            } else {
                treeMap.put(this.unkSymbol, Long.valueOf(j));
            }
        }
        return treeMap;
    }

    public Trie.TrieNode<Long> getRoot() {
        return this.baseTrie.getRoot();
    }

    public Set<K> getRootSymbols() {
        HashSet newHashSet = Sets.newHashSet();
        Iterator<Long> it = this.baseTrie.getRoot().prods.keySet().iterator();
        while (it.hasNext()) {
            newHashSet.add(getSymbolFromKey(Long.valueOf(it.next().longValue())));
        }
        return newHashSet;
    }

    public K getSymbolFromKey(Long l) {
        return l.equals(this.baseTrie.unkSymbolId) ? this.unkSymbol : this.alphabet.inverse().get(l);
    }

    public List<Long> getSymbolIds(Iterable<K> iterable, boolean z) {
        ArrayList newArrayList = Lists.newArrayList();
        for (K k : iterable) {
            Long l = this.alphabet.get(k);
            if (l == null && z) {
                newArrayList.add(Long.valueOf(addSymbolId(k)));
            } else if (l == null) {
                newArrayList.add(null);
            } else {
                newArrayList.add(l);
            }
        }
        return newArrayList;
    }

    public Long getUnkSymbolId() {
        return this.baseTrie.getUnkSymbolId();
    }

    public Set<K> getVocabulary() {
        return this.alphabet.keySet();
    }

    public boolean isUNK(K k) {
        return !this.alphabet.containsKey(k);
    }

    public void remove(NGram<K> nGram) {
        List<Long> symbolIds = getSymbolIds(nGram, false);
        for (int i = 0; i < symbolIds.size(); i++) {
            if (symbolIds.get(i) == null) {
                symbolIds.set(i, this.baseTrie.getUnkSymbolId());
            }
        }
        this.baseTrie.remove(symbolIds);
    }

    public NGram<K> substituteWordsToUNK(NGram<K> nGram) {
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<K> it = nGram.iterator();
        while (it.hasNext()) {
            K next = it.next();
            if (this.alphabet.get(next) == null) {
                newArrayList.add(this.alphabet.inverse().get(this.baseTrie.unkSymbolId));
            } else {
                newArrayList.add(next);
            }
        }
        return new NGram<>(newArrayList);
    }

    public long sumStartingWith(NGram<K> nGram, boolean z) {
        return this.baseTrie.sumStartingWith(getSymbolIds(nGram, false), z);
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append('[');
        for (Map.Entry<Long, Trie.TrieNode<Long>> entry : this.baseTrie.getRoot().prods.entrySet()) {
            Long key = entry.getKey();
            ArrayList newArrayList = Lists.newArrayList();
            toStringHelper(this.alphabet.inverse().get(key).toString(), entry.getValue(), newArrayList);
            Iterator<String> it = newArrayList.iterator();
            while (it.hasNext()) {
                stringBuffer.append(String.valueOf(it.next()) + System.lineSeparator());
            }
        }
        stringBuffer.append(']');
        return stringBuffer.toString();
    }

    private void toStringHelper(String str, Trie.TrieNode<Long> trieNode, List<String> list) {
        if (trieNode.prods.size() == 0) {
            list.add(String.valueOf(str) + " count:" + trieNode.count);
            return;
        }
        for (Map.Entry<Long, Trie.TrieNode<Long>> entry : trieNode.prods.entrySet()) {
            toStringHelper(String.valueOf(str) + ", " + this.alphabet.inverse().get(entry.getKey()), entry.getValue(), list);
        }
    }
}
