/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.gmm;

import Jama.Matrix;
import gnu.trove.list.array.TDoubleArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import org.openimaj.math.matrix.MatrixUtils;
import org.openimaj.math.statistics.MeanAndCovariance;
import org.openimaj.math.statistics.distribution.AbstractMultivariateGaussian;
import org.openimaj.math.statistics.distribution.DiagonalMultivariateGaussian;
import org.openimaj.math.statistics.distribution.FullMultivariateGaussian;
import org.openimaj.math.statistics.distribution.MixtureOfGaussians;
import org.openimaj.math.statistics.distribution.MultivariateGaussian;
import org.openimaj.math.statistics.distribution.SphericalMultivariateGaussian;
import org.openimaj.ml.clustering.kmeans.DoubleKMeans;
import org.openimaj.util.array.ArrayUtils;
import org.openimaj.util.pair.IndependentPair;

public class GaussianMixtureModelEM {
    private static final double DEFAULT_THRESH = 0.01;
    private static final double DEFAULT_MIN_COVAR = 0.001;
    private static final int DEFAULT_NITERS = 100;
    private static final int DEFAULT_NINIT = 1;
    CovarianceType ctype;
    int nComponents;
    private double thresh;
    private double minCovar;
    private int nIters;
    private int nInit;
    private boolean converged = false;
    private EnumSet<UpdateOptions> initOpts;
    private EnumSet<UpdateOptions> iterOpts;

    public GaussianMixtureModelEM(int nComponents, CovarianceType ctype, double thresh, double minCovar, int nIters, int nInit, EnumSet<UpdateOptions> iterOpts, EnumSet<UpdateOptions> initOpts) {
        this.ctype = ctype;
        this.nComponents = nComponents;
        this.thresh = thresh;
        this.minCovar = minCovar;
        this.nIters = nIters;
        this.nInit = nInit;
        this.iterOpts = iterOpts;
        this.initOpts = initOpts;
        if (nInit < 1) {
            throw new IllegalArgumentException("GMM estimation requires at least one run");
        }
        this.converged = false;
    }

    public GaussianMixtureModelEM(int nComponents, CovarianceType ctype) {
        this(nComponents, ctype, 0.01, 0.001, 100, 1, EnumSet.allOf(UpdateOptions.class), EnumSet.allOf(UpdateOptions.class));
    }

    public boolean hasConverged() {
        return this.converged;
    }

    public MixtureOfGaussians estimate(Matrix X) {
        return this.estimate(X.getArray());
    }

    public MixtureOfGaussians estimate(double[][] X) {
        EMGMM gmm = new EMGMM(this.nComponents);
        if (X.length < this.nComponents) {
            throw new IllegalArgumentException(String.format("GMM estimation with %d components, but got only %d samples", this.nComponents, X.length));
        }
        double max_log_prob = Double.NEGATIVE_INFINITY;
        block0: for (int j = 0; j < this.nInit; ++j) {
            gmm.gaussians = this.ctype.createGaussians(this.nComponents, X[0].length);
            if (this.initOpts.contains((Object)UpdateOptions.Means)) {
                DoubleKMeans km = DoubleKMeans.createExact(this.nComponents);
                DoubleKMeans.Result means = km.cluster(X);
                for (int i = 0; i < this.nComponents; ++i) {
                    ((AbstractMultivariateGaussian)gmm.gaussians[i]).mean.getArray()[0] = means.centroids[i];
                }
            }
            if (this.initOpts.contains((Object)UpdateOptions.Weights)) {
                gmm.weights = new double[this.nComponents];
                Arrays.fill(gmm.weights, 1.0 / (double)this.nComponents);
            }
            if (this.initOpts.contains((Object)UpdateOptions.Covariances)) {
                Matrix cv = MeanAndCovariance.computeCovariance((double[][])X);
                this.ctype.setCovariances(gmm.gaussians, cv);
            }
            TDoubleArrayList log_likelihood = new TDoubleArrayList();
            this.converged = false;
            double[] bestWeights = null;
            MultivariateGaussian[] bestMixture = null;
            for (int i = 0; i < this.nIters; ++i) {
                IndependentPair score = gmm.scoreSamples(X);
                double[] curr_log_likelihood = (double[])score.firstObject();
                double[][] responsibilities = (double[][])score.secondObject();
                log_likelihood.add(ArrayUtils.sumValues((double[])curr_log_likelihood));
                if (i > 0 && Math.abs(log_likelihood.get(i) - log_likelihood.get(i - 1)) < this.thresh) {
                    this.converged = true;
                    continue block0;
                }
                this.mstep(gmm, X, responsibilities);
                if (this.nIters > 0 && log_likelihood.getQuick(i) > max_log_prob) {
                    max_log_prob = log_likelihood.getQuick(i);
                    bestWeights = gmm.weights;
                    bestMixture = gmm.gaussians;
                }
                if (Double.isInfinite(max_log_prob) && this.nIters > 0) {
                    throw new RuntimeException("EM algorithm was never able to compute a valid likelihood given initial parameters. Try different init parameters (or increasing n_init) or check for degenerate data.");
                }
                if (this.nIters <= 0) continue;
                gmm.gaussians = bestMixture;
                gmm.weights = bestWeights;
            }
        }
        return gmm;
    }

