/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.experiments.sinabill;

import gov.sandia.cognition.math.matrix.Matrix;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.openimaj.io.FileUtils;
import org.openimaj.io.IOUtils;
import org.openimaj.io.WriteableBinary;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.BilinearEvaluator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.experiments.sinabill.BilinearExperiment;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.HardCodedInitStrat;
import org.openimaj.ml.linear.learner.init.SingleValueInitStrat;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.util.pair.Pair;

public class StreamAustrianDampeningExperiments
extends BilinearExperiment {
    private static final String BATCH_EXPERIMENT = "batchStreamLossExperiments/batch_1366820115090/experiment.log";

    @Override
    public String getExperimentName() {
        return "streamingDampeningExperiments";
    }

    @Override
    public void performExperiment() throws Exception {
        Map<Integer, Double> batchLosses = this.loadBatchLoss();
        BilinearLearnerParameters params = new BilinearLearnerParameters();
        params.put("eta0u", 0.01);
        params.put("eta0w", 0.01);
        params.put("lambda", 0.001);
        params.put("lambda_w", 0.006);
        params.put("biconvex_tol", 0.01);
        params.put("biconvex_maxiter", 10);
        params.put("bias", true);
        params.put("biaseta0", 0.5);
        params.put("winitstrat", new SingleValueInitStrat(0.1));
        params.put("uinitstrat", new SparseZerosInitStrategy());
        HardCodedInitStrat biasInitStrat = new HardCodedInitStrat();
        params.put("biasinitstrat", biasInitStrat);
        BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(this.MATLAB_DATA()), 98, true);
        this.prepareExperimentLog(params);
        double dampening = 0.02;
        double dampeningIncr = 0.1;
        double dampeningMax = 0.021;
        int maxItems = 15;
        this.logger.debug(String.format("Beggining dampening experiments: min=%2.5f,max=%2.5f,incr=%2.5f", dampening, 0.021, 0.1));
        block0: while (dampening < 0.021) {
            Pair<Matrix> next;
            params.put("dampening", dampening);
            this.logger.debug("Dampening is now: " + dampening);
            BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params);
            dampening += 0.1;
            int item = 0;
            RootMeanSumLossEvaluator eval = new RootMeanSumLossEvaluator();
            eval.setLearner(learner);
            bmfdg.setFold(-1, BillMatlabFileDataGenerator.Mode.ALL);
            boolean first = true;
            while ((next = bmfdg.generate()) != null) {
                if (first) {
                    first = false;
                    biasInitStrat.setMatrix((Matrix)next.secondObject());
                }
                ArrayList<Pair<Matrix>> asList = new ArrayList<Pair<Matrix>>();
                asList.add(next);
                if (learner.getW() != null) {
                    if (!batchLosses.containsKey(item)) {
                        this.logger.debug(String.format("...No batch result found for: %d, done", item));
                        continue block0;
                    }
                    this.logger.debug("...Calculating regret for item" + item);
                    double loss = ((BilinearEvaluator)eval).evaluate(asList);
                    this.logger.debug(String.format("... loss: %f", loss));
                    double batchloss = batchLosses.get(item);
                    this.logger.debug(String.format("... batch loss: %f", batchloss));
                    this.logger.debug(String.format("... regret: %f", loss - batchloss));
                }
                if (item >= 15) continue block0;
                learner.process((Matrix)next.firstObject(), (Matrix)next.secondObject());
                Matrix w = learner.getW();
                Matrix u = learner.getU();
                this.logger.debug("W row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)w));
                this.logger.debug(String.format("W range: %2.5f -> %2.5f", CFMatrixUtils.min((Matrix)w), CFMatrixUtils.max((Matrix)w)));
                this.logger.debug("U row sparcity: " + CFMatrixUtils.rowSparsity((Matrix)u));
                this.logger.debug(String.format("U range: %2.5f -> %2.5f", CFMatrixUtils.min((Matrix)u), CFMatrixUtils.max((Matrix)u)));
                this.logger.debug(String.format("... loss (post addition): %f", ((BilinearEvaluator)eval).evaluate(asList)));
                this.logger.debug(String.format("Saving learner, Fold %d, Item %d", -1, item));
                File learnerOut = new File(this.FOLD_ROOT(-1), String.format("learner_%d", item));
                IOUtils.writeBinary((File)learnerOut, (WriteableBinary)learner);
                ++item;
            }
        }
    }

    private Map<Integer, Double> loadBatchLoss() throws IOException {
        String[] batchExperimentLines = FileUtils.readlines((File)new File(this.DATA_ROOT(), BATCH_EXPERIMENT));
        int seenItems = 0;
        HashMap<Integer, Double> ret = new HashMap<Integer, Double>();
        for (String line : batchExperimentLines) {
            if (line.contains("New Item Seen: ")) {
                seenItems = Integer.parseInt(line.split(":")[1].trim());
            }
            if (!line.contains("Loss:")) continue;
            ret.put(seenItems, Double.parseDouble(line.split(":")[1].trim()));
        }
        return ret;
    }

    public static void main(String[] args) throws Exception {
        StreamAustrianDampeningExperiments exp = new StreamAustrianDampeningExperiments();
        ((BilinearExperiment)exp).performExperiment();
    }
}

