/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.ensemble.BaggingCategorizerLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Random;
import java.util.Set;

public class CategoryBalancedBaggingLearner<InputType, CategoryType>
extends BaggingCategorizerLearner<InputType, CategoryType> {
    protected ArrayList<CategoryType> categoryList;
    protected HashMap<CategoryType, ArrayList<Integer>> dataPerCategory;

    public CategoryBalancedBaggingLearner() {
        this(null);
    }

    public CategoryBalancedBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner) {
        this(learner, 100, 1.0, new Random());
    }

    public CategoryBalancedBaggingLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner, int maxIterations, double percentToSample, Random random) {
        super(learner, maxIterations, percentToSample, random);
    }

    @Override
    protected boolean initializeAlgorithm() {
        boolean result = super.initializeAlgorithm();
        if (result) {
            int dataSize = this.dataList.size();
            Set categories = DatasetUtil.findUniqueOutputs(this.dataList);
            this.categoryList = new ArrayList(categories);
            this.dataPerCategory = new LinkedHashMap<CategoryType, ArrayList<Integer>>(categories.size());
            for (Object category : categories) {
                this.dataPerCategory.put(category, new ArrayList());
            }
            for (int i = 0; i < dataSize; ++i) {
                Object category;
                category = ((InputOutputPair)this.dataList.get(i)).getOutput();
                this.dataPerCategory.get(category).add(i);
            }
        }
        return result;
    }

    @Override
    protected void fillBag(int sampleCount) {
        int categorySampleSize;
        int categoryCount = this.categoryList.size();
        if (sampleCount % categoryCount != 0) {
            Collections.shuffle(this.categoryList, this.random);
        }
        int remainingSampleSize = sampleCount;
        for (int i = 0; i < categoryCount && remainingSampleSize > 0; remainingSampleSize -= categorySampleSize, ++i) {
            CategoryType category = this.categoryList.get(i);
            ArrayList<Integer> indices = this.dataPerCategory.get(category);
            int categorySize = indices.size();
            categorySampleSize = Math.max(1, remainingSampleSize / (categoryCount - i));
            for (int j = 0; j < categorySampleSize; ++j) {
                int index = indices.get(this.random.nextInt(categorySize));
                InputOutputPair example = (InputOutputPair)this.dataList.get(index);
                this.bag.add(example);
                int n = index;
                this.dataInBag[n] = this.dataInBag[n] + 1;
            }
        }
    }

    @Override
    protected void cleanupAlgorithm() {
        this.dataPerCategory = null;
        this.categoryList = null;
        super.cleanupAlgorithm();
    }
}

