/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.knn.approximate;

import cern.jet.random.Uniform;
import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.RandomEngine;
import jal.objects.BinaryPredicate;
import jal.objects.Sorting;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import org.openimaj.knn.ShortNearestNeighbours;
import org.openimaj.util.array.IntArrayView;
import org.openimaj.util.pair.FloatIntPair;
import org.openimaj.util.pair.FloatObjectPair;
import org.openimaj.util.pair.IntFloatPair;

public class ShortKDTreeEnsemble {
    private static final int leaf_max_points = 14;
    private static final int varest_max_points = 128;
    private static final int varest_max_randsz = 5;
    Uniform rng;
    public final ShortKDTreeNode[] trees;
    public final short[][] pnts;

    public ShortKDTreeEnsemble(short[][] pnts) {
        this(pnts, 8, 42);
    }

    public ShortKDTreeEnsemble(short[][] pnts, int ntrees) {
        this(pnts, ntrees, 42);
    }

    public ShortKDTreeEnsemble(short[][] pnts, int ntrees, int seed) {
        int N = pnts.length;
        this.pnts = pnts;
        this.rng = new Uniform((RandomEngine)new MersenneTwister(seed));
        IntArrayView inds = new IntArrayView(N);
        for (int n = 0; n < N; ++n) {
            inds.setFast(n, n);
        }
        this.trees = new ShortKDTreeNode[ntrees];
        for (int t = 0; t < ntrees; ++t) {
            this.trees[t] = new ShortKDTreeNode(pnts, inds, this.rng);
        }
    }

    void search(short[] qu, int numnn, IntFloatPair[] ret_nns, int nchecks) {
        int N = this.pnts.length;
        if (nchecks < numnn) {
            nchecks = numnn;
        }
        if (nchecks > N) {
            nchecks = N;
        }
        PriorityQueue<FloatObjectPair<ShortKDTreeNode>> pri_branch = new PriorityQueue<FloatObjectPair<ShortKDTreeNode>>(11, new Comparator<FloatObjectPair<ShortKDTreeNode>>(){

            @Override
            public int compare(FloatObjectPair<ShortKDTreeNode> o1, FloatObjectPair<ShortKDTreeNode> o2) {
                if (o1.first > o2.first) {
                    return 1;
                }
                if (o2.first > o1.first) {
                    return -1;
                }
                return 0;
            }
        });
        ArrayList<IntFloatPair> nns = new ArrayList<IntFloatPair>(3 * nchecks / 2);
        boolean[] seen = new boolean[N];
        for (int t = 0; t < this.trees.length; ++t) {
            this.trees[t].search(qu, pri_branch, nns, seen, this.pnts, 0.0f);
        }
        while (nns.size() < nchecks) {
            FloatObjectPair<ShortKDTreeNode> pr = pri_branch.poll();
            ((ShortKDTreeNode)pr.second).search(qu, pri_branch, nns, seen, this.pnts, pr.first);
        }
        Object[] nns_arr = nns.toArray(new IntFloatPair[nns.size()]);
        Sorting.partial_sort((Object[])nns_arr, (int)0, (int)numnn, (int)nns_arr.length, (BinaryPredicate)new BinaryPredicate(){

            public boolean apply(Object lhs, Object rhs) {
                return ((IntFloatPair)lhs).second < ((IntFloatPair)rhs).second;
            }
        });
        System.arraycopy(nns_arr, 0, ret_nns, 0, Math.min(numnn, nchecks));
    }

    public static class ShortKDTreeNode {
        ShortKDTreeNode left;
        NodeData node_data;
        private Uniform rng;

        boolean is_leaf() {
            return this.left == null;
        }

