/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.learning.algorithm.tree.DeciderLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;

public class VectorThresholdVarianceLearner
extends AbstractCloneableSerializable
implements DeciderLearner<Vectorizable, Double, Boolean, VectorElementThresholdCategorizer> {
    @Override
    public VectorElementThresholdCategorizer learn(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> data) {
        if (data == null || data.size() <= 1) {
            return null;
        }
        double baseVariance = DatasetUtil.computeOutputVariance(data);
        int dimensionality = this.getDimensionality(data);
        double bestGain = -1.0;
        int bestIndex = -1;
        double bestThreshold = 0.0;
        for (int i = 0; i < dimensionality; ++i) {
            DefaultPair<Double, Double> gainThresholdPair = this.computeBestGainThreshold(data, i, baseVariance);
            if (gainThresholdPair == null) continue;
            double gain = (Double)gainThresholdPair.getFirst();
            if (bestIndex != -1 && !(gain > bestGain)) continue;
            double threshold = (Double)gainThresholdPair.getSecond();
            bestGain = gain;
            bestIndex = i;
            bestThreshold = threshold;
        }
        if (bestIndex < 0) {
            return null;
        }
        return new VectorElementThresholdCategorizer(bestIndex, bestThreshold);
    }

    protected int getDimensionality(Collection<? extends InputOutputPair<? extends Vectorizable, ?>> data) {
        if (data == null || data.size() <= 0) {
            return 0;
        }
        return data.iterator().next().getInput().convertToVector().getDimensionality();
    }

    public DefaultPair<Double, Double> computeBestGainThreshold(Collection<? extends InputOutputPair<? extends Vectorizable, Double>> data, int dimension, double baseVariance) {
        int total = data.size();
        ArrayList<DefaultPair> values = new ArrayList<DefaultPair>(total);
        double totalOutputSum = 0.0;
        for (InputOutputPair<? extends Vectorizable, Double> inputOutputPair : data) {
            Vector input = inputOutputPair.getInput().convertToVector();
            Double output = inputOutputPair.getOutput();
            double value = input.getElement(dimension);
            values.add(new DefaultPair((Object)value, (Object)output));
            totalOutputSum += output.doubleValue();
        }
        Collections.sort(values, new Comparator<DefaultPair<Double, Double>>(){

            @Override
            public int compare(DefaultPair<Double, Double> o1, DefaultPair<Double, Double> o2) {
                return ((Double)o1.getFirst()).compareTo((Double)o2.getFirst());
            }
        });
        if (total <= 1 || ((Double)((DefaultPair)values.get(0)).getFirst()).equals(((DefaultPair)values.get(total - 1)).getFirst())) {
            return null;
        }
        double sumNegative = 0.0;
        double sumPositive = totalOutputSum;
        double bestGain = 0.0;
        double bestTieBreaker = 0.0;
        double bestThreshold = 0.0;
        double previousValue = 0.0;
        for (int i = 0; i < total; ++i) {
            DefaultPair valueLabel = (DefaultPair)values.get(i);
            double value = (Double)valueLabel.getFirst();
            double label = (Double)valueLabel.getSecond();
            if (i == 0) {
                bestGain = 0.0;
                bestTieBreaker = 0.0;
                bestThreshold = value;
            } else if (value != previousValue) {
                int numNegative = i;
                int numPositive = total - i;
                double meanNegative = sumNegative / (double)numNegative;
                double varianceNegative = 0.0;
                for (int j = 0; j < i; ++j) {
                    double output = (Double)((DefaultPair)values.get(j)).getSecond();
                    double difference = output - meanNegative;
                    varianceNegative += difference * difference;
                }
                varianceNegative /= (double)numNegative;
                double meanPositive = sumPositive / (double)numPositive;
                double variancePositive = 0.0;
                for (int j = i; j < total; ++j) {
                    double output = (Double)((DefaultPair)values.get(j)).getSecond();
                    double difference = output - meanPositive;
                    variancePositive += difference * difference;
                }
                double proportionPositive = (double)numPositive / (double)total;
                double proportionNegative = (double)numNegative / (double)total;
                double gain = baseVariance - proportionPositive * (variancePositive /= (double)numPositive) - proportionNegative * varianceNegative;
                if (gain >= bestGain) {
                    double tieBreaker = 1.0 - Math.abs(proportionPositive - proportionNegative);
                    if (gain > bestGain || tieBreaker > bestTieBreaker) {
                        double threshold = (value + previousValue) / 2.0;
                        bestGain = gain;
                        bestTieBreaker = tieBreaker;
                        bestThreshold = threshold;
                    }
                }
            }
            sumPositive -= label;
            sumNegative += label;
            previousValue = value;
        }
        return new DefaultPair((Object)bestGain, (Object)bestThreshold);
    }
}

