/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.clustering.kmeans;

import com.rits.cloning.Cloner;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import org.openimaj.data.ArrayBackedDataSource;
import org.openimaj.data.DataSource;
import org.openimaj.feature.FeatureVector;
import org.openimaj.knn.ObjectNearestNeighbours;
import org.openimaj.knn.ObjectNearestNeighboursExact;
import org.openimaj.knn.ObjectNearestNeighboursProvider;
import org.openimaj.ml.clustering.FeatureVectorCentroidsResult;
import org.openimaj.ml.clustering.IndexClusters;
import org.openimaj.ml.clustering.SpatialClusterer;
import org.openimaj.ml.clustering.assignment.HardAssigner;
import org.openimaj.ml.clustering.assignment.hard.ExactFeatureVectorAssigner;
import org.openimaj.ml.clustering.kmeans.FeatureVectorKMeansInit;
import org.openimaj.ml.clustering.kmeans.KMeansConfiguration;
import org.openimaj.util.comparator.DistanceComparator;
import org.openimaj.util.pair.IntFloatPair;

public class FeatureVectorKMeans<T extends FeatureVector>
implements SpatialClusterer<FeatureVectorCentroidsResult<T>, T> {
    private FeatureVectorKMeansInit<T> init = new FeatureVectorKMeansInit.RANDOM();
    private KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf;
    private Random rng = new Random();

    public FeatureVectorKMeans(KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf) {
        this.conf = conf;
    }

    protected FeatureVectorKMeans() {
        this(new KMeansConfiguration());
    }

    public FeatureVectorKMeansInit<T> getInit() {
        return this.init;
    }

    public void setInit(FeatureVectorKMeansInit<T> init) {
        this.init = init;
    }

    public void seed(long seed) {
        this.rng = seed < 0L ? new Random() : new Random(seed);
    }

    @Override
    public Result<T> cluster(List<T> data) {
        FeatureVector[] d = (FeatureVector[])Array.newInstance(((FeatureVector)data.get(0)).getClass(), data.size());
        d = data.toArray(d);
        return this.cluster(d);
    }

    @Override
    public Result<T> cluster(T[] data) {
        ArrayBackedDataSource ds = new ArrayBackedDataSource<T>((FeatureVector[])data, this.rng){

            public int numDimensions() {
                return ((FeatureVector[])this.data)[0].length();
            }
        };
        try {
            Result<T> result = this.cluster((DataSource<T>)ds, this.conf.K);
            result.nn = (ObjectNearestNeighbours)this.conf.factory.create((Object[])result.centroids);
            return result;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public int[][] performClustering(T[] data) {
        Result clusters = this.cluster((FeatureVector[])data);
        return new IndexClusters(((FeatureVectorCentroidsResult)clusters).defaultHardAssigner().assign(data)).clusters();
    }

    public int[][] performClustering(List<T> data) {
        FeatureVector[] d = (FeatureVector[])Array.newInstance(((FeatureVector)data.get(0)).getClass(), data.size());
        d = data.toArray(d);
        Result clusters = this.cluster(d);
        return new IndexClusters(((FeatureVectorCentroidsResult)clusters).defaultHardAssigner().assign((DATATYPE[])d)).clusters();
    }

    protected Result<T> cluster(DataSource<T> data, int K) throws Exception {
        Result result = new Result();
        result.centroids = (FeatureVector[])data.createTemporaryArray(K);
        this.init.initKMeans(data, result.centroids);
        this.cluster(data, result);
        return result;
    }

    public void cluster(T[] data, Result<T> result) throws InterruptedException {
        ArrayBackedDataSource ds = new ArrayBackedDataSource<T>((FeatureVector[])data, this.rng){

            public int numDimensions() {
                return ((FeatureVector[])this.data)[0].length();
            }
        };
        this.cluster((DataSource<T>)ds, result);
    }

    protected void cluster(DataSource<T> data, Result<T> result) throws InterruptedException {
        Object[] centroids = result.centroids;
        int K = centroids.length;
        int D = centroids[0].length();
        int N = data.size();
        double[][] centroids_accum = new double[K][D];
        int[] new_counts = new int[K];
        ExecutorService service = this.conf.threadpool;
        for (int i = 0; i < this.conf.niters; ++i) {
            ++result.iterations;
            for (int j = 0; j < K; ++j) {
                Arrays.fill(centroids_accum[j], 0.0);
            }
            Arrays.fill(new_counts, 0);
            ObjectNearestNeighbours nno = (ObjectNearestNeighbours)this.conf.factory.create(centroids);
            ArrayList<CentroidAssignmentJob<T>> jobs = new ArrayList<CentroidAssignmentJob<T>>();
            for (int bl = 0; bl < N; bl += this.conf.blockSize) {
                int br = Math.min(bl + this.conf.blockSize, N);
                jobs.add(new CentroidAssignmentJob<T>(data, bl, br, nno, centroids_accum, new_counts));
            }
            service.invokeAll(jobs);
            result.changedCentroidCount = 0;
            for (int k = 0; k < K; ++k) {
                double ssd = 0.0;
                if (new_counts[k] == 0) {
                    new_counts[k] = 1;
                    Object[] rnd = (FeatureVector[])data.createTemporaryArray(1);
                    data.getRandomRows(rnd);
                    Cloner cloner = new Cloner();
                    centroids[k] = (FeatureVector)cloner.deepClone(rnd[0]);
                    ++result.changedCentroidCount;
                    continue;
                }
                for (int d = 0; d < D; ++d) {
                    double newValue = centroids_accum[k][d] / (double)new_counts[k];
                    double diff = newValue - centroids[k].getAsDouble(d);
                    ssd += diff * diff;
                    centroids[k].setFromDouble(d, newValue);
                }
                if (ssd == 0.0) continue;
                ++result.changedCentroidCount;
            }
            if (result.changedCentroidCount == 0) break;
        }
    }

    @Override
    public FeatureVectorCentroidsResult<T> cluster(DataSource<T> ds) {
        try {
            Result<T> result = this.cluster(ds, this.conf.K);
            result.nn = (ObjectNearestNeighbours)this.conf.factory.create((Object[])result.centroids);
            return result;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public KMeansConfiguration<ObjectNearestNeighbours<T>, T> getConfiguration() {
        return this.conf;
    }

    public void setConfiguration(KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf) {
        this.conf = conf;
    }

    public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int K, DistanceComparator<? super T> distance) {
        KMeansConfiguration conf = new KMeansConfiguration(K, new ObjectNearestNeighboursExact.Factory(distance));
        return new FeatureVectorKMeans(conf);
    }

    public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int K, DistanceComparator<? super T> distance, int niters) {
        KMeansConfiguration conf = new KMeansConfiguration(K, new ObjectNearestNeighboursExact.Factory(distance), niters);
        return new FeatureVectorKMeans(conf);
    }

    public String toString() {
        return String.format("%s: {K=%d, NN=%s}", this.getClass().getSimpleName(), this.conf.K, this.conf.getNearestNeighbourFactory().getClass().getSimpleName());
    }

    public static class Result<T extends FeatureVector>
    extends FeatureVectorCentroidsResult<T>
    implements ObjectNearestNeighboursProvider<T> {
        protected ObjectNearestNeighbours<T> nn;
        protected int iterations;
        protected int changedCentroidCount;

        public ObjectNearestNeighbours<T> getNearestNeighbours() {
            return this.nn;
        }

        @Override
        public HardAssigner<T, float[], IntFloatPair> defaultHardAssigner() {
            return new ExactFeatureVectorAssigner(this, this.nn.distanceComparator());
        }

        public int numIterations() {
            return this.iterations;
        }

        public int numChangedCentroids() {
            return this.changedCentroidCount;
        }
    }

    private static class CentroidAssignmentJob<T extends FeatureVector>
    implements Callable<Boolean> {
        private final DataSource<T> ds;
        private final int startRow;
        private final int stopRow;
        private final ObjectNearestNeighbours<T> nno;
        private final double[][] centroids_accum;
        private final int[] counts;

        public CentroidAssignmentJob(DataSource<T> ds, int startRow, int stopRow, ObjectNearestNeighbours<T> nno, double[][] centroids_accum, int[] counts) {
            this.ds = ds;
            this.startRow = startRow;
            this.stopRow = stopRow;
            this.nno = nno;
            this.centroids_accum = centroids_accum;
            this.counts = counts;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Boolean call() {
            try {
                int D = ((FeatureVector)this.ds.getData(0)).length();
                Object[] points = (FeatureVector[])this.ds.createTemporaryArray(this.stopRow - this.startRow);
                this.ds.getData(this.startRow, this.stopRow, points);
                int[] argmins = new int[points.length];
                float[] mins = new float[points.length];
                this.nno.searchNN(points, argmins, (Object)mins);
                double[][] dArray = this.centroids_accum;
                synchronized (this.centroids_accum) {
                    for (int i = 0; i < points.length; ++i) {
                        int k = argmins[i];
                        double[] vector = points[i].asDoubleVector();
                        for (int d = 0; d < D; ++d) {
                            double[] dArray2 = this.centroids_accum[k];
                            int n = d;
                            dArray2[n] = dArray2[n] + vector[d];
                        }
                        int n = k;
                        this.counts[n] = this.counts[n] + 1;
                    }
                    // ** MonitorExit[var5_6] (shouldn't be in output)
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            {
                return true;
            }
        }
    }
}

