/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.bayesian.conjugate;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.ComplexNumber;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.bayesian.AbstractBayesianParameter;
import gov.sandia.cognition.statistics.bayesian.BayesianParameter;
import gov.sandia.cognition.statistics.bayesian.conjugate.AbstractConjugatePriorBayesianEstimator;
import gov.sandia.cognition.statistics.bayesian.conjugate.ConjugatePriorBayesianEstimatorPredictor;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import java.util.Arrays;

@PublicationReference(author={"William M. Bolstad"}, title="Introduction to Bayesian Statistics: Second Edition", type=PublicationType.Book, year=2007, pages={208})
public class MultivariateGaussianMeanBayesianEstimator
extends AbstractConjugatePriorBayesianEstimator<Vector, Vector, MultivariateGaussian, MultivariateGaussian>
implements ConjugatePriorBayesianEstimatorPredictor<Vector, Vector, MultivariateGaussian, MultivariateGaussian> {
    public static final int DEFAULT_DIMENSIONALITY = 1;

    public MultivariateGaussianMeanBayesianEstimator() {
        this(1);
    }

    public MultivariateGaussianMeanBayesianEstimator(int dimensionality) {
        this(MatrixFactory.getDefault().createIdentity(dimensionality, dimensionality));
    }

    public MultivariateGaussianMeanBayesianEstimator(Matrix knownCovarianceInverse) {
        this(knownCovarianceInverse, new MultivariateGaussian(VectorFactory.getDefault().createVector(knownCovarianceInverse.getNumRows()), MatrixFactory.getDefault().createIdentity(knownCovarianceInverse.getNumRows(), knownCovarianceInverse.getNumColumns())));
    }

    public MultivariateGaussianMeanBayesianEstimator(Matrix knownCovarianceInverse, MultivariateGaussian belief) {
        this(new MultivariateGaussian(VectorFactory.getDefault().createVector(knownCovarianceInverse.getNumRows()), knownCovarianceInverse.inverse()), belief);
    }

    public MultivariateGaussianMeanBayesianEstimator(MultivariateGaussian conditional, MultivariateGaussian prior) {
        this(new Parameter(conditional, prior));
    }

    protected MultivariateGaussianMeanBayesianEstimator(BayesianParameter<Vector, MultivariateGaussian, MultivariateGaussian> parameter) {
        super(parameter);
    }

    public Parameter createParameter(MultivariateGaussian conditional, MultivariateGaussian prior) {
        return new Parameter(conditional, prior);
    }

    public Matrix getKnownCovarianceInverse() {
        return ((MultivariateGaussian)this.parameter.getConditionalDistribution()).getCovarianceInverse();
    }

    public void setKnownCovarianceInverse(Matrix knownCovarianceInverse) {
        if (!knownCovarianceInverse.isSymmetric() || knownCovarianceInverse.rank() != knownCovarianceInverse.getNumRows()) {
            throw new IllegalArgumentException("Covariance inverse must be symmetric and invertible!");
        }
        ((MultivariateGaussian)this.parameter.getConditionalDistribution()).setCovariance(knownCovarianceInverse.inverse());
    }

    @Override
    public void update(MultivariateGaussian target, Iterable<? extends Vector> data) {
        int N = CollectionUtil.size(data);
        Matrix Ci0 = target.getCovarianceInverse();
        Matrix CiN = this.getKnownCovarianceInverse().clone();
        if (N > 1) {
            CiN.scaleEquals((double)N);
        }
        Vector sampleMean = (Vector)MultivariateStatisticsUtil.computeMean(data);
        Vector t0 = Ci0.times(target.getMean());
        t0.plusEquals((Ring)CiN.times(sampleMean));
        CiN.plusEquals((Ring)Ci0);
        Matrix updatedCovariance = CiN.inverse();
        Vector updatedMean = updatedCovariance.times(t0);
        target.setMean(updatedMean);
        target.setCovariance(updatedCovariance);
    }

    @Override
    public void update(MultivariateGaussian updater, Vector data) {
        this.update(updater, (Iterable<? extends Vector>)Arrays.asList(data));
    }

    @Override
    public double computeEquivalentSampleSize(MultivariateGaussian belief) {
        ComplexNumber logR = (ComplexNumber)belief.getCovarianceInverse().logDeterminant().minus((Ring)this.getKnownCovarianceInverse().logDeterminant());
        return Math.exp(logR.getMagnitude() / (double)belief.getMean().getDimensionality());
    }

    public MultivariateGaussian createPredictiveDistribution(MultivariateGaussian posterior) {
        Vector mean = posterior.getMean().clone();
        Matrix C = (Matrix)posterior.getCovariance().plus((Ring)((MultivariateGaussian)this.parameter.getConditionalDistribution()).getCovariance());
        return new MultivariateGaussian(mean, C);
    }

    @Override
    public MultivariateGaussian createConditionalDistribution(Vector parameter) {
        parameter.assertDimensionalityEquals(((MultivariateGaussian)this.parameter.getConditionalDistribution()).getInputDimensionality());
        return (MultivariateGaussian)super.createConditionalDistribution(parameter);
    }

    public static class Parameter
    extends AbstractBayesianParameter<Vector, MultivariateGaussian, MultivariateGaussian> {
        public static final String NAME = "mean";

        public Parameter(MultivariateGaussian conditional, MultivariateGaussian prior) {
            super(conditional, NAME, prior);
        }

        @Override
        public void setValue(Vector value) {
            value.assertDimensionalityEquals(((MultivariateGaussian)this.conditionalDistribution).getInputDimensionality());
            ((MultivariateGaussian)this.conditionalDistribution).setMean(value);
        }

        public Vector getValue() {
            return ((MultivariateGaussian)this.conditionalDistribution).getMean();
        }
    }
}

