package codemining.lm.grammar.tsg;

import codemining.lm.grammar.tree.TreeNode;
import codemining.lm.grammar.tsg.TSGCompatibleTree;
import codemining.math.distributions.GeometricDistribution;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.math.DoubleMath;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Stack;

/* JADX WARN: Classes with same name are omitted:
  input_file:lib/naturalize.jar:codemining/lm/grammar/tsg/CollapsedGibbsSampler.class
 */
/* loaded from: input_file:naturalize.jar:codemining/lm/grammar/tsg/CollapsedGibbsSampler.class */
public class CollapsedGibbsSampler {
    private final double concentrationParameter;
    private final double geometricProbability;
    final List<TSGCompatibleTree> treeSet = Lists.newArrayList();
    final Map<TSGCompatibleTree.TSGNode, Multiset<NodeConsequent>> ruleFrequency = Maps.newHashMap();
    private final TSGrammar<TSGCompatibleTree.TSGNode> grammar = new TSGrammar<>();
    final Random rng = new Random();

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Classes with same name are omitted:
      input_file:lib/naturalize.jar:codemining/lm/grammar/tsg/CollapsedGibbsSampler$NodeConsequent.class
     */
    /* loaded from: input_file:naturalize.jar:codemining/lm/grammar/tsg/CollapsedGibbsSampler$NodeConsequent.class */
    public static class NodeConsequent {
        public final List<TSGCompatibleTree.TSGNode> nodes = Lists.newArrayList();

        public boolean equals(Object obj) {
            if (obj instanceof NodeConsequent) {
                return Objects.equal(((NodeConsequent) obj).nodes, this.nodes);
            }
            return false;
        }

        public int hashCode() {
            return Objects.hashCode(this.nodes);
        }
    }

    public CollapsedGibbsSampler(double d, double d2) {
        this.geometricProbability = 1.0d / d;
        this.concentrationParameter = d2;
    }

    public void addTree(TSGCompatibleTree tSGCompatibleTree) {
        this.treeSet.add(tSGCompatibleTree);
        updateRuleFrequencies(tSGCompatibleTree.getRoot());
    }

    public TSGrammar<TSGCompatibleTree.TSGNode> getGrammar() {
        return this.grammar;
    }

    private double getLogProbFor(TSGCompatibleTree.TSGNode tSGNode, NodeConsequent nodeConsequent) {
        Preconditions.checkNotNull(tSGNode);
        Preconditions.checkNotNull(nodeConsequent);
        TSGCompatibleTree.TSGNode tSGNode2 = new TSGCompatibleTree.TSGNode(tSGNode);
        tSGNode.isRoot = true;
        NodeConsequent nodeConsequent2 = new NodeConsequent();
        for (int i = 0; i < nodeConsequent.nodes.size(); i++) {
            TSGCompatibleTree.TSGNode tSGNode3 = new TSGCompatibleTree.TSGNode(nodeConsequent.nodes.get(i));
            tSGNode3.isRoot = true;
            nodeConsequent2.nodes.add(tSGNode3);
        }
        if (!this.ruleFrequency.containsKey(tSGNode2)) {
            return -100.0d;
        }
        Multiset<NodeConsequent> multiset = this.ruleFrequency.get(tSGNode2);
        double size = multiset.size();
        if (!multiset.contains(nodeConsequent2)) {
            return DoubleMath.log2(1.0d / (size + 1.0d));
        }
        double log2 = DoubleMath.log2(multiset.count(nodeConsequent2) / size);
        Preconditions.checkArgument((Double.isInfinite(log2) || Double.isNaN(log2)) ? false : true);
        return log2;
    }

    protected double getPosteriorProbabilityForTree(TreeNode<TSGCompatibleTree.TSGNode> treeNode, boolean z) {
        Preconditions.checkNotNull(treeNode);
        double countTreesWithRoot = this.grammar.countTreesWithRoot(treeNode.getData());
        double countTreeOccurences = this.grammar.countTreeOccurences(treeNode);
        double priorForTree = getPriorForTree(treeNode);
        if (countTreeOccurences > 0.0d && z) {
            countTreeOccurences -= 1.0d;
            countTreesWithRoot -= 1.0d;
        }
        return (countTreesWithRoot + (this.concentrationParameter * priorForTree)) / (countTreeOccurences + this.concentrationParameter);
    }

