/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.RingAccumulator;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.AbstractSufficientStatistic;
import gov.sandia.cognition.statistics.bayesian.AbstractBayesianRegression;
import gov.sandia.cognition.statistics.distribution.InverseGammaDistribution;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussianInverseGammaDistribution;
import gov.sandia.cognition.statistics.distribution.StudentTDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Collection;

@PublicationReferences(references={@PublicationReference(author={"Christopher M. Bishop"}, title="Pattern Recognition and Machine Learning", type=PublicationType.Book, year=2006, pages={152, 159}), @PublicationReference(author={"Jan Drugowitsch"}, title="Bayesian Linear Regression", type=PublicationType.Misc, year=2009, url="http://www.bcs.rochester.edu/people/jdrugowitsch/code/bayes_linear_notes_0.1.1.pdf")})
public class BayesianRobustLinearRegression<InputType>
extends AbstractBayesianRegression<InputType, Double, MultivariateGaussianInverseGammaDistribution> {
    public static final double DEFAULT_WEIGHT_VARIANCE = 1.0;
    private MultivariateGaussian weightPrior;
    private InverseGammaDistribution outputVariance;

    public BayesianRobustLinearRegression(int dimensionality) {
        this(null, new InverseGammaDistribution(), new MultivariateGaussian(VectorFactory.getDefault().createVector(dimensionality), (Matrix)MatrixFactory.getDefault().createIdentity(dimensionality, dimensionality).scale(1.0)));
    }

    public BayesianRobustLinearRegression(Evaluator<? super InputType, Vector> featureMap, InverseGammaDistribution outputVariance, MultivariateGaussian weightPrior) {
        super(featureMap);
        this.setWeightPrior(weightPrior);
        this.setOutputVariance(outputVariance);
    }

    @Override
    public BayesianRobustLinearRegression<InputType> clone() {
        BayesianRobustLinearRegression clone = (BayesianRobustLinearRegression)super.clone();
        clone.setWeightPrior((MultivariateGaussian)ObjectUtil.cloneSafe((CloneableSerializable)this.getWeightPrior()));
        clone.setOutputVariance((InverseGammaDistribution)ObjectUtil.cloneSafe((CloneableSerializable)this.getOutputVariance()));
        return clone;
    }

    public MultivariateGaussian getWeightPrior() {
        return this.weightPrior;
    }

    public void setWeightPrior(MultivariateGaussian weightPrior) {
        this.weightPrior = weightPrior;
    }

    @Override
    public MultivariateGaussianInverseGammaDistribution learn(Collection<? extends InputOutputPair<? extends InputType, Double>> data) {
        MultivariateGaussian g = this.weightPrior;
        RingAccumulator Cin = new RingAccumulator();
        Matrix Ci = g.getCovarianceInverse();
        Cin.accumulate((Ring)Ci);
        RingAccumulator zn = new RingAccumulator();
        Vector z = Ci.times(g.getMean());
        zn.accumulate((Ring)z);
        InverseGammaDistribution ig = this.outputVariance;
        double an = ig.getShape();
        double bn = ig.getScale();
        double sy2 = 0.0;
        for (InputOutputPair<InputType, Double> pair : data) {
            Vector x1 = (Vector)this.featureMap.evaluate(pair.getInput());
            Vector x2 = x1.clone();
            double beta = DatasetUtil.getWeight(pair);
            if (beta != 1.0) {
                x2.scaleEquals(beta);
            }
            Cin.accumulate((Ring)x1.outerProduct(x2));
            double y = pair.getOutput();
            if (y != 1.0) {
                x2.scaleEquals(y);
            }
            zn.accumulate((Ring)x2);
            sy2 += y * y;
            an += 0.5;
        }
        Ci = (Matrix)Cin.getSum();
        Matrix C = Ci.inverse();
        z = (Vector)zn.getSum();
        Vector mean = C.times(z);
        return new MultivariateGaussianInverseGammaDistribution(new MultivariateGaussian(mean, C), new InverseGammaDistribution(an, bn += 0.5 * (sy2 - mean.times(Ci).dotProduct(mean))));
    }

    public InverseGammaDistribution getOutputVariance() {
        return this.outputVariance;
    }

    public void setOutputVariance(InverseGammaDistribution outputVariance) {
        this.outputVariance = outputVariance;
    }

    public UnivariateGaussian createConditionalDistribution(InputType input, Vector weights) {
        double mean = ((Vector)this.featureMap.evaluate(input)).dotProduct(weights);
        double variance = this.getOutputVariance().getMean();
        return new UnivariateGaussian(mean, variance);
    }

    public PredictiveDistribution createPredictiveDistribution(MultivariateGaussianInverseGammaDistribution posterior) {
        return new PredictiveDistribution(posterior);
    }

    public static class IncrementalEstimator<InputType>
    extends BayesianRobustLinearRegression<InputType>
    implements IncrementalLearner<InputOutputPair<? extends InputType, Double>, SufficientStatistic> {
        public IncrementalEstimator(int dimensionality) {
            super(dimensionality);
        }

        public IncrementalEstimator(int dimensionality, Evaluator<? super InputType, Vector> featureMap) {
            this(dimensionality);
            this.setFeatureMap(featureMap);
        }

        public IncrementalEstimator(Evaluator<? super InputType, Vector> featureMap, InverseGammaDistribution outputVariance, MultivariateGaussian weightPrior) {
            super(featureMap, outputVariance, weightPrior);
        }

        @Override
        public SufficientStatistic createInitialLearnedObject() {
            return new SufficientStatistic(new MultivariateGaussianInverseGammaDistribution(this.getWeightPrior(), this.getOutputVariance()));
        }

        @Override
        public MultivariateGaussianInverseGammaDistribution learn(Collection<? extends InputOutputPair<? extends InputType, Double>> data) {
            SufficientStatistic target = this.createInitialLearnedObject();
            this.update(target, (Iterable<? extends InputOutputPair<? extends InputType, Double>>)data);
            return target.create();
        }

        @Override
        public void update(SufficientStatistic target, InputOutputPair<? extends InputType, Double> data) {
            target.update(data);
        }

        @Override
        public void update(SufficientStatistic target, Iterable<? extends InputOutputPair<? extends InputType, Double>> data) {
            target.update(data);
        }

        public class SufficientStatistic
        extends AbstractSufficientStatistic<InputOutputPair<? extends InputType, Double>, MultivariateGaussianInverseGammaDistribution> {
            private double outputSumSquared;
            private Vector z;
            private Matrix covarianceInverse;

            public SufficientStatistic(MultivariateGaussianInverseGammaDistribution prior) {
                if (prior != null) {
                    Vector mean = prior.getMean();
                    this.covarianceInverse = prior.getGaussian().getCovarianceInverse().clone();
                    this.z = this.covarianceInverse.times(mean);
                    double a0 = prior.getInverseGamma().getShape();
                    double b0 = prior.getInverseGamma().getScale();
                    this.count = (long)Math.ceil(2.0 * a0);
                    this.outputSumSquared = 2.0 * b0 + mean.dotProduct(this.z);
                } else {
                    this.covarianceInverse = null;
                    this.z = null;
                    this.count = 0L;
                    this.outputSumSquared = 0.0;
                }
            }

            @Override
            public void update(InputOutputPair<? extends InputType, Double> value) {
                Vector v;
                ++this.count;
                Vector x1 = v = (Vector)IncrementalEstimator.this.featureMap.evaluate(value.getInput());
                Vector x2 = v.clone();
                double y = value.getOutput();
                double beta = DatasetUtil.getWeight(value);
                if (beta != 1.0) {
                    x2.scaleEquals(beta);
                }
                if (this.covarianceInverse == null) {
                    this.covarianceInverse = x1.outerProduct(x2);
                } else {
                    this.covarianceInverse.plusEquals((Ring)x1.outerProduct(x2));
                }
                if (y != 1.0) {
                    x2.scaleEquals(y);
                }
                if (this.z == null) {
                    this.z = x2;
                } else {
                    this.z.plusEquals((Ring)x2);
                }
                this.outputSumSquared += y * y;
            }

            public MultivariateGaussianInverseGammaDistribution create() {
                MultivariateGaussianInverseGammaDistribution g = new MultivariateGaussianInverseGammaDistribution(this.getDimensionality());
                this.create(g);
                return g;
            }

            @Override
            public void create(MultivariateGaussianInverseGammaDistribution distribution) {
                distribution.getGaussian().setMean(this.getMean());
                distribution.getGaussian().setCovarianceInverse(this.getCovarianceInverse());
                distribution.getInverseGamma().setShape(this.getShape());
                distribution.getInverseGamma().setScale(this.getScale());
            }

            public Matrix getCovarianceInverse() {
                return this.covarianceInverse;
            }

            public Vector getZ() {
                return this.z;
            }

            public Vector getMean() {
                return this.covarianceInverse.inverse().times(this.z);
            }

            public int getDimensionality() {
                return this.getZ().getDimensionality();
            }

            public double getOutputSumSquared() {
                return this.outputSumSquared;
            }

            public double getShape() {
                return (double)this.getCount() / 2.0;
            }

            public double getScale() {
                Vector mean = this.getMean();
                Matrix Ci = this.covarianceInverse;
                return 0.5 * (this.outputSumSquared - mean.times(Ci).dotProduct(mean));
            }
        }
    }

    public class PredictiveDistribution
    extends AbstractCloneableSerializable
    implements Evaluator<InputType, StudentTDistribution> {
        private MultivariateGaussianInverseGammaDistribution posterior;

        public PredictiveDistribution(MultivariateGaussianInverseGammaDistribution posterior) {
            this.posterior = posterior;
        }

        public StudentTDistribution evaluate(InputType input) {
            Vector x = (Vector)BayesianRobustLinearRegression.this.featureMap.evaluate(input);
            double mean = x.dotProduct(this.posterior.getMean());
            double dofs = this.posterior.getInverseGamma().getShape() * 2.0;
            double v = x.times(this.posterior.getGaussian().getCovariance()).dotProduct(x);
            double anbn = this.posterior.getInverseGamma().getShape() / this.posterior.getInverseGamma().getScale();
            double precision = anbn / (1.0 + v);
            return new StudentTDistribution(dofs, mean, precision);
        }
    }
}

