/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.naturalli;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.ie.util.RelationTriple;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem;
import edu.stanford.nlp.naturalli.OpenIE;
import edu.stanford.nlp.naturalli.SentenceFragment;
import edu.stanford.nlp.naturalli.Util;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Trilean;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;

public interface ClauseSplitter
extends BiFunction<SemanticGraph, Boolean, ClauseSplitterSearchProblem> {
    public static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, ClauseSplitterSearchProblem.Featurizer featurizer) {
        LinearClassifierFactory factory = new LinearClassifierFactory();
        OpenIE openie = new OpenIE(new Properties(){
            {
                this.setProperty("splitter.nomodel", "true");
                this.setProperty("optimizefor", "GENERAL");
            }
        });
        WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<ClauseClassifierLabel, String>();
        AtomicInteger numExamplesProcessed = new AtomicInteger(0);
        Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> {
            try {
                return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream((File)trainingDataDump.get()))));
            }
            catch (IOException e) {
                throw new RuntimeIOException(e);
            }
        });
        Redwood.Util.forceTrack("Training inference");
        trainingData.forEach(rawExample -> {
            CoreMap sentence = (CoreMap)rawExample.first;
            Collection spans = (Collection)rawExample.second;
            List tokens = (List)sentence.get(CoreAnnotations.TokensAnnotation.class);
            SemanticGraph tree = (SemanticGraph)sentence.get(SemanticGraphCoreAnnotations.CollapsedDependenciesAnnotation.class);
            ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true);
            problem.search(fragmentAndScore -> {
                List features = (List)fragmentAndScore.second;
                SentenceFragment fragment = (SentenceFragment)((Supplier)fragmentAndScore.third).get();
                HashSet<RelationTriple> extractions = new HashSet<RelationTriple>(openie.relationsInFragments(openie.entailmentsFromClause(fragment)));
                Trilean correct = Trilean.FALSE;
                block0: for (RelationTriple extraction : extractions) {
                    Span span = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index());
                    Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index());
                    for (Pair candidateGold : spans) {
                        Span subjectSpan = (Span)candidateGold.first;
                        Span objectSpan = (Span)candidateGold.second;
                        if (span.equals(subjectSpan) && objectGuess.equals(objectSpan) || span.equals(objectSpan) && objectGuess.equals(subjectSpan)) {
                            correct = Trilean.TRUE;
                            break block0;
                        }
                        if (Util.nerOverlap(tokens, subjectSpan, span) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, span)) {
                            if (correct.isTrue()) continue;
                            correct = Trilean.TRUE;
                            break block0;
                        }
                        if (correct.isTrue()) continue;
                        correct = Trilean.UNKNOWN;
                        break block0;
                    }
                }
                if (!features.isEmpty()) {
                    ArrayList decisionsToAddAsDatums = new ArrayList();
                    if (correct.isTrue()) {
                        for (int i = 0; i < features.size(); ++i) {
                            if (i == features.size() - 1) {
                                decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                                continue;
                            }
                            decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                        }
                    } else if (correct.isFalse()) {
                        decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
                    } else if (correct.isUnknown()) {
                        boolean isSimpleSplit = false;
                        for (Counter feats : features) {
                            if (!featurizer.isSimpleSplit(feats)) continue;
                            isSimpleSplit = true;
                            break;
                        }
                        if (isSimpleSplit) {
                            void var13_19;
                            boolean bl = false;
                            while (var13_19 < features.size()) {
                                if (var13_19 == features.size() - 1) {
                                    decisionsToAddAsDatums.add(Pair.makePair(features.get((int)var13_19), ClauseClassifierLabel.CLAUSE_SPLIT));
                                } else {
                                    decisionsToAddAsDatums.add(Pair.makePair(features.get((int)var13_19), ClauseClassifierLabel.CLAUSE_INTERM));
                                }
                                ++var13_19;
                            }
                        }
                    }
                    for (Pair pair : decisionsToAddAsDatums) {
                        RVFDatum datum = new RVFDatum((Counter)pair.first);
                        datum.setLabel(pair.second);
                        if (datasetDumpWriter.isPresent()) {
                            ((PrintWriter)datasetDumpWriter.get()).println("" + pair.second + "\t" + StringUtils.join(((Counter)pair.first).entrySet().stream().map(entry -> "" + (String)entry.getKey() + "->" + entry.getValue()), ";"));
                        }
                        dataset.add(datum);
                    }
                }
                return true;
            }, new LinearClassifier<ClauseClassifierLabel, String>(new ClassicCounter()), Collections.EMPTY_MAP, featurizer, 10000);
            if (numExamplesProcessed.incrementAndGet() % 100 == 0) {
                Redwood.Util.log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums");
            }
        });
        Redwood.Util.endTrack("Training inference");
        if (datasetDumpWriter.isPresent()) {
            datasetDumpWriter.get().close();
        }
        Redwood.Util.forceTrack("Training");
        Classifier fullClassifier = factory.trainClassifier((GeneralDataset)dataset);
        Redwood.Util.endTrack("Training");
        if (modelPath.isPresent()) {
            Pair<Classifier, ClauseSplitterSearchProblem.Featurizer> toSave = Pair.makePair(fullClassifier, featurizer);
            try {
                IOUtils.writeObjectToFile(toSave, modelPath.get());
                Redwood.Util.log("SUCCESS: wrote model to " + modelPath.get().getPath());
            }
            catch (IOException e) {
                Redwood.Util.log("ERROR: failed to save model to path: " + modelPath.get().getPath());
                Redwood.Util.err(e);
            }
        }
        Redwood.Util.forceTrack("Training accuracy");
        dataset.randomize(42L);
        Util.dumpAccuracy(fullClassifier, dataset);
        Redwood.Util.endTrack("Training accuracy");
        int numFolds = 5;
        Redwood.Util.forceTrack("" + numFolds + " fold cross-validation");
        for (int fold = 0; fold < numFolds; ++fold) {
            Redwood.Util.forceTrack("Fold " + (fold + 1));
            Redwood.Util.forceTrack("Training");
            Pair foldData = dataset.splitOutFold(fold, numFolds);
            Classifier classifier = factory.trainClassifier((GeneralDataset)foldData.first);
            Redwood.Util.endTrack("Training");
            Redwood.Util.forceTrack("Test");
            Util.dumpAccuracy(classifier, (GeneralDataset)foldData.second);
            Redwood.Util.endTrack("Test");
            Redwood.Util.endTrack("Fold " + (fold + 1));
        }
        Redwood.Util.endTrack("" + numFolds + " fold cross-validation");
        return (tree, truth) -> new ClauseSplitterSearchProblem((SemanticGraph)tree, (boolean)truth, Optional.of(fullClassifier), Optional.of(featurizer));
    }

    public static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, File modelPath, File trainingDataDump) {
        return ClauseSplitter.train(trainingData, Optional.of(modelPath), Optional.of(trainingDataDump), ClauseSplitterSearchProblem.DEFAULT_FEATURIZER);
    }

    public static ClauseSplitter load(String serializedModel) throws IOException {
        try {
            long start = System.currentTimeMillis();
            System.err.print("Loading clause searcher from " + serializedModel + "...");
            Pair data = (Pair)IOUtils.readObjectFromURLOrClasspathOrFileSystem(serializedModel);
            ClauseSplitter rtn = (tree, truth) -> new ClauseSplitterSearchProblem((SemanticGraph)tree, (boolean)truth, Optional.of(pair.first), Optional.of(pair.second));
            System.err.println("done [" + Redwood.formatTimeDifference(System.currentTimeMillis() - start) + "]");
            return rtn;
        }
        catch (ClassNotFoundException e) {
            throw new IllegalStateException("Invalid model at path: " + serializedModel, e);
        }
    }

    public static enum ClauseClassifierLabel {
        CLAUSE_SPLIT(2),
        CLAUSE_INTERM(1),
        NOT_A_CLAUSE(0);

        public final byte index;

        private ClauseClassifierLabel(int val) {
            this.index = (byte)val;
        }

        public String toString() {
            return this.name();
        }

        public static ClauseClassifierLabel fromIndex(int index) {
            switch (index) {
                case 0: {
                    return NOT_A_CLAUSE;
                }
                case 1: {
                    return CLAUSE_INTERM;
                }
                case 2: {
                    return CLAUSE_SPLIT;
                }
            }
            throw new IllegalArgumentException("Not a valid index: " + index);
        }
    }
}