    protected double getPriorForTree(TreeNode<TSGCompatibleTree.TSGNode> treeNode) {
        Preconditions.checkNotNull(treeNode);
        int treeSize = getTreeSize(treeNode);
        return Math.pow(2.0d, GeometricDistribution.getLogProb(treeSize, this.geometricProbability) + getTreeRuleLogProbability(treeNode));
    }

    private double getTreeRuleLogProbability(TreeNode<TSGCompatibleTree.TSGNode> treeNode) {
        Preconditions.checkNotNull(treeNode);
        Stack stack = new Stack();
        double d = 0.0d;
        stack.push(treeNode);
        while (!stack.isEmpty()) {
            TreeNode treeNode2 = (TreeNode) stack.pop();
            NodeConsequent nodeConsequent = new NodeConsequent();
            for (int i = 0; i < treeNode2.nChildren(); i++) {
                nodeConsequent.nodes.add((TSGCompatibleTree.TSGNode) treeNode2.getChild(i).getData());
            }
            d += getLogProbFor((TSGCompatibleTree.TSGNode) treeNode2.getData(), nodeConsequent);
            for (TreeNode treeNode3 : treeNode2.getChildren()) {
                if (!treeNode3.isLeaf()) {
                    stack.push(treeNode3);
                }
            }
        }
        Preconditions.checkArgument((Double.isInfinite(d) || Double.isNaN(d)) ? false : true);
        return d;
    }

    protected int getTreeSize(TreeNode<TSGCompatibleTree.TSGNode> treeNode) {
        Preconditions.checkNotNull(treeNode);
        Stack stack = new Stack();
        int i = 1;
        stack.push(treeNode);
        while (!stack.isEmpty()) {
            TreeNode treeNode2 = (TreeNode) stack.pop();
            i += treeNode2.nChildren();
            Iterator it = treeNode2.getChildren().iterator();
            while (it.hasNext()) {
                stack.push((TreeNode) it.next());
            }
        }
        return i / 2;
    }

