/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFCliqueTree;
import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.ie.crf.CliquePotentialFunction;
import edu.stanford.nlp.ie.crf.HasCliquePotentialFunction;
import edu.stanford.nlp.ie.crf.LinearCliquePotentialFunction;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractStochasticCachingDiffUpdateFunction;
import edu.stanford.nlp.optimization.HasFeatureGrouping;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;

public class CRFLogConditionalObjectiveFunction
extends AbstractStochasticCachingDiffUpdateFunction
implements HasCliquePotentialFunction,
HasFeatureGrouping {
    public static final int NO_PRIOR = 0;
    public static final int QUADRATIC_PRIOR = 1;
    public static final int HUBER_PRIOR = 2;
    public static final int QUARTIC_PRIOR = 3;
    public static final int DROPOUT_PRIOR = 4;
    public static final boolean DEBUG2 = false;
    public static final boolean DEBUG3 = false;
    public static final boolean TIMED = false;
    public static final boolean CONDENSE = true;
    public static boolean VERBOSE = false;
    protected final int prior;
    protected final double sigma;
    protected final double epsilon = 0.1;
    protected final List<Index<CRFLabel>> labelIndices;
    protected final Index<String> classIndex;
    protected final double[][] Ehat;
    protected final double[][] E;
    protected double[][][] parallelE;
    protected double[][][] parallelEhat;
    protected final int window;
    protected final int numClasses;
    protected final int[] map;
    protected int[][][][] data;
    protected double[][][][] featureVal;
    protected int[][] labels;
    protected final int domainDimension;
    protected int[][] weightIndices;
    protected final String backgroundSymbol;
    protected int[][] featureGrouping = null;
    protected static final double smallConst = 1.0E-6;
    protected Random rand = new Random(Integer.MAX_VALUE);
    protected final int multiThreadGrad;
    protected double[][] weights;
    protected CliquePotentialFunction cliquePotentialFunc;
    private ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> expectedThreadProcessor = new ExpectationThreadsafeProcessor();
    private ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> expectedAndEmpiricalThreadProcessor = new ExpectationThreadsafeProcessor(true);

    @Override
    public double[] initial() {
        return this.initial(this.rand);
    }

    public double[] initial(boolean useRandomSeed) {
        Random randToUse = useRandomSeed ? new Random() : this.rand;
        return this.initial(this.rand);
    }

    public double[] initial(Random randGen) {
        double[] initial = new double[this.domainDimension()];
        for (int i = 0; i < initial.length; ++i) {
            initial[i] = randGen.nextDouble() + 1.0E-6;
        }
        return initial;
    }

    public static int getPriorType(String priorTypeStr) {
        if (priorTypeStr == null) {
            return 1;
        }
        if ("QUADRATIC".equalsIgnoreCase(priorTypeStr)) {
            return 1;
        }
        if ("HUBER".equalsIgnoreCase(priorTypeStr)) {
            return 2;
        }
        if ("QUARTIC".equalsIgnoreCase(priorTypeStr)) {
            return 3;
        }
        if ("DROPOUT".equalsIgnoreCase(priorTypeStr)) {
            return 4;
        }
        if ("NONE".equalsIgnoreCase(priorTypeStr)) {
            return 0;
        }
        if (priorTypeStr.equalsIgnoreCase("lasso") || priorTypeStr.equalsIgnoreCase("ridge") || priorTypeStr.equalsIgnoreCase("gaussian") || priorTypeStr.equalsIgnoreCase("ae-lasso") || priorTypeStr.equalsIgnoreCase("sg-lasso") || priorTypeStr.equalsIgnoreCase("g-lasso")) {
            return 0;
        }
        throw new IllegalArgumentException("Unknown prior type: " + priorTypeStr);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, int multiThreadGrad) {
        this(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, sigma, featureVal, multiThreadGrad, true);
    }

    CRFLogConditionalObjectiveFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, int multiThreadGrad, boolean calcEmpirical) {
        this.window = window;
        this.classIndex = classIndex;
        this.numClasses = classIndex.size();
        this.labelIndices = labelIndices;
        this.map = map;
        this.data = data;
        this.featureVal = featureVal;
        this.labels = labels;
        this.prior = CRFLogConditionalObjectiveFunction.getPriorType(priorType);
        this.backgroundSymbol = backgroundSymbol;
        this.sigma = sigma;
        this.multiThreadGrad = multiThreadGrad;
        this.Ehat = this.empty2D();
        this.E = this.empty2D();
        this.weights = this.empty2D();
        if (calcEmpirical) {
            this.empiricalCounts(this.Ehat);
        }
        int myDomainDimension = 0;
        for (int dim : map) {
            myDomainDimension += labelIndices.get(dim).size();
        }
        this.domainDimension = myDomainDimension;
    }

    protected void empiricalCounts(double[][] eHat) {
        for (int m = 0; m < this.data.length; ++m) {
            this.empiricalCountsForADoc(eHat, m);
        }
    }

    protected void empiricalCountsForADoc(double[][] eHat, int docIndex) {
        int[][][] docData = this.data[docIndex];
        int[] docLabels = this.labels[docIndex];
        int[] windowLabels = new int[this.window];
        Arrays.fill(windowLabels, this.classIndex.indexOf(this.backgroundSymbol));
        double[][][] featureValArr = null;
        if (this.featureVal != null) {
            featureValArr = this.featureVal[docIndex];
        }
        if (docLabels.length > docData.length) {
            System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
            int[] newDocLabels = new int[docData.length];
            System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
            docLabels = newDocLabels;
        }
        for (int i = 0; i < docData.length; ++i) {
            System.arraycopy(windowLabels, 1, windowLabels, 0, this.window - 1);
            windowLabels[this.window - 1] = docLabels[i];
            for (int j = 0; j < docData[i].length; ++j) {
                int[] cliqueLabel = new int[j + 1];
                System.arraycopy(windowLabels, this.window - 1 - j, cliqueLabel, 0, j + 1);
                CRFLabel crfLabel = new CRFLabel(cliqueLabel);
                int labelIndex = this.labelIndices.get(j).indexOf(crfLabel);
                for (int n = 0; n < docData[i][j].length; ++n) {
                    double fVal = 1.0;
                    if (featureValArr != null && j == 0) {
                        fVal = featureValArr[i][j][n];
                    }
                    double[] dArray = eHat[docData[i][j][n]];
                    int n2 = labelIndex;
                    dArray[n2] = dArray[n2] + fVal;
                }
            }
        }
    }

    @Override
    public CliquePotentialFunction getCliquePotentialFunction(double[] x) {
        this.to2D(x, this.weights);
        return new LinearCliquePotentialFunction(this.weights);
    }

    protected double expectedAndEmpiricalCountsAndValueForADoc(double[][] E, double[][] Ehat, int docIndex) {
        this.empiricalCountsForADoc(Ehat, docIndex);
        return this.expectedCountsAndValueForADoc(E, docIndex);
    }

    public double valueForADoc(int docIndex) {
        return this.expectedCountsAndValueForADoc(null, docIndex, false, true);
    }

    protected double expectedCountsAndValueForADoc(double[][] E, int docIndex) {
        return this.expectedCountsAndValueForADoc(E, docIndex, true, true);
    }

    protected double expectedCountsForADoc(double[][] E, int docIndex) {
        return this.expectedCountsAndValueForADoc(E, docIndex, true, false);
    }

    protected double expectedCountsAndValueForADoc(double[][] E, int docIndex, boolean doExpectedCountCalc, boolean doValueCalc) {
        int[][][] docData = this.data[docIndex];
        double[][][] featureVal3DArr = null;
        if (this.featureVal != null) {
            featureVal3DArr = this.featureVal[docIndex];
        }
        CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, this.labelIndices, this.numClasses, this.classIndex, this.backgroundSymbol, this.cliquePotentialFunc, featureVal3DArr);
        double prob = 0.0;
        if (doValueCalc) {
            prob = this.documentLogProbability(docData, docIndex, cliqueTree);
        }
        if (doExpectedCountCalc) {
            this.documentExpectedCounts(E, docData, featureVal3DArr, cliqueTree);
        }
        return prob;
    }

    protected void documentExpectedCounts(double[][] E, int[][][] docData, double[][][] featureVal3DArr, CRFCliqueTree cliqueTree) {
        for (int i = 0; i < docData.length; ++i) {
            for (int j = 0; j < docData[i].length; ++j) {
                Index<CRFLabel> labelIndex = this.labelIndices.get(j);
                int liSize = labelIndex.size();
                for (int k = 0; k < liSize; ++k) {
                    int[] label = labelIndex.get(k).getLabel();
                    double p = cliqueTree.prob(i, label);
                    for (int n = 0; n < docData[i][j].length; ++n) {
                        double fVal = 1.0;
                        if (j == 0 && featureVal3DArr != null) {
                            fVal = featureVal3DArr[i][j][n];
                        }
                        double[] dArray = E[docData[i][j][n]];
                        int n2 = k;
                        dArray[n2] = dArray[n2] + p * fVal;
                    }
                }
            }
        }
    }

    private double documentLogProbability(int[][][] docData, int docIndex, CRFCliqueTree cliqueTree) {
        int[] docLabels = this.labels[docIndex];
        int[] given = new int[this.window - 1];
        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
        if (docLabels.length > docData.length) {
            System.arraycopy(docLabels, 0, given, 0, given.length);
            int[] newDocLabels = new int[docData.length];
            System.arraycopy(docLabels, docLabels.length - newDocLabels.length, newDocLabels, 0, newDocLabels.length);
            docLabels = newDocLabels;
        }
        double startPosLogProb = cliqueTree.logProbStartPos();
        if (VERBOSE) {
            System.err.printf("P_-1(Background) = % 5.3f%n", startPosLogProb);
        }
        double prob = startPosLogProb;
        for (int i = 0; i < docData.length; ++i) {
            int label = docLabels[i];
            double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
            if (VERBOSE) {
                System.err.println("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
            }
            prob += p;
            System.arraycopy(given, 1, given, 0, given.length - 1);
            given[given.length - 1] = label;
        }
        return prob;
    }

    public void setWeights(double[][] weights) {
        this.weights = weights;
        this.cliquePotentialFunc = new LinearCliquePotentialFunction(weights);
    }

    protected double regularGradientAndValue() {
        int totalLen = this.data.length;
        ArrayList<Integer> docIDs = new ArrayList<Integer>(totalLen);
        for (int m = 0; m < totalLen; ++m) {
            docIDs.add(m);
        }
        return this.multiThreadGradient(docIDs, false);
    }

    protected double multiThreadGradient(List<Integer> docIDs, boolean calculateEmpirical) {
        double objective = 0.0;
        if (this.multiThreadGrad > 1) {
            int i;
            if (this.parallelE == null) {
                this.parallelE = new double[this.multiThreadGrad][][];
                for (i = 0; i < this.multiThreadGrad; ++i) {
                    this.parallelE[i] = this.empty2D();
                }
            }
            if (calculateEmpirical && this.parallelEhat == null) {
                this.parallelEhat = new double[this.multiThreadGrad][][];
                for (i = 0; i < this.multiThreadGrad; ++i) {
                    this.parallelEhat[i] = this.empty2D();
                }
            }
        }
        MulticoreWrapper<Pair<Integer, List<Integer>>, Pair<Integer, Double>> wrapper = new MulticoreWrapper<Pair<Integer, List<Integer>>, Pair<Integer, Double>>(this.multiThreadGrad, calculateEmpirical ? this.expectedAndEmpiricalThreadProcessor : this.expectedThreadProcessor);
        int totalLen = docIDs.size();
        int partLen = totalLen / this.multiThreadGrad;
        int currIndex = 0;
        for (int part = 0; part < this.multiThreadGrad; ++part) {
            int endIndex = currIndex + partLen;
            if (part == this.multiThreadGrad - 1) {
                endIndex = totalLen;
            }
            List<Integer> subList = docIDs.subList(currIndex, endIndex);
            wrapper.put(new Pair<Integer, List<Integer>>(part, subList));
            currIndex = endIndex;
        }
        wrapper.join();
        while (wrapper.peek()) {
            Pair<Integer, Double> result = wrapper.poll();
            int tID = result.first();
            objective += result.second().doubleValue();
            if (this.multiThreadGrad <= 1) continue;
            CRFLogConditionalObjectiveFunction.combine2DArr(this.E, this.parallelE[tID]);
            if (!calculateEmpirical) continue;
            CRFLogConditionalObjectiveFunction.combine2DArr(this.Ehat, this.parallelEhat[tID]);
        }
        return objective;
    }

    @Override
    public void calculate(double[] x) {
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        CRFLogConditionalObjectiveFunction.clear2D(this.E);
        double prob = this.regularGradientAndValue();
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate() - this may well indicate numeric underflow due to overly long documents.");
        }
        this.value = -prob;
        if (VERBOSE) {
            System.err.println("value is " + Math.exp(-this.value));
        }
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            for (int j = 0; j < this.E[i].length; ++j) {
                this.derivative[index] = this.E[i][j] - this.Ehat[i][j];
                if (VERBOSE) {
                    System.err.println("deriv(" + i + "," + j + ") = " + this.E[i][j] + " - " + this.Ehat[i][j] + " = " + this.derivative[index]);
                }
                ++index;
            }
        }
        this.applyPrior(x, 1.0);
    }

    @Override
    public int dataDimension() {
        return this.data.length;
    }

    @Override
    public void calculateStochastic(double[] x, double[] v, int[] batch) {
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        double batchScale = (double)batch.length / (double)this.dataDimension();
        ArrayList<Integer> docIDs = new ArrayList<Integer>(batch.length);
        for (int item : batch) {
            docIDs.add(item);
        }
        double prob = this.multiThreadGradient(docIDs, false);
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            for (int j = 0; j < this.E[i].length; ++j) {
                this.derivative[index++] = this.E[i][j] - batchScale * this.Ehat[i][j];
                if (!VERBOSE) continue;
                System.err.println("deriv(" + i + "," + j + ") = " + this.E[i][j] + " - " + this.Ehat[i][j] + " = " + this.derivative[index - 1]);
            }
        }
        this.applyPrior(x, batchScale);
    }

    @Override
    public double calculateStochasticUpdate(double[] x, double xScale, int[] batch, double gScale) {
        this.to2D(x, xScale, this.weights);
        this.setWeights(this.weights);
        ArrayList<Integer> docIDs = new ArrayList<Integer>(batch.length);
        for (int item : batch) {
            docIDs.add(item);
        }
        double prob = this.multiThreadGradient(docIDs, true);
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            for (int j = 0; j < this.E[i].length; ++j) {
                int n = index++;
                x[n] = x[n] + (this.Ehat[i][j] - this.E[i][j]) * gScale;
            }
        }
        return this.value;
    }

    @Override
    public void calculateStochasticGradient(double[] x, int[] batch) {
        if (this.derivative == null) {
            this.derivative = new double[this.domainDimension()];
        }
        this.to2D(x, this.weights);
        this.setWeights(this.weights);
        ArrayList<Integer> docIDs = new ArrayList<Integer>(batch.length);
        for (int item : batch) {
            docIDs.add(item);
        }
        this.multiThreadGradient(docIDs, true);
        int index = 0;
        for (int i = 0; i < this.E.length; ++i) {
            for (int j = 0; j < this.E[i].length; ++j) {
                this.derivative[index++] = this.E[i][j] - this.Ehat[i][j];
            }
        }
    }

    @Override
    public double valueAt(double[] x, double xScale, int[] batch) {
        double prob = 0.0;
        this.to2D(x, xScale, this.weights);
        this.setWeights(this.weights);
        for (int ind : batch) {
            prob += this.valueForADoc(ind);
        }
        if (Double.isNaN(prob)) {
            throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunction.calculate()");
        }
        this.value = -prob;
        return this.value;
    }

    @Override
    public int[][] getFeatureGrouping() {
        if (this.featureGrouping != null) {
            return this.featureGrouping;
        }
        int[][] fg = new int[][]{ArrayMath.range(0, this.domainDimension())};
        return fg;
    }

    public void setFeatureGrouping(int[][] fg) {
        this.featureGrouping = fg;
    }

    protected void applyPrior(double[] x, double batchScale) {
        block5: {
            block6: {
                block4: {
                    if (this.prior != 1) break block4;
                    double lambda = 1.0 / (this.sigma * this.sigma);
                    int i = 0;
                    while (i < x.length) {
                        double w = x[i];
                        this.value += batchScale * w * w * lambda * 0.5;
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + batchScale * w * lambda;
                    }
                    break block5;
                }
                if (this.prior != 2) break block6;
                double sigmaSq = this.sigma * this.sigma;
                for (int i = 0; i < x.length; ++i) {
                    double w = x[i];
                    double wabs = Math.abs(w);
                    if (wabs < 0.1) {
                        this.value += batchScale * w * w / 2.0 / 0.1 / sigmaSq;
                        int n = i;
                        this.derivative[n] = this.derivative[n] + batchScale * w / 0.1 / sigmaSq;
                        continue;
                    }
                    this.value += batchScale * (wabs - 0.05) / sigmaSq;
                    int n = i;
                    this.derivative[n] = this.derivative[n] + batchScale * (w < 0.0 ? -1.0 : 1.0) / sigmaSq;
                }
                break block5;
            }
            if (this.prior != 3) break block5;
            double sigmaQu = this.sigma * this.sigma * this.sigma * this.sigma;
            double lambda = 0.5 / sigmaQu;
            int i = 0;
            while (i < x.length) {
                double w = x[i];
                this.value += batchScale * w * w * w * w * lambda;
                int n = i++;
                this.derivative[n] = this.derivative[n] + batchScale * w / sigmaQu;
            }
        }
    }

    protected Pair<double[][][], double[][][]> getCondProbs(CRFCliqueTree cTree, int[][][] docData) {
        int i;
        double[][][] prevGivenCurr = new double[docData.length][][];
        double[][][] nextGivenCurr = new double[docData.length][][];
        for (i = 0; i < docData.length; ++i) {
            prevGivenCurr[i] = new double[this.numClasses][];
            nextGivenCurr[i] = new double[this.numClasses][];
            for (int j = 0; j < this.numClasses; ++j) {
                prevGivenCurr[i][j] = new double[this.numClasses];
                nextGivenCurr[i][j] = new double[this.numClasses];
            }
        }
        for (i = 0; i < docData.length; ++i) {
            int[] labelPair = new int[2];
            for (int l1 = 0; l1 < this.numClasses; ++l1) {
                labelPair[0] = l1;
                for (int l2 = 0; l2 < this.numClasses; ++l2) {
                    labelPair[1] = l2;
                    double prob = cTree.logProb(i, labelPair);
                    if (i - 1 >= 0) {
                        nextGivenCurr[i - 1][l1][l2] = prob;
                    }
                    prevGivenCurr[i][l2][l1] = prob;
                }
            }
            for (int j = 0; j < this.numClasses; ++j) {
                int k;
                if (i - 1 >= 0) {
                    ArrayMath.logNormalize(nextGivenCurr[i - 1][j]);
                    for (k = 0; k < nextGivenCurr[i - 1][j].length; ++k) {
                        nextGivenCurr[i - 1][j][k] = Math.exp(nextGivenCurr[i - 1][j][k]);
                    }
                }
                ArrayMath.logNormalize(prevGivenCurr[i][j]);
                for (k = 0; k < prevGivenCurr[i][j].length; ++k) {
                    prevGivenCurr[i][j][k] = Math.exp(prevGivenCurr[i][j][k]);
                }
            }
        }
        return new Pair<double[][][], double[][][]>(prevGivenCurr, nextGivenCurr);
    }

    protected static void combine2DArr(double[][] combineInto, double[][] toBeCombined, double scale) {
        for (int i = 0; i < toBeCombined.length; ++i) {
            for (int j = 0; j < toBeCombined[i].length; ++j) {
                double[] dArray = combineInto[i];
                int n = j;
                dArray[n] = dArray[n] + toBeCombined[i][j] * scale;
            }
        }
    }

    protected static void combine2DArr(double[][] combineInto, double[][] toBeCombined) {
        for (int i = 0; i < toBeCombined.length; ++i) {
            for (int j = 0; j < toBeCombined[i].length; ++j) {
                double[] dArray = combineInto[i];
                int n = j;
                dArray[n] = dArray[n] + toBeCombined[i][j];
            }
        }
    }

    protected static void combine2DArr(double[][] combineInto, Map<Integer, double[]> toBeCombined) {
        for (Map.Entry<Integer, double[]> entry : toBeCombined.entrySet()) {
            int key = entry.getKey();
            double[] source = entry.getValue();
            for (int i = 0; i < source.length; ++i) {
                double[] dArray = combineInto[key];
                int n = i;
                dArray[n] = dArray[n] + source[i];
            }
        }
    }

    protected static void combine2DArr(double[][] combineInto, Map<Integer, double[]> toBeCombined, double scale) {
        for (Map.Entry<Integer, double[]> entry : toBeCombined.entrySet()) {
            int key = entry.getKey();
            double[] source = entry.getValue();
            for (int i = 0; i < source.length; ++i) {
                double[] dArray = combineInto[key];
                int n = i;
                dArray[n] = dArray[n] + source[i] * scale;
            }
        }
    }

    @Override
    public int domainDimension() {
        return this.domainDimension;
    }

    public static double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) {
        double[][] newWeights = new double[map.length][];
        int index = 0;
        for (int i = 0; i < map.length; ++i) {
            int labelSize = labelIndices.get(map[i]).size();
            newWeights[i] = new double[labelSize];
            try {
                System.arraycopy(weights, index, newWeights[i], 0, labelSize);
            }
            catch (Exception ex) {
                System.err.println("weights: " + Arrays.toString(weights));
                System.err.println("newWeights[" + i + "]: " + Arrays.toString(newWeights[i]));
                throw new RuntimeException(ex);
            }
            index += labelSize;
        }
        return newWeights;
    }

    public double[][] to2D(double[] weights) {
        return CRFLogConditionalObjectiveFunction.to2D(weights, this.labelIndices, this.map);
    }

    public static void to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map, double[][] newWeights) {
        int index = 0;
        for (int i = 0; i < map.length; ++i) {
            int labelSize = labelIndices.get(map[i]).size();
            try {
                System.arraycopy(weights, index, newWeights[i], 0, labelSize);
            }
            catch (Exception ex) {
                System.err.println("weights: " + Arrays.toString(weights));
                System.err.println("newWeights[" + i + "]: " + Arrays.toString(newWeights[i]));
                throw new RuntimeException(ex);
            }
            index += labelSize;
        }
    }

    public void to2D(double[] weights1D, double[][] newWeights) {
        CRFLogConditionalObjectiveFunction.to2D(weights1D, this.labelIndices, this.map, newWeights);
    }

    public double[][] to2D(double[] weights1D, double wScale) {
        for (int i = 0; i < weights1D.length; ++i) {
            weights1D[i] = weights1D[i] * wScale;
        }
        return CRFLogConditionalObjectiveFunction.to2D(weights1D, this.labelIndices, this.map);
    }

    public void to2D(double[] weights1D, double wScale, double[][] newWeights) {
        for (int i = 0; i < weights1D.length; ++i) {
            weights1D[i] = weights1D[i] * wScale;
        }
        CRFLogConditionalObjectiveFunction.to2D(weights1D, this.labelIndices, this.map, newWeights);
    }

    public static void clear2D(double[][] arr2D) {
        for (int i = 0; i < arr2D.length; ++i) {
            for (int j = 0; j < arr2D[i].length; ++j) {
                arr2D[i][j] = 0.0;
            }
        }
    }

    public static void to1D(double[][] weights, double[] newWeights) {
        int index = 0;
        for (double[] weightVector : weights) {
            System.arraycopy(weightVector, 0, newWeights, index, weightVector.length);
            index += weightVector.length;
        }
    }

    public static double[] to1D(double[][] weights, int domainDimension) {
        double[] newWeights = new double[domainDimension];
        int index = 0;
        for (double[] weightVector : weights) {
            System.arraycopy(weightVector, 0, newWeights, index, weightVector.length);
            index += weightVector.length;
        }
        return newWeights;
    }

    public double[] to1D(double[][] weights) {
        return CRFLogConditionalObjectiveFunction.to1D(weights, this.domainDimension());
    }

    public int[][] getWeightIndices() {
        if (this.weightIndices == null) {
            this.weightIndices = new int[this.map.length][];
            int index = 0;
            for (int i = 0; i < this.map.length; ++i) {
                this.weightIndices[i] = new int[this.labelIndices.get(this.map[i]).size()];
                for (int j = 0; j < this.labelIndices.get(this.map[i]).size(); ++j) {
                    this.weightIndices[i][j] = index++;
                }
            }
        }
        return this.weightIndices;
    }

    protected double[][] empty2D() {
        double[][] d = new double[this.map.length][];
        for (int i = 0; i < this.map.length; ++i) {
            d[i] = new double[this.labelIndices.get(this.map[i]).size()];
        }
        return d;
    }

    public int[][] getLabels() {
        return this.labels;
    }

    class ExpectationThreadsafeProcessor
    implements ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> {
        boolean calculateEmpirical = false;

        public ExpectationThreadsafeProcessor() {
        }

        public ExpectationThreadsafeProcessor(boolean calculateEmpirical) {
            this.calculateEmpirical = calculateEmpirical;
        }

        @Override
        public Pair<Integer, Double> process(Pair<Integer, List<Integer>> threadIDAndDocIndices) {
            double[][] partE;
            int tID = threadIDAndDocIndices.first();
            if (tID < 0 || tID >= CRFLogConditionalObjectiveFunction.this.multiThreadGrad) {
                throw new IllegalArgumentException("threadID must be with in range 0 <= tID < multiThreadGrad(=" + CRFLogConditionalObjectiveFunction.this.multiThreadGrad + ")");
            }
            List<Integer> docIDs = threadIDAndDocIndices.second();
            double[][] partEhat = null;
            if (CRFLogConditionalObjectiveFunction.this.multiThreadGrad == 1) {
                partE = CRFLogConditionalObjectiveFunction.this.E;
                if (this.calculateEmpirical) {
                    partEhat = CRFLogConditionalObjectiveFunction.this.Ehat;
                }
            } else {
                partE = CRFLogConditionalObjectiveFunction.this.parallelE[tID];
                CRFLogConditionalObjectiveFunction.clear2D(partE);
                if (this.calculateEmpirical) {
                    partEhat = CRFLogConditionalObjectiveFunction.this.parallelEhat[tID];
                    CRFLogConditionalObjectiveFunction.clear2D(partEhat);
                }
            }
            double probSum = 0.0;
            for (int docIndex : docIDs) {
                if (this.calculateEmpirical) {
                    probSum += CRFLogConditionalObjectiveFunction.this.expectedAndEmpiricalCountsAndValueForADoc(partE, partEhat, docIndex);
                    continue;
                }
                probSum += CRFLogConditionalObjectiveFunction.this.expectedCountsAndValueForADoc(partE, docIndex);
            }
            return new Pair<Integer, Double>(tID, probSum);
        }

        @Override
        public ThreadsafeProcessor<Pair<Integer, List<Integer>>, Pair<Integer, Double>> newInstance() {
            return this;
        }
    }
}

