/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.regression;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.learning.function.scalar.KernelScalarFunction;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;

@PublicationReference(author={"John Shawe-Taylor", "Nello Cristianini"}, title="Kernel Methods for Pattern Analysis", type=PublicationType.Book, year=2004, url="http://www.kernel-methods.net/")
public class KernelBasedIterativeRegression<InputType>
extends AbstractAnytimeSupervisedBatchLearner<InputType, Double, KernelScalarFunction<InputType>>
implements MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_MIN_SENSITIVITY = 10.0;
    private Kernel<? super InputType> kernel;
    private double minSensitivity;
    private KernelScalarFunction<InputType> result;
    private int errorCount;
    private transient LinkedHashMap<InputOutputPair<? extends InputType, Double>, DefaultWeightedValue<InputType>> supportsMap;

    public KernelBasedIterativeRegression() {
        this(null);
    }

    public KernelBasedIterativeRegression(Kernel<? super InputType> kernel) {
        this(kernel, 10.0);
    }

    public KernelBasedIterativeRegression(Kernel<? super InputType> kernel, double minSensitivity) {
        this(kernel, minSensitivity, 100);
    }

    public KernelBasedIterativeRegression(Kernel<? super InputType> kernel, double minSensitivity, int maxIterations) {
        super(maxIterations);
        this.setKernel(kernel);
        this.setMinSensitivity(minSensitivity);
        this.setResult(null);
        this.setErrorCount(0);
        this.setSupportsMap(null);
    }

    @Override
    public KernelBasedIterativeRegression<InputType> clone() {
        KernelBasedIterativeRegression clone = (KernelBasedIterativeRegression)super.clone();
        clone.setKernel((Kernel)ObjectUtil.cloneSafe(this.getKernel()));
        clone.setResult((KernelScalarFunction)ObjectUtil.cloneSafe((CloneableSerializable)this.getResult()));
        clone.setSupportsMap((LinkedHashMap)ObjectUtil.cloneSmart(this.getSupportsMap()));
        return clone;
    }

    @Override
    protected boolean initializeAlgorithm() {
        if (this.getData() == null) {
            return false;
        }
        int validCount = 0;
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null) continue;
            ++validCount;
        }
        if (validCount <= 0) {
            return false;
        }
        this.setErrorCount(validCount);
        this.setSupportsMap(new LinkedHashMap<InputOutputPair<? extends InputType, Double>, DefaultWeightedValue<InputType>>());
        this.setResult(new KernelScalarFunction<InputType>(this.getKernel(), this.getSupportsMap().values(), 0.0));
        return true;
    }

    @Override
    protected boolean step() {
        this.setErrorCount(0);
        if (((Collection)this.getData()).size() == 1) {
            InputOutputPair first = (InputOutputPair)((Collection)this.getData()).iterator().next();
            ((KernelScalarFunction)this.getResult()).getExamples().clear();
            ((KernelScalarFunction)this.getResult()).setBias((Double)first.getOutput());
            return false;
        }
        for (InputOutputPair example : (Collection)this.getData()) {
            double oldWeight;
            if (example == null) continue;
            Object input = example.getInput();
            double actual = (Double)example.getOutput();
            double prediction = (Double)this.result.evaluate(input);
            double error = actual - prediction;
            DefaultWeightedValue support = this.supportsMap.get(example);
            double newWeight = oldWeight = support == null ? 0.0 : support.getWeight();
            if (Math.abs(error) >= this.minSensitivity) {
                double weightUpdate = error;
                if (oldWeight == 0.0) {
                    double positiveUpdate = weightUpdate - this.minSensitivity;
                    double negativeUpdate = weightUpdate + this.minSensitivity;
                    weightUpdate = Math.abs(positiveUpdate) <= Math.abs(negativeUpdate) ? (weightUpdate -= this.minSensitivity) : (weightUpdate += this.minSensitivity);
                } else {
                    weightUpdate = oldWeight > 0.0 ? (weightUpdate -= this.minSensitivity) : (weightUpdate += this.minSensitivity);
                }
                double selfKernel = this.kernel.evaluate(input, input);
                if (selfKernel != 0.0) {
                    weightUpdate /= selfKernel;
                }
                if (oldWeight * (newWeight = oldWeight + weightUpdate) < 0.0) {
                    newWeight = 0.0;
                }
            }
            double difference = newWeight - oldWeight;
            System.out.println("Input: " + input);
            System.out.println("actual: " + actual + " prediction: " + prediction + " error: " + (actual - prediction));
            System.out.println("Old weight: " + oldWeight + ", new weight: " + newWeight);
            System.out.println("Difference: " + difference);
            if (difference == 0.0) continue;
            this.setErrorCount(this.getErrorCount() + 1);
            double oldBias = this.result.getBias();
            double newBias = oldBias + difference;
            if (support == null) {
                support = new DefaultWeightedValue(input, newWeight);
                this.supportsMap.put(example, support);
            } else if (newWeight == 0.0) {
                this.supportsMap.remove(example);
            } else {
                support.setWeight(newWeight);
            }
            this.result.setBias(newBias);
        }
        return this.getErrorCount() > 0;
    }

    @Override
    protected void cleanupAlgorithm() {
        if (this.getSupportsMap() != null) {
            ((KernelScalarFunction)this.getResult()).setExamples(new ArrayList<DefaultWeightedValue<InputType>>(this.getSupportsMap().values()));
            this.setSupportsMap(null);
        }
    }

    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    public KernelScalarFunction<InputType> getResult() {
        return this.result;
    }

    protected void setResult(KernelScalarFunction<InputType> result) {
        this.result = result;
    }

    public int getErrorCount() {
        return this.errorCount;
    }

    protected void setErrorCount(int errorCount) {
        this.errorCount = errorCount;
    }

    protected LinkedHashMap<InputOutputPair<? extends InputType, Double>, DefaultWeightedValue<InputType>> getSupportsMap() {
        return this.supportsMap;
    }

    protected void setSupportsMap(LinkedHashMap<InputOutputPair<? extends InputType, Double>, DefaultWeightedValue<InputType>> supportsMap) {
        this.supportsMap = supportsMap;
    }

    public double getMinSensitivity() {
        return this.minSensitivity;
    }

    public void setMinSensitivity(double minSensitivity) {
        if (minSensitivity < 0.0) {
            throw new IllegalArgumentException("minSensitivity must be non-negative.");
        }
        this.minSensitivity = minSensitivity;
    }

    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue("error count", (Object)this.getErrorCount());
    }
}

