/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.tree;

import gov.sandia.cognition.learning.algorithm.tree.DeciderLearner;
import gov.sandia.cognition.learning.algorithm.tree.VectorThresholdMaximumGainLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.Categorizer;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.Permutation;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.AbstractRandomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;

public class RandomSubVectorThresholdLearner<OutputType>
extends AbstractRandomized
implements DeciderLearner<Vectorizable, OutputType, Boolean, Categorizer<? super Vectorizable, ? extends Boolean>>,
VectorFactoryContainer {
    public static final double DEFAULT_PERCENT_TO_SAMPLE = 0.1;
    protected DeciderLearner<? super Vectorizable, OutputType, Boolean, VectorElementThresholdCategorizer> subLearner;
    protected double percentToSample;
    protected VectorFactory<? extends Vector> vectorFactory;

    public RandomSubVectorThresholdLearner() {
        this(null, 0.1, new Random());
    }

    public RandomSubVectorThresholdLearner(DeciderLearner<? super Vectorizable, OutputType, Boolean, VectorElementThresholdCategorizer> subLearner, double percentToSample, Random random) {
        this(subLearner, percentToSample, random, (VectorFactory<? extends Vector>)VectorFactory.getDefault());
    }

    public RandomSubVectorThresholdLearner(DeciderLearner<? super Vectorizable, OutputType, Boolean, VectorElementThresholdCategorizer> subLearner, double percentToSample, Random random, VectorFactory<? extends Vector> vectorFactory) {
        super(random);
        this.setSubLearner(subLearner);
        this.setPercentToSample(percentToSample);
        this.setVectorFactory(vectorFactory);
    }

    @Override
    public VectorElementThresholdCategorizer learn(Collection<? extends InputOutputPair<? extends Vectorizable, OutputType>> data) {
        int dimensionality;
        int subDimensionality;
        if (this.random == null) {
            this.random = new Random();
        }
        if ((subDimensionality = this.getSubDimensionality(dimensionality = DatasetUtil.getInputDimensionality(data))) >= dimensionality) {
            return (VectorElementThresholdCategorizer)this.subLearner.learn(data);
        }
        int[] permutation = Permutation.createPermutation((int)dimensionality, (Random)this.random);
        if (this.subLearner instanceof VectorThresholdMaximumGainLearner) {
            int[] subDimensions = new int[subDimensionality];
            System.arraycopy(permutation, 0, subDimensions, 0, subDimensionality);
            ((VectorThresholdMaximumGainLearner)this.subLearner).setDimensionsToConsider(subDimensions);
            return (VectorElementThresholdCategorizer)this.subLearner.learn(data);
        }
        ArrayList<DefaultInputOutputPair<Vector, OutputType>> subData = new ArrayList<DefaultInputOutputPair<Vector, OutputType>>(data.size());
        for (InputOutputPair<Vectorizable, OutputType> example : data) {
            Vector subVector = this.vectorFactory.createVector(subDimensionality);
            Vector vector = example.getInput().convertToVector();
            for (int i = 0; i < subDimensionality; ++i) {
                subVector.setElement(i, vector.getElement(permutation[i]));
            }
            subData.add(new DefaultInputOutputPair<Vector, OutputType>(subVector, example.getOutput()));
        }
        VectorElementThresholdCategorizer subDecider = (VectorElementThresholdCategorizer)this.subLearner.learn(subData);
        if (subDecider != null) {
            int subIndex = subDecider.getIndex();
            int index = permutation[subIndex];
            subDecider.setIndex(index);
        }
        return subDecider;
    }

    public int getSubDimensionality(int dimensionality) {
        return Math.max(1, (int)((double)dimensionality * this.percentToSample));
    }

    public DeciderLearner<? super Vectorizable, OutputType, Boolean, VectorElementThresholdCategorizer> getSubLearner() {
        return this.subLearner;
    }

    public void setSubLearner(DeciderLearner<? super Vectorizable, OutputType, Boolean, VectorElementThresholdCategorizer> subLearner) {
        this.subLearner = subLearner;
    }

    public double getPercentToSample() {
        return this.percentToSample;
    }

    public void setPercentToSample(double percentToSample) {
        if (percentToSample < 0.0 || percentToSample > 1.0) {
            throw new IllegalArgumentException("percentToSample must be between 0.0 and 1.0");
        }
        this.percentToSample = percentToSample;
    }

    public VectorFactory<? extends Vector> getVectorFactory() {
        return this.vectorFactory;
    }

    public void setVectorFactory(VectorFactory<? extends Vector> vectorFactory) {
        this.vectorFactory = vectorFactory;
    }
}

