/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner;
import gov.sandia.cognition.learning.algorithm.tree.VectorThresholdMaximumGainLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.DefaultWeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

public abstract class AbstractVectorThresholdMaximumGainLearner<OutputType>
extends AbstractCloneableSerializable
implements VectorThresholdMaximumGainLearner<OutputType> {
    protected int[] dimensionsToConsider;

    @Override
    public VectorElementThresholdCategorizer learn(Collection<? extends InputOutputPair<? extends Vectorizable, OutputType>> data) {
        int totalCount = CollectionUtil.size(data);
        if (totalCount <= 1) {
            return null;
        }
        MapBasedDataHistogram baseCounts = CategorizationTreeLearner.getOutputCounts(data);
        ArrayList<DefaultWeightedValue> workspace = new ArrayList<DefaultWeightedValue>(totalCount);
        for (int i = 0; i < totalCount; ++i) {
            workspace.add(new DefaultWeightedValue());
        }
        int dimensionality = AbstractVectorThresholdMaximumGainLearner.getDimensionality(data);
        double bestGain = -1.0;
        int bestIndex = -1;
        double bestThreshold = 0.0;
        int dimensionsCount = this.dimensionsToConsider == null ? dimensionality : this.dimensionsToConsider.length;
        for (int i = 0; i < dimensionsCount; ++i) {
            int index = this.dimensionsToConsider == null ? i : this.dimensionsToConsider[i];
            DefaultPair<Double, Double> gainThresholdPair = this.computeBestGainAndThreshold(data, index, baseCounts);
            if (gainThresholdPair == null) continue;
            double gain = (Double)gainThresholdPair.getFirst();
            if (bestIndex != -1 && !(gain > bestGain)) continue;
            double threshold = (Double)gainThresholdPair.getSecond();
            bestGain = gain;
            bestIndex = index;
            bestThreshold = threshold;
        }
        if (bestIndex < 0) {
            return null;
        }
        return new VectorElementThresholdCategorizer(bestIndex, bestThreshold);
    }

    public DefaultPair<Double, Double> computeBestGainAndThreshold(Collection<? extends InputOutputPair<? extends Vectorizable, OutputType>> data, int dimension, MapBasedDataHistogram<OutputType> baseCounts) {
        int totalCount = data.size();
        ArrayList<DefaultWeightedValue<OutputType>> workspace = new ArrayList<DefaultWeightedValue<OutputType>>(totalCount);
        for (int i = 0; i < totalCount; ++i) {
            workspace.add(new DefaultWeightedValue());
        }
        return this.computeBestGainAndThreshold(data, dimension, baseCounts, workspace);
    }

    protected DefaultPair<Double, Double> computeBestGainAndThreshold(Collection<? extends InputOutputPair<? extends Vectorizable, OutputType>> data, int dimension, MapBasedDataHistogram<OutputType> baseCounts, ArrayList<DefaultWeightedValue<OutputType>> values) {
        int totalCount = data.size();
        if (totalCount <= 1) {
            return null;
        }
        int index = 0;
        for (InputOutputPair<Vectorizable, OutputType> example : data) {
            Vector input = example.getInput().convertToVector();
            OutputType output = example.getOutput();
            double value = input.getElement(dimension);
            DefaultWeightedValue<OutputType> entry = values.get(index);
            entry.setWeight(value);
            entry.setValue(output);
            ++index;
        }
        Collections.sort(values, DefaultWeightedValue.WeightComparator.getInstance());
        double smallestValue = values.get(0).getWeight();
        double largestValue = values.get(totalCount - 1).getWeight();
        if (smallestValue >= largestValue) {
            return null;
        }
        CloneableSerializable positiveCounts = baseCounts.clone();
        MapBasedDataHistogram<Object> negativeCounts = new MapBasedDataHistogram<Object>(baseCounts.getDomain().size());
        double bestGain = Double.NEGATIVE_INFINITY;
        double bestTieBreaker = Double.NEGATIVE_INFINITY;
        double bestThreshold = Double.NEGATIVE_INFINITY;
        double previousValue = smallestValue;
        for (int i = 1; i < totalCount; ++i) {
            Object label = values.get(i - 1).getValue();
            positiveCounts.remove(label);
            negativeCounts.add(label);
            double value = values.get(i).getWeight();
            if (value == previousValue) continue;
            double gain = this.computeSplitGain(baseCounts, (MapBasedDataHistogram<OutputType>)positiveCounts, (MapBasedDataHistogram<OutputType>)negativeCounts);
            if (gain >= bestGain) {
                double proportionPositive = (double)positiveCounts.getTotalCount() / (double)totalCount;
                double proportionNegative = (double)negativeCounts.getTotalCount() / (double)totalCount;
                double tieBreaker = 1.0 - Math.abs(proportionPositive - proportionNegative);
                if (gain > bestGain || tieBreaker > bestTieBreaker) {
                    double threshold = (value + previousValue) / 2.0;
                    bestGain = gain;
                    bestTieBreaker = tieBreaker;
                    bestThreshold = threshold;
                }
            }
            previousValue = value;
        }
        if (bestThreshold <= smallestValue || bestThreshold >= largestValue) {
            throw new RuntimeException("bestThreshold (" + bestThreshold + ") lies outside range of values (" + smallestValue + ", " + largestValue + ")");
        }
        return new DefaultPair((Object)bestGain, (Object)bestThreshold);
    }

    public abstract double computeSplitGain(MapBasedDataHistogram<OutputType> var1, MapBasedDataHistogram<OutputType> var2, MapBasedDataHistogram<OutputType> var3);

    @Override
    public int[] getDimensionsToConsider() {
        return this.dimensionsToConsider;
    }

    @Override
    public void setDimensionsToConsider(int[] dimensionsToConsider) {
        this.dimensionsToConsider = dimensionsToConsider;
    }

    protected static int getDimensionality(Collection<? extends InputOutputPair<? extends Vectorizable, ?>> data) {
        if (CollectionUtil.isEmpty(data)) {
            return 0;
        }
        return ((Vectorizable)((InputOutputPair)CollectionUtil.getFirst(data)).getInput()).convertToVector().getDimensionality();
    }
}

