/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.svm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.learning.function.kernel.KernelContainer;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Random;

@PublicationReference(title="Fast training of support vector machines using sequential minimal optimization", author={"John C. Platt"}, year=1999, type=PublicationType.BookChapter, pages={185, 208}, publication="Advances in Kernel Methods", url="http://research.microsoft.com/pubs/68391/smo-book.pdf")
public class SequentialMinimalOptimization<InputType>
extends AbstractAnytimeSupervisedBatchLearner<InputType, Boolean, KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>>>
implements KernelContainer<InputType>,
Randomized {
    public static final int DEFAULT_MAX_ITERATIONS = 1000;
    public static final double DEFAULT_MAX_PENALTY = Double.POSITIVE_INFINITY;
    public static final double DEFAULT_ERROR_TOLERANCE = 0.001;
    public static final double DEFAULT_EFFECTIVE_ZERO = 1.0E-10;
    public static final int DEFAULT_KERNEL_CACHE_SIZE = 1000;
    private double maxPenalty;
    private double errorTolerance;
    private double effectiveZero;
    private int kernelCacheSize;
    private Random random;
    private Kernel<? super InputType> kernel;
    private transient KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> result;
    private transient ArrayList<InputOutputPair<? extends InputType, Boolean>> dataList;
    private transient int dataSize;
    private transient boolean examineAll;
    private transient int changeCount;
    private transient LinkedHashMap<Integer, DefaultWeightedValue<InputType>> supportsMap;
    private transient LinkedHashSet<Integer> nonBoundAlphaIndices;
    private transient LinkedHashMap<Integer, Double> errorCache;
    private transient LinkedHashMap<Long, Double> kernelCache;

    public SequentialMinimalOptimization() {
        this(null);
    }

    public SequentialMinimalOptimization(Kernel<? super InputType> kernel) {
        this(kernel, new Random());
    }

    public SequentialMinimalOptimization(Kernel<? super InputType> kernel, Random random) {
        this(kernel, Double.POSITIVE_INFINITY, 0.001, 1.0E-10, 1000, 1000, random);
    }

    public SequentialMinimalOptimization(Kernel<? super InputType> kernel, double maxPenalty, double errorTolerance, double effectiveZero, int kernelCacheSize, int maxIterations, Random random) {
        super(maxIterations);
        this.setKernel(kernel);
        this.setMaxPenalty(maxPenalty);
        this.setErrorTolerance(errorTolerance);
        this.setEffectiveZero(effectiveZero);
        this.setKernelCacheSize(kernelCacheSize);
        this.setRandom(random);
    }

    @Override
    protected boolean initializeAlgorithm() {
        this.result = null;
        if (this.getData() == null) {
            return false;
        }
        this.dataList = new ArrayList(((Collection)this.getData()).size());
        int positives = 0;
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null || example.getInput() == null || example.getOutput() == null) continue;
            this.dataList.add(example);
            if (!((Boolean)example.getOutput()).booleanValue()) continue;
            ++positives;
        }
        this.dataSize = this.dataList.size();
        if (this.dataSize <= 0) {
            this.dataList = null;
            return false;
        }
        if (positives <= 0 || positives >= this.dataSize) {
            throw new IllegalArgumentException("Data is all one category");
        }
        this.changeCount = ((Collection)this.getData()).size();
        this.supportsMap = new LinkedHashMap();
        this.nonBoundAlphaIndices = new LinkedHashSet();
        this.errorCache = new LinkedHashMap();
        if (this.kernelCacheSize > 1 && this.dataSize > 1) {
            final int cacheSize = Math.min(this.dataSize * this.dataSize, this.kernelCacheSize);
            this.kernelCache = new LinkedHashMap<Long, Double>(cacheSize, 0.75f, true){

                @Override
                protected boolean removeEldestEntry(Map.Entry<Long, Double> eldest) {
                    return this.size() > cacheSize;
                }
            };
        }
        this.result = new KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>>(this.kernel, this.supportsMap.values(), 0.0);
        return true;
    }

    @Override
    protected boolean step() {
        this.changeCount = 0;
        if (this.examineAll) {
            for (int j = 0; j < this.dataSize; ++j) {
                this.changeCount += this.examineExample(j);
            }
        } else {
            for (Integer j : new ArrayList<Integer>(this.nonBoundAlphaIndices)) {
                double alphaJ = this.getAlpha(j);
                if (!(alphaJ > 0.0) || !(alphaJ < this.maxPenalty)) continue;
                this.changeCount += this.examineExample(j);
            }
        }
        if (this.examineAll) {
            this.examineAll = false;
        } else if (this.changeCount <= 0) {
            this.examineAll = true;
        }
        return this.changeCount > 0 || this.examineAll;
    }

    @Override
    protected void cleanupAlgorithm() {
        this.dataList = null;
        this.supportsMap = null;
        this.nonBoundAlphaIndices = null;
        this.errorCache = null;
        this.kernelCache = null;
    }

    protected int examineExample(int j) {
        double c = this.maxPenalty;
        double tolerance = this.effectiveZero;
        double yJ = this.getTarget(j);
        double alphaJ = this.getAlpha(j);
        double eJ = this.getError(j);
        double rJ = eJ * yJ;
        if (rJ < -tolerance && alphaJ < c || rJ > tolerance && alphaJ > 0.0) {
            int offset;
            int i;
            int nonBoundAlphasCount = this.nonBoundAlphaIndices.size();
            if (nonBoundAlphasCount > 1 && this.takeStep(i = eJ > 0.0 ? this.getMinErrorIndex() : this.getMaxErrorIndex(), j)) {
                return 1;
            }
            if (nonBoundAlphasCount > 0) {
                offset = this.random.nextInt(nonBoundAlphasCount);
                ArrayList<Integer> nonBoundAlphas = new ArrayList<Integer>(this.nonBoundAlphaIndices);
                for (int n = 0; n < nonBoundAlphasCount; ++n) {
                    int alphaIndex = (offset + n) % nonBoundAlphasCount;
                    int i2 = nonBoundAlphas.get(alphaIndex);
                    if (!this.takeStep(i2, j)) continue;
                    return 1;
                }
            }
            offset = this.random.nextInt(this.dataSize);
            for (int n = 0; n < this.dataSize; ++n) {
                int i3 = (offset + n) % this.dataSize;
                if (!this.takeStep(i3, j)) continue;
                return 1;
            }
        }
        return 0;
    }

    private boolean takeStep(int i, int j) {
        double kJJ;
        double kIJ;
        double upperBound;
        double lowerBound;
        if (i == j) {
            return false;
        }
        double c = this.maxPenalty;
        double epsilon = this.effectiveZero;
        double cMinusEpsilon = c - epsilon;
        double yI = this.getTarget(i);
        double eI = this.getError(i);
        double oldAlphaI = this.getAlpha(i);
        double yJ = this.getTarget(j);
        double eJ = this.getError(j);
        double oldAlphaJ = this.getAlpha(j);
        if (yI != yJ) {
            double alphaJMinusAlphaI = oldAlphaJ - oldAlphaI;
            lowerBound = Math.max(0.0, alphaJMinusAlphaI);
            upperBound = Math.min(c, alphaJMinusAlphaI + c);
        } else {
            double alphaIPlusAlphaJ = oldAlphaI + oldAlphaJ;
            lowerBound = Math.max(0.0, alphaIPlusAlphaJ - c);
            upperBound = Math.min(c, alphaIPlusAlphaJ);
        }
        if (lowerBound >= upperBound) {
            return false;
        }
        double kII = this.evaluateKernel(i, i);
        double kJI = kIJ = this.evaluateKernel(i, j);
        double eta = kIJ + kJI - kII - (kJJ = this.evaluateKernel(j, j));
        if (eta >= 0.0) {
            return false;
        }
        double newAlphaJ = oldAlphaJ - yJ * (eI - eJ) / eta;
        if (newAlphaJ <= lowerBound) {
            newAlphaJ = lowerBound;
        } else if (newAlphaJ >= upperBound) {
            newAlphaJ = upperBound;
        }
        if (newAlphaJ < epsilon) {
            newAlphaJ = 0.0;
        } else if (newAlphaJ > cMinusEpsilon) {
            newAlphaJ = c;
        }
        if (Math.abs(newAlphaJ - oldAlphaJ) < epsilon) {
            return false;
        }
        double newAlphaI = oldAlphaI + yI * yJ * (oldAlphaJ - newAlphaJ);
        if (newAlphaI < epsilon) {
            newAlphaI = 0.0;
        } else if (newAlphaI > cMinusEpsilon) {
            newAlphaI = c;
        }
        double oldBias = this.getBias();
        double b1 = oldBias - eI - yI * (newAlphaI - oldAlphaI) * kII - yJ * (newAlphaJ - oldAlphaJ) * kIJ;
        double b2 = oldBias - eJ - yI * (newAlphaI - oldAlphaI) * kJI - yJ * (newAlphaJ - oldAlphaJ) * kJJ;
        double newBias = newAlphaI > epsilon && newAlphaI < cMinusEpsilon ? b1 : (newAlphaJ > epsilon && newAlphaJ < cMinusEpsilon ? b2 : (b1 + b2) / 2.0);
        this.setAlpha(i, newAlphaI);
        this.setAlpha(j, newAlphaJ);
        this.setBias(newBias);
        this.updateErrorCache(i, yI, oldAlphaI, newAlphaI, j, yJ, oldAlphaJ, newAlphaJ, oldBias, newBias);
        return true;
    }

    private void updateErrorCache(int i, double yI, double oldAlphaI, double newAlphaI, int j, double yJ, double oldAlphaJ, double newAlphaJ, double oldBias, double newBias) {
        if (newAlphaI <= 0.0 || newAlphaI >= this.maxPenalty) {
            this.errorCache.remove(i);
        }
        if (newAlphaJ <= 0.0 || newAlphaJ >= this.maxPenalty) {
            this.errorCache.remove(j);
        }
        double weightIChange = yI * (newAlphaI - oldAlphaI);
        double weightJChange = yJ * (newAlphaJ - oldAlphaJ);
        double biasChange = newBias - oldBias;
        for (Integer k : this.nonBoundAlphaIndices) {
            Double oldError = this.errorCache.get(k);
            double newError = k == i || k == j ? 0.0 : (oldError == null ? this.getSVMOutput(k) - this.getTarget(k) : oldError + weightIChange * this.evaluateKernel(i, k) + weightJChange * this.evaluateKernel(j, k) + biasChange);
            this.errorCache.put(k, newError);
        }
    }

    private double evaluateKernel(int i, int j) {
        if (this.kernelCache == null) {
            return this.kernel.evaluate(this.getPoint(i), this.getPoint(j));
        }
        long kernelCacheKey = i <= j ? (long)i << 32 | (long)j : (long)j << 32 | (long)i;
        Double cachedValue = this.kernelCache.get(kernelCacheKey);
        if (cachedValue != null) {
            return cachedValue;
        }
        double value = this.kernel.evaluate(this.getPoint(i), this.getPoint(j));
        this.kernelCache.put(kernelCacheKey, value);
        return value;
    }

    private double getSVMOutput(int i) {
        double result = this.result.getBias();
        for (Map.Entry<Integer, DefaultWeightedValue<InputType>> entry : this.supportsMap.entrySet()) {
            result += entry.getValue().getWeight() * this.evaluateKernel(i, entry.getKey());
        }
        return result;
    }

    private double getError(int i) {
        Double cachedError = this.errorCache.get(i);
        if (cachedError != null) {
            return cachedError;
        }
        double error = this.getSVMOutput(i) - this.getTarget(i);
        return error;
    }

    private int getMinErrorIndex() {
        double minError = Double.POSITIVE_INFINITY;
        int minIndex = -1;
        for (Integer index : this.nonBoundAlphaIndices) {
            double error = this.getError(index);
            if (!(error < minError)) continue;
            minError = error;
            minIndex = index;
        }
        return minIndex;
    }

    private int getMaxErrorIndex() {
        double maxError = Double.NEGATIVE_INFINITY;
        int maxIndex = -1;
        for (Integer index : this.nonBoundAlphaIndices) {
            double error = this.getError(index);
            if (!(error > maxError)) continue;
            maxError = error;
            maxIndex = index;
        }
        return maxIndex;
    }

    private InputType getPoint(int i) {
        return this.dataList.get(i).getInput();
    }

    private double getTarget(int i) {
        return this.dataList.get(i).getOutput() != false ? 1.0 : -1.0;
    }

    private double getAlpha(int i) {
        DefaultWeightedValue<InputType> support = this.supportsMap.get(i);
        if (support == null) {
            return 0.0;
        }
        return Math.abs(support.getWeight());
    }

    private void setAlpha(int i, double alpha) {
        if (alpha == 0.0) {
            this.supportsMap.remove(i);
            this.nonBoundAlphaIndices.remove(i);
        } else {
            double weight = this.getTarget(i) * alpha;
            DefaultWeightedValue support = this.supportsMap.get(i);
            if (support == null) {
                support = new DefaultWeightedValue(this.getPoint(i), weight);
                this.supportsMap.put(i, support);
            } else {
                support.setWeight(weight);
            }
            if (alpha == this.maxPenalty) {
                this.nonBoundAlphaIndices.remove(i);
            } else {
                this.nonBoundAlphaIndices.add(i);
            }
        }
    }

    private double getBias() {
        return this.result.getBias();
    }

    private void setBias(double bias) {
        this.result.setBias(bias);
    }

    public KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> getResult() {
        return this.result;
    }

    @Override
    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    public double getMaxPenalty() {
        return this.maxPenalty;
    }

    public void setMaxPenalty(double maxPenalty) {
        if (maxPenalty <= 0.0) {
            throw new IllegalArgumentException("maxPenalty must be positive.");
        }
        this.maxPenalty = maxPenalty;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public void setErrorTolerance(double errorTolerance) {
        if (errorTolerance < 0.0) {
            throw new IllegalArgumentException("errorTolerance cannot be negative.");
        }
        this.errorTolerance = errorTolerance;
    }

    public double getEffectiveZero() {
        return this.effectiveZero;
    }

    public void setEffectiveZero(double effectiveZero) {
        if (effectiveZero < 0.0) {
            throw new IllegalArgumentException("effectiveZero cannot be negative.");
        }
        this.effectiveZero = effectiveZero;
    }

    public int getKernelCacheSize() {
        return this.kernelCacheSize;
    }

    public void setKernelCacheSize(int kernelCacheSize) {
        if (kernelCacheSize < 0) {
            throw new IllegalArgumentException("kernelCacheSize cannot be negative");
        }
        this.kernelCacheSize = kernelCacheSize;
    }

    public Random getRandom() {
        return this.random;
    }

    public void setRandom(Random random) {
        this.random = random;
    }

    public int getChangeCount() {
        return this.changeCount;
    }
}