    protected void mstep(EMGMM gmm, double[][] X, double[][] responsibilities) {
        double[] weights = ArrayUtils.colSum((double[][])responsibilities);
        Matrix resMat = new Matrix(responsibilities);
        Matrix Xmat = new Matrix(X);
        Matrix weighted_X_sum = resMat.transpose().times(Xmat);
        double[] inverse_weights = new double[weights.length];
        for (int i = 0; i < inverse_weights.length; ++i) {
            inverse_weights[i] = 1.0 / (weights[i] + (double)1.110223E-15f);
        }
        if (this.iterOpts.contains((Object)UpdateOptions.Weights)) {
            double sum = ArrayUtils.sumValues((double[])weights);
            for (int i = 0; i < weights.length; ++i) {
                gmm.weights[i] = weights[i] / (sum + (double)1.110223E-15f) + (double)1.110223E-16f;
            }
        }
        if (this.iterOpts.contains((Object)UpdateOptions.Means)) {
            double[][] wx = weighted_X_sum.getArray();
            for (int i = 0; i < this.nComponents; ++i) {
                double[][] m = ((AbstractMultivariateGaussian)gmm.gaussians[i]).mean.getArray();
                for (int j = 0; j < m[0].length; ++j) {
                    m[0][j] = wx[i][j] * inverse_weights[i];
                }
            }
        }
        if (this.iterOpts.contains((Object)UpdateOptions.Covariances)) {
            this.ctype.mstep(gmm, this, Xmat, resMat, weighted_X_sum, inverse_weights);
        }
    }