        IntFloatPair choose_split(short[][] pnts, IntArrayView inds) {
            int d;
            int D = pnts[0].length;
            float[] sum_x = new float[D];
            float[] sum_xx = new float[D];
            int count = Math.min(inds.size(), 128);
            for (int n = 0; n < count; ++n) {
                for (d = 0; d < D; ++d) {
                    int n2 = d;
                    sum_x[n2] = sum_x[n2] + (float)pnts[inds.getFast(n)][d];
                    int n3 = d;
                    sum_xx[n3] = sum_xx[n3] + (float)(pnts[inds.getFast(n)][d] * pnts[inds.getFast(n)][d]);
                }
            }
            Object[] var_dim = new FloatIntPair[D];
            for (d = 0; d < D; ++d) {
                var_dim[d] = new FloatIntPair();
                var_dim[d].first = count <= 1 ? 0.0f : (sum_xx[d] - 1.0f / (float)count * sum_x[d] * sum_x[d]) / (float)(count - 1);
                ((FloatIntPair)var_dim[d]).second = d;
            }
            int nrand = Math.min(5, D);
            Sorting.partial_sort((Object[])var_dim, (int)0, (int)nrand, (int)var_dim.length, (BinaryPredicate)new BinaryPredicate(){

                public boolean apply(Object arg0, Object arg1) {
                    FloatIntPair p1 = (FloatIntPair)arg0;
                    FloatIntPair p2 = (FloatIntPair)arg1;
                    if (p1.first > p2.first) {
                        return true;
                    }
                    if (p2.first > p1.first) {
                        return false;
                    }
                    return p1.second > p2.second;
                }
            });
            int randd = ((FloatIntPair)var_dim[this.rng.nextIntFromTo((int)0, (int)(nrand - 1))]).second;
            return new IntFloatPair(randd, sum_x[randd] / (float)count);
        }

        void split_points(short[][] pnts, IntArrayView inds) {
            IntFloatPair spl = this.choose_split(pnts, inds);
            ((InternalNodeData)this.node_data).disc_dim = spl.first;
            ((InternalNodeData)this.node_data).disc = spl.second;
            int N = inds.size();
            int l = 0;
            int r = N;
            while (l != r) {
                if ((float)pnts[inds.getFast(l)][((InternalNodeData)this.node_data).disc_dim] < ((InternalNodeData)this.node_data).disc) {
                    ++l;
                    continue;
                }
                int t = inds.getFast(l);
                inds.setFast(l, inds.getFast(--r));
                inds.setFast(r, t);
            }
            if (l == 0 || l == N) {
                l = N / 2;
            }
            this.left = new ShortKDTreeNode(pnts, inds.subView(0, l), this.rng);
            ((InternalNodeData)this.node_data).right = new ShortKDTreeNode(pnts, inds.subView(l, N), this.rng);
        }

        public ShortKDTreeNode() {
        }

        public ShortKDTreeNode(short[][] pnts, IntArrayView inds, Uniform rng) {
            this.rng = rng;
            if (inds.size() > 14) {
                this.node_data = new InternalNodeData();
                this.split_points(pnts, inds);
            } else {
                this.node_data = new LeafNodeData();
                ((LeafNodeData)this.node_data).indices = inds.toArray();
            }
        }

        void search(short[] qu, PriorityQueue<FloatObjectPair<ShortKDTreeNode>> pri_branch, List<IntFloatPair> nns, boolean[] seen, short[][] pnts, float mindsq) {
            ShortKDTreeNode cur = this;
            ShortKDTreeNode other = null;
            while (!cur.is_leaf()) {
                float diff = (float)qu[((InternalNodeData)cur.node_data).disc_dim] - ((InternalNodeData)cur.node_data).disc;
                if (diff < 0.0f) {
                    other = ((InternalNodeData)cur.node_data).right;
                    cur = cur.left;
                } else {
                    other = cur.left;
                    cur = ((InternalNodeData)cur.node_data).right;
                }
                pri_branch.add((FloatObjectPair<ShortKDTreeNode>)new FloatObjectPair(mindsq + diff * diff, (Object)other));
            }
            int[] cur_inds = ((LeafNodeData)cur.node_data).indices;
            int ncur_inds = cur_inds.length;
            float[] dsq = new float[1];
            for (int i = 0; i < ncur_inds; ++i) {
                int ci = cur_inds[i];
                if (seen[ci]) continue;
                ShortNearestNeighbours.distanceFunc(qu, new short[][]{pnts[ci]}, dsq);
                nns.add(new IntFloatPair(ci, dsq[0]));
                seen[ci] = true;
            }
        }

        class LeafNodeData
        extends NodeData {
            int[] indices;

            LeafNodeData() {
            }
        }

        class InternalNodeData
        extends NodeData {
            ShortKDTreeNode right;
            float disc;
            int disc_dim;

            InternalNodeData() {
            }
        }

        class NodeData {
            NodeData() {
            }
        }
    }
}

