/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.confidence;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.confidence.ConfidenceWeightedDiagonalDeviation;
import gov.sandia.cognition.learning.function.categorization.DiagonalConfidenceWeightedBinaryCategorizer;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.VectorFactory;

@PublicationReference(author={"Koby Crammer", "Mark Dredze", "Fernando Pereira"}, title="Exact Convex Confidence-Weighted Learning", year=2008, type=PublicationType.Conference, publication="Advances in Neural Information Processing Systems", url="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.169.3364")
public class ConfidenceWeightedDiagonalDeviationProject
extends ConfidenceWeightedDiagonalDeviation {
    public ConfidenceWeightedDiagonalDeviationProject() {
        this(0.85, 1.0);
    }

    public ConfidenceWeightedDiagonalDeviationProject(double confidence, double defaultVariance) {
        super(confidence, defaultVariance);
    }

    @Override
    public void update(DiagonalConfidenceWeightedBinaryCategorizer target, Vector input, boolean label) {
        boolean update;
        Vector variance;
        Vector mean;
        if (!target.isInitialized()) {
            int dimensionality = input.getDimensionality();
            mean = VectorFactory.getDenseDefault().createVector(dimensionality);
            variance = VectorFactory.getDenseDefault().createVector(dimensionality, this.getDefaultVariance());
            target.setMean(mean);
            target.setVariance(variance);
        } else {
            mean = target.getMean();
            variance = target.getVariance();
        }
        double predicted = input.dotProduct(mean);
        double actual = label ? 1.0 : -1.0;
        double margin = actual * predicted;
        Vector varianceTimesInput = (Vector)input.dotTimes((Ring)variance);
        double marginVariance = input.dotProduct(varianceTimesInput);
        double m = margin;
        double v = marginVariance;
        boolean bl = update = v > 0.0 && m <= this.phi * Math.sqrt(v);
        if (!update) {
            return;
        }
        double alpha = (-m * this.psi + Math.sqrt(m * m * Math.pow(this.phi, 4.0) / 4.0 + v * this.phi * this.phi * this.epsilon)) / (v * this.epsilon);
        double u = 0.25 * Math.pow(-alpha * v * this.phi + Math.sqrt(alpha * alpha * v * v * this.phi * this.phi + 4.0 * v), 2.0);
        double sqrtU = Math.sqrt(u);
        double factor = alpha * this.phi / sqrtU;
        if (alpha > 0.0) {
            Vector meanUpdate = (Vector)varianceTimesInput.scale(actual * alpha);
            mean.plusEquals((Ring)meanUpdate);
            if (u > 0.0 && sqrtU > 0.0) {
                for (VectorEntry entry : input) {
                    int index = entry.getIndex();
                    double value = entry.getValue();
                    double sigma = variance.getElement(index);
                    double newSigma = 1.0 / sigma + factor * value * value;
                    newSigma = 1.0 / newSigma;
                    variance.setElement(index, newSigma);
                }
            }
        }
        target.setMean(mean);
        target.setVariance(variance);
    }
}