    public GaussianMixtureModelEM clone() {
        try {
            return (GaussianMixtureModelEM)super.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    protected static class EMGMM
    extends MixtureOfGaussians {
        EMGMM(int nComponents) {
            super(null, null);
            this.weights = new double[nComponents];
            Arrays.fill(this.weights, 1.0 / (double)nComponents);
        }
    }

    public static enum UpdateOptions {
        Means,
        Weights,
        Covariances;

    }

    public static enum CovarianceType {
        Spherical{

            @Override
            protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) {
                double mean = 0.0;
                for (int i = 0; i < cv.getRowDimension(); ++i) {
                    for (int j = 0; j < cv.getColumnDimension(); ++j) {
                        mean += cv.get(i, j);
                    }
                }
                mean /= (double)(cv.getColumnDimension() * cv.getRowDimension());
                for (MultivariateGaussian mg : gaussians) {
                    ((SphericalMultivariateGaussian)mg).variance = mean;
                }
            }

            @Override
            protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
                MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
                for (int i = 0; i < ngauss; ++i) {
                    arr[i] = new SphericalMultivariateGaussian(ndims);
                }
                return arr;
            }

            @Override
            protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, Matrix weightedXsum, double[] norm) {
                Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X));
                for (int i = 0; i < gmm.gaussians.length; ++i) {
                    Matrix weightedXsumi = new Matrix((double[][])new double[][]{weightedXsum.getArray()[i]});
                    Matrix avgX2uwi = new Matrix((double[][])new double[][]{avgX2uw.getArray()[i]});
                    Matrix avgX2 = avgX2uwi.times(norm[i]);
                    Matrix mu = ((AbstractMultivariateGaussian)gmm.gaussians[i]).mean;
                    Matrix avgMeans2 = MatrixUtils.pow((Matrix)mu, (double)2.0);
                    Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]);
                    Matrix covar = MatrixUtils.plus((Matrix)avgX2.minus(avgXmeans.times(2.0)).plus(avgMeans2), (double)learner.minCovar);
                    ((SphericalMultivariateGaussian)gmm.gaussians[i]).variance = MatrixUtils.sum((Matrix)covar) / (double)X.getColumnDimension();
                }
            }
        }
        ,
        Diagonal{

            @Override
            protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) {
                for (MultivariateGaussian mg : gaussians) {
                    ((DiagonalMultivariateGaussian)mg).variance = MatrixUtils.diagVector((Matrix)cv);
                }
            }

            @Override
            protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
                MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
                for (int i = 0; i < ngauss; ++i) {
                    arr[i] = new DiagonalMultivariateGaussian(ndims);
                }
                return arr;
            }

            @Override
            protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, Matrix weightedXsum, double[] norm) {
                Matrix avgX2uw = responsibilities.transpose().times(X.arrayTimes(X));
                for (int i = 0; i < gmm.gaussians.length; ++i) {
                    Matrix weightedXsumi = new Matrix((double[][])new double[][]{weightedXsum.getArray()[i]});
                    Matrix avgX2uwi = new Matrix((double[][])new double[][]{avgX2uw.getArray()[i]});
                    Matrix avgX2 = avgX2uwi.times(norm[i]);
                    Matrix mu = ((AbstractMultivariateGaussian)gmm.gaussians[i]).mean;
                    Matrix avgMeans2 = MatrixUtils.pow((Matrix)mu, (double)2.0);
                    Matrix avgXmeans = mu.arrayTimes(weightedXsumi).times(norm[i]);
                    Matrix covar = MatrixUtils.plus((Matrix)avgX2.minus(avgXmeans.times(2.0)).plus(avgMeans2), (double)learner.minCovar);
                    ((DiagonalMultivariateGaussian)gmm.gaussians[i]).variance = covar.getArray()[0];
                }
            }
        }
        ,
        Full{

            @Override
            protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
                MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
                for (int i = 0; i < ngauss; ++i) {
                    arr[i] = new FullMultivariateGaussian(ndims);
                }
                return arr;
            }

            @Override
            protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) {
                for (MultivariateGaussian mg : gaussians) {
                    ((FullMultivariateGaussian)mg).covar = cv.copy();
                }
            }

            @Override
            protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, Matrix weightedXsum, double[] norm) {
                int nfeatures = X.getColumnDimension();
                for (int c = 0; c < learner.nComponents; ++c) {
                    Matrix post = responsibilities.getMatrix(0, X.getRowDimension() - 1, c, c).transpose();
                    double factor = 1.0 / (ArrayUtils.sumValues((double[][])post.getArray()) + (double)1.110223E-15f);
                    Matrix pXt = X.transpose();
                    for (int i = 0; i < pXt.getRowDimension(); ++i) {
                        for (int j = 0; j < pXt.getColumnDimension(); ++j) {
                            pXt.set(i, j, pXt.get(i, j) * post.get(0, j));
                        }
                    }
                    Matrix argcv = pXt.times(X).times(factor);
                    Matrix mu = ((FullMultivariateGaussian)gmm.gaussians[c]).mean;
                    ((FullMultivariateGaussian)gmm.gaussians[c]).covar = argcv.minusEquals(mu.transpose().times(mu)).plusEquals(Matrix.identity((int)nfeatures, (int)nfeatures).times(learner.minCovar));
                }
            }
        }
        ,
        Tied{

            @Override
            protected void setCovariances(MultivariateGaussian[] gaussians, Matrix cv) {
                for (MultivariateGaussian mg : gaussians) {
                    ((FullMultivariateGaussian)mg).covar = cv;
                }
            }

            @Override
            protected MultivariateGaussian[] createGaussians(int ngauss, int ndims) {
                MultivariateGaussian[] arr = new MultivariateGaussian[ngauss];
                Matrix covar = new Matrix(ndims, ndims);
                for (int i = 0; i < ngauss; ++i) {
                    arr[i] = new FullMultivariateGaussian(new Matrix(1, ndims), covar);
                }
                return arr;
            }

            @Override
            protected void mstep(EMGMM gmm, GaussianMixtureModelEM learner, Matrix X, Matrix responsibilities, Matrix weightedXsum, double[] norm) {
                int nfeatures = X.getColumnDimension();
                Matrix avgX2 = X.transpose().times(X);
                double[][] mudata = new double[gmm.gaussians.length][];
                for (int i = 0; i < mudata.length; ++i) {
                    mudata[i] = ((FullMultivariateGaussian)gmm.gaussians[i]).mean.getArray()[0];
                }
                Matrix mu = new Matrix((double[][])mudata);
                Matrix avgMeans2 = mu.transpose().times(weightedXsum);
                Matrix covar = avgX2.minus(avgMeans2).plus(Matrix.identity((int)nfeatures, (int)nfeatures).times(learner.minCovar)).times(1.0 / (double)X.getRowDimension());
                for (int i = 0; i < learner.nComponents; ++i) {
                    ((FullMultivariateGaussian)gmm.gaussians[i]).covar = covar;
                }
            }
        };


        protected abstract MultivariateGaussian[] createGaussians(int var1, int var2);

        protected abstract void setCovariances(MultivariateGaussian[] var1, Matrix var2);

        protected abstract void mstep(EMGMM var1, GaussianMixtureModelEM var2, Matrix var3, Matrix var4, Matrix var5, double[] var6);
    }
}

