/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.svm;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;

@CodeReview(reviewer={"Kevin R. Dixon"}, date="2008-07-23", changesNeeded=false, comments={"Minor cosmetic to javadoc.", "Great looking code."})
@PublicationReference(author={"Olvi L. Mangasarian", "David R. Musicant"}, title="Successive Overrelaxation for Support Vector Machines", type=PublicationType.Journal, year=1999, publication="IEEE Transactions on Neural Networks", pages={1032, 1037}, url="ftp://ftp.cs.wisc.edu/math-prog/tech-reports/98-18.ps")
public class SuccessiveOverrelaxation<InputType>
extends AbstractAnytimeSupervisedBatchLearner<InputType, Boolean, KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>>>
implements MeasurablePerformanceAlgorithm {
    public static final int DEFAULT_MAX_ITERATIONS = 1000;
    public static final double DEFAULT_MAX_WEIGHT = 100.0;
    public static final double DEFAULT_OVERRELAXATION = 1.3;
    public static final double DEFAULT_MIN_CHANGE = 1.0E-4;
    protected Kernel<? super InputType> kernel;
    protected double maxWeight;
    protected double overrelaxation;
    protected double minChange;
    protected KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> result;
    protected double totalChange;
    protected ArrayList<Entry> entries;
    protected LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, Entry> supportsMap;

    public SuccessiveOverrelaxation() {
        this(null);
    }

    public SuccessiveOverrelaxation(Kernel<? super InputType> kernel) {
        this(kernel, 100.0, 1.3, 1.0E-4, 1000);
    }

    public SuccessiveOverrelaxation(Kernel<? super InputType> kernel, double maxWeight, double overrelaxation, double minChange, int maxIterations) {
        super(maxIterations);
        this.setKernel(kernel);
        this.setMaxWeight(maxWeight);
        this.setOverrelaxation(overrelaxation);
        this.setMinChange(minChange);
        this.setEntries(null);
        this.setResult(null);
        this.setTotalChange(0.0);
        this.setSupportsMap(null);
    }

    @Override
    protected boolean initializeAlgorithm() {
        if (this.getData() == null) {
            return false;
        }
        int validCount = 0;
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null) continue;
            ++validCount;
        }
        if (validCount <= 0) {
            return false;
        }
        this.setTotalChange(1.0);
        this.setEntries(new ArrayList<Entry>(validCount));
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null || example.getOutput() == null) continue;
            this.entries.add(new Entry(example));
        }
        this.setSupportsMap(new LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, Entry>());
        Collection<Entry> supports = Collections.unmodifiableCollection(this.getSupportsMap().values());
        this.setResult(new KernelBinaryCategorizer<InputType, Entry>(this.getKernel(), supports, 0.0));
        return true;
    }

    @Override
    protected boolean step() {
        int numSupports;
        this.setTotalChange(0.0);
        Collections.sort(this.entries, Collections.reverseOrder());
        for (Entry entry : this.entries) {
            entry.previousStepWeight = entry.getWeight();
            this.update(entry);
        }
        int numInstances = this.entries.size();
        int numNotPinned = numSupports = this.supportsMap.size();
        double interiorIterationsGuess = 0.5 * ((double)numInstances + 1.0) + ((double)numSupports + 1.0) / (((double)numNotPinned + 1.0) * ((double)numNotPinned + 1.0));
        int interiorIterations = Math.max((int)interiorIterationsGuess, 1);
        ArrayList<Entry> currentSupports = new ArrayList<Entry>(this.supportsMap.values());
        Collections.sort(currentSupports);
        for (int i = 0; i < interiorIterations; ++i) {
            for (Entry support : currentSupports) {
                this.update(support);
            }
        }
        double changeSum = 0.0;
        for (Entry entry : this.entries) {
            double change = entry.getWeight() - entry.previousStepWeight;
            changeSum += change * change;
        }
        this.setTotalChange(Math.sqrt(changeSum));
        return this.getTotalChange() > this.getMinChange();
    }

    protected void update(Entry entry) {
        Object input = entry.getInput();
        double actualDouble = entry.outputDouble;
        double prediction = this.result.evaluateAsDouble(input);
        double oldWeight = entry.getWeight();
        double bias = this.result.getBias();
        double newWeight = actualDouble * oldWeight - this.overrelaxation / (entry.selfKernel + 1.0) * (actualDouble * prediction - 1.0);
        newWeight = Math.max(0.0, Math.min(this.maxWeight, newWeight));
        entry.setWeight(newWeight *= actualDouble);
        if (newWeight != 0.0) {
            if (!entry.supportInserted) {
                this.supportsMap.put(entry.example, entry);
                entry.supportInserted = true;
            }
        } else if (entry.supportInserted) {
            this.supportsMap.remove(entry.example);
            entry.supportInserted = false;
        }
        double change = newWeight - oldWeight;
        this.result.setBias(bias + change);
    }

    @Override
    protected void cleanupAlgorithm() {
        if (this.getSupportsMap() != null) {
            ArrayList<DefaultWeightedValue> supports = new ArrayList<DefaultWeightedValue>(this.supportsMap.size());
            for (Entry entry : this.supportsMap.values()) {
                supports.add(new DefaultWeightedValue((WeightedValue)entry));
            }
            ((KernelBinaryCategorizer)this.getResult()).setExamples(supports);
            this.setSupportsMap(null);
        }
    }

    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    public double getMaxWeight() {
        return this.maxWeight;
    }

    public void setMaxWeight(double maxWeight) {
        if (maxWeight <= 0.0) {
            throw new IllegalArgumentException("maxWeight must be positive");
        }
        this.maxWeight = maxWeight;
    }

    public double getOverrelaxation() {
        return this.overrelaxation;
    }

    public void setOverrelaxation(double overrelaxation) {
        if (overrelaxation <= 0.0 || overrelaxation >= 2.0) {
            throw new IllegalArgumentException("overrelaxation must be in (0.0, 2.0), exclusive.");
        }
        this.overrelaxation = overrelaxation;
    }

    public double getMinChange() {
        return this.minChange;
    }

    public void setMinChange(double minChange) {
        if (minChange < 0.0) {
            throw new IllegalArgumentException("minChange must be positive");
        }
        this.minChange = minChange;
    }

    public KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> getResult() {
        return this.result;
    }

    protected void setResult(KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> result) {
        this.result = result;
    }

    protected ArrayList<Entry> getEntries() {
        return this.entries;
    }

    protected void setEntries(ArrayList<Entry> entries) {
        this.entries = entries;
    }

    protected LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, Entry> getSupportsMap() {
        return this.supportsMap;
    }

    protected void setSupportsMap(LinkedHashMap<InputOutputPair<? extends InputType, ? extends Boolean>, Entry> supportsMap) {
        this.supportsMap = supportsMap;
    }

    public double getTotalChange() {
        return this.totalChange;
    }

    protected void setTotalChange(double totalChange) {
        this.totalChange = totalChange;
    }

    public NamedValue<Double> getPerformance() {
        return new DefaultNamedValue("change", (Object)this.getTotalChange());
    }

    protected class Entry
    extends DefaultWeightedValue<InputType>
    implements Comparable<Entry> {
        protected InputOutputPair<? extends InputType, ? extends Boolean> example;
        protected boolean output;
        protected double outputDouble;
        protected boolean supportInserted;
        protected double selfKernel;
        protected double previousStepWeight;

        protected Entry(InputOutputPair<? extends InputType, ? extends Boolean> example) {
            super(example.getInput(), 0.0);
            Object input = example.getInput();
            this.example = example;
            this.output = example.getOutput();
            this.outputDouble = this.output ? 1.0 : -1.0;
            this.supportInserted = false;
            this.selfKernel = SuccessiveOverrelaxation.this.kernel.evaluate(input, input);
            this.previousStepWeight = 0.0;
        }

        public InputType getInput() {
            return this.value;
        }

        public boolean getOutput() {
            return this.output;
        }

        public void setUnlabeledWeight(double unlabeledWeight) {
            this.weight = this.output ? unlabeledWeight : -unlabeledWeight;
        }

        public double getUnlabeledWeight() {
            return this.output ? this.weight : -this.weight;
        }

        @Override
        public int compareTo(Entry other) {
            return Double.compare(this.getUnlabeledWeight(), other.getUnlabeledWeight());
        }
    }
}