    public void performSampling(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            System.out.println("Iteration " + i2);
            sampleAllTreesOnce();
        }
    }

    public void pruneNonSurprisingRules(double d) {
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<Map.Entry<TSGCompatibleTree.TSGNode, Multiset<TreeNode<TSGCompatibleTree.TSGNode>>>> it = this.grammar.getInternalGrammar().entrySet().iterator();
        while (it.hasNext()) {
            Multiset<TreeNode<TSGCompatibleTree.TSGNode>> value = it.next().getValue();
            ArrayList newArrayList2 = Lists.newArrayList();
            for (TreeNode<TSGCompatibleTree.TSGNode> treeNode : value.elementSet()) {
                if (getPosteriorProbabilityForTree(treeNode, false) / Math.pow(2.0d, getTreeRuleLogProbability(treeNode)) < d) {
                    newArrayList2.add(treeNode);
                }
            }
            int i = 0;
            Iterator it2 = newArrayList2.iterator();
            while (it2.hasNext()) {
                TreeNode treeNode2 = (TreeNode) it2.next();
                int count = value.count(treeNode2);
                i += count;
                value.remove(treeNode2, count);
            }
            value.add(null, i);
        }
        Iterator it3 = newArrayList.iterator();
        while (it3.hasNext()) {
            this.grammar.getInternalGrammar().remove((TSGCompatibleTree.TSGNode) it3.next());
        }
    }

    public void pruneRareTrees(int i) {
        this.grammar.prune(i);
    }

    public void sampleAllTreesOnce() {
        Iterator<TSGCompatibleTree> it = this.treeSet.iterator();
        while (it.hasNext()) {
            TreeNode<TSGCompatibleTree.TSGNode> root = it.next().getRoot();
            Iterator<TreeNode<TSGCompatibleTree.TSGNode>> it2 = root.getChildren().iterator();
            while (it2.hasNext()) {
                sampleSubTree(it2.next(), root);
            }
        }
    }

    protected void sampleAt(TreeNode<TSGCompatibleTree.TSGNode> treeNode, TreeNode<TSGCompatibleTree.TSGNode> treeNode2) {
        Preconditions.checkNotNull(treeNode);
        Preconditions.checkNotNull(treeNode2);
        boolean z = treeNode.getData().isRoot;
        treeNode.getData().isRoot = false;
        TreeNode<TSGCompatibleTree.TSGNode> subTreeFromRoot = TSGCompatibleTree.getSubTreeFromRoot(treeNode2);
        treeNode.getData().isRoot = true;
        TreeNode<TSGCompatibleTree.TSGNode> subTreeFromRoot2 = TSGCompatibleTree.getSubTreeFromRoot(treeNode2);
        TreeNode<TSGCompatibleTree.TSGNode> subTreeFromRoot3 = TSGCompatibleTree.getSubTreeFromRoot(treeNode);
        double posteriorProbabilityForTree = getPosteriorProbabilityForTree(subTreeFromRoot, !z);
        double posteriorProbabilityForTree2 = posteriorProbabilityForTree / (posteriorProbabilityForTree + (getPosteriorProbabilityForTree(subTreeFromRoot2, z) * getPosteriorProbabilityForTree(subTreeFromRoot3, z)));
        treeNode.getData().isRoot = this.rng.nextDouble() > posteriorProbabilityForTree2;
        if (z != treeNode.getData().isRoot) {
            if (z) {
                this.grammar.removeTree(subTreeFromRoot2);
                this.grammar.removeTree(subTreeFromRoot3);
                this.grammar.addTree(subTreeFromRoot);
            } else {
                this.grammar.removeTree(subTreeFromRoot);
                this.grammar.addTree(subTreeFromRoot2);
                this.grammar.addTree(subTreeFromRoot3);
            }
        }
    }

    protected void sampleSubTree(TreeNode<TSGCompatibleTree.TSGNode> treeNode, TreeNode<TSGCompatibleTree.TSGNode> treeNode2) {
        Stack stack = new Stack();
        Stack stack2 = new Stack();
        stack.push(treeNode);
        stack2.push(treeNode2);
        while (!stack.isEmpty()) {
            TreeNode<TSGCompatibleTree.TSGNode> treeNode3 = (TreeNode) stack.pop();
            TreeNode<TSGCompatibleTree.TSGNode> treeNode4 = (TreeNode) stack2.pop();
            sampleAt(treeNode3, treeNode4);
            TreeNode<TSGCompatibleTree.TSGNode> treeNode5 = treeNode3.getData().isRoot ? treeNode3 : treeNode4;
            for (TreeNode<TSGCompatibleTree.TSGNode> treeNode6 : treeNode3.getChildren()) {
                if (!treeNode6.isLeaf()) {
                    stack.push(treeNode6);
                    stack2.push(treeNode5);
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v44, types: [com.google.common.collect.Multiset] */
    private void updateRuleFrequencies(TreeNode<TSGCompatibleTree.TSGNode> treeNode) {
        HashMultiset create;
        Preconditions.checkNotNull(treeNode);
        Stack stack = new Stack();
        stack.push(treeNode);
        while (!stack.isEmpty()) {
            TreeNode treeNode2 = (TreeNode) stack.pop();
            TSGCompatibleTree.TSGNode tSGNode = new TSGCompatibleTree.TSGNode((TSGCompatibleTree.TSGNode) treeNode2.getData());
            tSGNode.isRoot = true;
            if (this.ruleFrequency.containsKey(tSGNode)) {
                create = (Multiset) this.ruleFrequency.get(tSGNode);
            } else {
                create = HashMultiset.create();
                this.ruleFrequency.put(tSGNode, create);
            }
            NodeConsequent nodeConsequent = new NodeConsequent();
            for (int i = 0; i < treeNode2.nChildren(); i++) {
                TSGCompatibleTree.TSGNode tSGNode2 = new TSGCompatibleTree.TSGNode((TSGCompatibleTree.TSGNode) treeNode2.getChild(i).getData());
                tSGNode2.isRoot = true;
                nodeConsequent.nodes.add(tSGNode2);
            }
            create.add(nodeConsequent);
            for (TreeNode treeNode3 : treeNode2.getChildren()) {
                if (!treeNode3.isLeaf()) {
                    stack.push(treeNode3);
                }
            }
        }
    }
}
