/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.experiment;

import gov.sandia.cognition.collection.RangeExcludedArrayList;
import gov.sandia.cognition.learning.data.DefaultPartitionedDataset;
import gov.sandia.cognition.learning.data.PartitionedDataset;
import gov.sandia.cognition.learning.experiment.ValidationFoldCreator;
import gov.sandia.cognition.math.Permutation;
import gov.sandia.cognition.util.AbstractRandomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

public class CrossFoldCreator<DataType>
extends AbstractRandomized
implements ValidationFoldCreator<DataType, DataType> {
    public static final int DEFAULT_NUM_FOLDS = 10;
    protected int numFolds;

    public CrossFoldCreator() {
        this(10);
    }

    public CrossFoldCreator(int numFolds) {
        this(numFolds, new Random());
    }

    public CrossFoldCreator(int numFolds, Random random) {
        super(random);
        this.setNumFolds(numFolds);
    }

    @Override
    public List<PartitionedDataset<DataType>> createFolds(Collection<? extends DataType> data) {
        return CrossFoldCreator.createFolds(data, this.getNumFolds(), this.getRandom());
    }

    public static <DataType> List<PartitionedDataset<DataType>> createFolds(Collection<? extends DataType> data, int numFolds, Random random) {
        int total = data.size();
        if (total < 2) {
            throw new IllegalArgumentException("data must have at least 2 items");
        }
        CrossFoldCreator.checkNumFolds(numFolds);
        ArrayList reordering = Permutation.createReordering(data, (Random)random);
        int numActualFolds = Math.min(total, numFolds);
        ArrayList<PartitionedDataset<DataType>> datasets = new ArrayList<PartitionedDataset<DataType>>(numActualFolds);
        int fromIndex = 0;
        int toIndex = 0;
        for (int i = 0; i < numActualFolds; ++i) {
            fromIndex = toIndex;
            int foldSize = (total - fromIndex) / (numActualFolds - i);
            toIndex = fromIndex + foldSize;
            RangeExcludedArrayList training = new RangeExcludedArrayList(reordering, fromIndex, toIndex - 1);
            List testing = reordering.subList(fromIndex, toIndex);
            datasets.add(new DefaultPartitionedDataset(training, testing));
        }
        return datasets;
    }

    public int getNumFolds() {
        return this.numFolds;
    }

    public void setNumFolds(int numFolds) {
        CrossFoldCreator.checkNumFolds(numFolds);
        this.numFolds = numFolds;
    }

    protected static void checkNumFolds(int numFolds) {
        if (numFolds <= 1) {
            throw new IllegalArgumentException("numFolds must be greater than 1.");
        }
    }
}

