package cc.mallet.classify.tui;

import cc.mallet.classify.FeatureConstraintUtil;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;
import org.apache.commons.io.IOUtils;
import org.eclipse.core.internal.boot.PlatformURLHandler;
import org.eclipse.jdt.internal.corext.refactoring.JDTRefactoringDescriptor;

/* loaded from: input_file:cc/mallet/classify/tui/Vectors2FeatureConstraints.class */
public class Vectors2FeatureConstraints {
    private static Logger logger = MalletLogger.getLogger(Vectors2FeatureConstraints.class.getName());
    public static CommandOption.File vectorsFile = new CommandOption.File(Vectors2FeatureConstraints.class, JDTRefactoringDescriptor.ATTRIBUTE_INPUT, "FILENAME", true, null, "Data file used to generate constraints.", null);
    public static CommandOption.File constraintsFile = new CommandOption.File(Vectors2FeatureConstraints.class, "output", "FILENAME", true, null, "Output file for constraints.", null);
    public static CommandOption.File featuresFile = new CommandOption.File(Vectors2FeatureConstraints.class, "features-file", "FILENAME", false, null, "File with list of features used to generate constraints.", null);
    public static CommandOption.File ldaFile = new CommandOption.File(Vectors2FeatureConstraints.class, "lda-file", "FILENAME", false, null, "File with serialized LDA object (if using LDA feature constraint selection).", null);
    public static CommandOption.Integer numConstraints = new CommandOption.Integer(Vectors2FeatureConstraints.class, "num-constraints", "FILENAME", true, 10, "Number of feature constraints.", null);
    public static CommandOption.String featureSelection = new CommandOption.String(Vectors2FeatureConstraints.class, "feature-selection", "STRING", true, "infogain | lda", "Method used to choose feature constraints.", null);
    public static CommandOption.String targets = new CommandOption.String(Vectors2FeatureConstraints.class, "targets", "STRING", true, "none | oracle | heuristic | voted", "Method used to estimate constraint targets.", null);
    public static CommandOption.Double majorityProb = new CommandOption.Double(Vectors2FeatureConstraints.class, "majority-prob", "DOUBLE", false, 0.9d, "Probability for majority labels when using heuristic target estimation.", null);

    public static void main(String[] strArr) {
        HashMap<Integer, double[]> targetsUsingFeatureVoting;
        CommandOption.process(Vectors2FeatureConstraints.class, strArr);
        InstanceList load = InstanceList.load(vectorsFile.value);
        ArrayList<Integer> arrayList = null;
        HashMap<Integer, ArrayList<Integer>> hashMap = null;
        if (featuresFile.wasInvoked()) {
            if (!fileContainsLabels(featuresFile.value)) {
                arrayList = readFeaturesFromFile(featuresFile.value, load.getDataAlphabet());
            } else {
                if (targets.value.equals("oracle")) {
                    throw new RuntimeException("with --targets oracle, features file must be unlabeled");
                }
                hashMap = readFeaturesAndLabelsFromFile(featuresFile.value, load.getDataAlphabet(), load.getTargetAlphabet());
            }
        } else if (featureSelection.value.equals("infogain")) {
            arrayList = FeatureConstraintUtil.selectFeaturesByInfoGain(load, numConstraints.value);
        } else {
            if (!featureSelection.value.equals("lda")) {
                throw new RuntimeException("Unsupported value for feature selection: " + featureSelection.value);
            }
            try {
                arrayList = FeatureConstraintUtil.selectTopLDAFeatures(numConstraints.value, (ParallelTopicModel) new ObjectInputStream(new FileInputStream(ldaFile.value)).readObject(), load.getDataAlphabet());
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (targets.value.equals("none")) {
            targetsUsingFeatureVoting = new HashMap<>();
            Iterator<Integer> it = arrayList.iterator();
            while (it.hasNext()) {
                targetsUsingFeatureVoting.put(Integer.valueOf(it.next().intValue()), null);
            }
        } else if (targets.value.equals("oracle")) {
            targetsUsingFeatureVoting = FeatureConstraintUtil.setTargetsUsingData(load, arrayList);
        } else {
            if (hashMap == null) {
                hashMap = FeatureConstraintUtil.labelFeatures(load, arrayList);
                Iterator<Integer> it2 = hashMap.keySet().iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    logger.info(load.getDataAlphabet().lookupObject(intValue) + ":  ");
                    Iterator<Integer> it3 = hashMap.get(Integer.valueOf(intValue)).iterator();
                    while (it3.hasNext()) {
                        logger.info(load.getTargetAlphabet().lookupObject(it3.next().intValue()) + " ");
                    }
                }
            }
            if (targets.value.equals("heuristic")) {
                targetsUsingFeatureVoting = FeatureConstraintUtil.setTargetsUsingHeuristic(hashMap, load.getTargetAlphabet().size(), majorityProb.value);
            } else {
                if (!targets.value.equals("voted")) {
                    throw new RuntimeException("Unsupported value for targets: " + targets.value);
                }
                targetsUsingFeatureVoting = FeatureConstraintUtil.setTargetsUsingFeatureVoting(hashMap, load);
            }
        }
        writeConstraints(targetsUsingFeatureVoting, constraintsFile.value, load.getDataAlphabet(), load.getTargetAlphabet());
    }

    private static boolean fileContainsLabels(File file) {
        String str = "";
        try {
            str = new BufferedReader(new FileReader(file)).readLine().trim();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return str.split("\\s+").length != 1;
    }

    private static ArrayList<Integer> readFeaturesFromFile(File file, Alphabet alphabet) {
        ArrayList<Integer> arrayList = new ArrayList<>();
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                arrayList.add(Integer.valueOf(alphabet.lookupIndex(readLine.trim(), false)));
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return arrayList;
    }

    public static HashMap<Integer, ArrayList<Integer>> readFeaturesAndLabelsFromFile(File file, Alphabet alphabet, Alphabet alphabet2) {
        HashMap<Integer, ArrayList<Integer>> hashMap = new HashMap<>();
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                String[] split = readLine.trim().split("\\s+");
                int lookupIndex = alphabet.lookupIndex(split[0], false);
                if (lookupIndex == -1) {
                    throw new RuntimeException("Couldn't find feature '" + split[0] + "' in the data alphabet.");
                }
                ArrayList<Integer> arrayList = new ArrayList<>();
                for (int i = 1; i < split.length; i++) {
                    int lookupIndex2 = alphabet2.lookupIndex(split[i]);
                    arrayList.add(Integer.valueOf(lookupIndex2));
                    logger.info("found label " + lookupIndex2);
                }
                hashMap.put(Integer.valueOf(lookupIndex), arrayList);
            }
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
        return hashMap;
    }

    private static void writeConstraints(HashMap<Integer, double[]> hashMap, File file, Alphabet alphabet, Alphabet alphabet2) {
        if (hashMap.size() == 0) {
            logger.warning("No constraints written!");
            return;
        }
        try {
            FileWriter fileWriter = new FileWriter(file);
            Iterator<Integer> it = hashMap.keySet().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                fileWriter.write(alphabet.lookupObject(intValue) + " ");
                double[] dArr = hashMap.get(Integer.valueOf(intValue));
                if (dArr != null) {
                    for (int i = 0; i < dArr.length; i++) {
                        fileWriter.write(alphabet2.lookupObject(i) + PlatformURLHandler.PROTOCOL_SEPARATOR + dArr[i] + " ");
                    }
                }
                fileWriter.write(IOUtils.LINE_SEPARATOR_UNIX);
            }
            fileWriter.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }
}
