/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.knn;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.PrintWriter;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.PriorityQueue;
import java.util.Scanner;
import java.util.Stack;
import org.openimaj.knn.CoordinateBruteForce;
import org.openimaj.knn.CoordinateIndex;
import org.openimaj.knn.KDNode;
import org.openimaj.math.geometry.point.Coordinate;

public class CoordinateKDTree<T extends Coordinate>
implements CoordinateIndex<T> {
    KDNode<T> _root = null;

    public CoordinateKDTree() {
    }

    public CoordinateKDTree(Collection<T> coords) {
        this.insertAll(coords);
    }

    public void insertAll(Collection<T> coords) {
        for (Coordinate c : coords) {
            this.insert(c);
        }
    }

    @Override
    public void insert(T point) {
        if (this._root == null) {
            this._root = new KDNode<T>(point, 0);
        } else {
            double ordinate2;
            KDNode<T> tmpNode;
            int discriminate;
            double ordinate1;
            KDNode<T> curNode = this._root;
            do {
                tmpNode = curNode;
            } while ((curNode = (ordinate1 = point.getOrdinate(discriminate = tmpNode._discriminate).doubleValue()) > (ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue()) ? tmpNode._right : tmpNode._left) != null);
            int dimensions = point.getDimensions();
            if (++discriminate >= dimensions) {
                discriminate = 0;
            }
            if (ordinate1 > ordinate2) {
                tmpNode._right = new KDNode<T>(point, discriminate);
            } else {
                tmpNode._left = new KDNode<T>(point, discriminate);
            }
        }
    }

    static final boolean isContained(Coordinate point, Coordinate lower, Coordinate upper) {
        int dimensions = point.getDimensions();
        for (int i = 0; i < dimensions; ++i) {
            double ordinate1 = point.getOrdinate(i).doubleValue();
            double ordinate2 = lower.getOrdinate(i).doubleValue();
            double ordinate3 = upper.getOrdinate(i).doubleValue();
            if (!(ordinate1 < ordinate2) && !(ordinate1 > ordinate3)) continue;
            return false;
        }
        return true;
    }

    @Override
    public void rangeSearch(Collection<T> results, Coordinate lowerExtreme, Coordinate upperExtreme) {
        Stack stack = new Stack();
        if (this._root == null) {
            return;
        }
        stack.push(this._root);
        while (!stack.empty()) {
            double ordinate2;
            KDNode tmpNode = (KDNode)stack.pop();
            int discriminate = tmpNode._discriminate;
            double ordinate1 = tmpNode._point.getOrdinate(discriminate).doubleValue();
            if (ordinate1 > (ordinate2 = lowerExtreme.getOrdinate(discriminate).doubleValue()) && tmpNode._left != null) {
                stack.push(tmpNode._left);
            }
            if (ordinate1 < (ordinate2 = upperExtreme.getOrdinate(discriminate).doubleValue()) && tmpNode._right != null) {
                stack.push(tmpNode._right);
            }
            if (!CoordinateKDTree.isContained(tmpNode._point, lowerExtreme, upperExtreme)) continue;
            results.add(tmpNode._point);
        }
    }

    protected static final float distance(Coordinate a, Coordinate b) {
        float s = 0.0f;
        for (int i = 0; i < a.getDimensions(); ++i) {
            float fa = a.getOrdinate(i).floatValue();
            float fb = b.getOrdinate(i).floatValue();
            s += (fa - fb) * (fa - fb);
        }
        return s;
    }

    @Override
    public T nearestNeighbour(Coordinate query) {
        Stack<KDNode<T>> stack = this.walkdown(query);
        NNState state = new NNState();
        state.best = stack.peek()._point;
        state.bestDist = CoordinateKDTree.distance(query, state.best);
        if (state.bestDist == 0.0f) {
            return state.best;
        }
        while (!stack.isEmpty()) {
            KDNode<T> current = stack.pop();
            this.checkSubtree(current, query, state);
        }
        return state.best;
    }

    @Override
    public void kNearestNeighbour(Collection<T> result, Coordinate query, int k) {
        Stack<KDNode<T>> stack = this.walkdown(query);
        PriorityQueue<NNState> state = new PriorityQueue<NNState>(k);
        NNState initialState = new NNState();
        initialState.best = stack.peek()._point;
        initialState.bestDist = CoordinateKDTree.distance(query, initialState.best);
        state.add(initialState);
        while (!stack.isEmpty()) {
            KDNode<T> current = stack.pop();
            this.checkSubtreeK(current, query, state, k);
        }
        Object[] stateList = state.toArray((NNState[])Array.newInstance(NNState.class, state.size()));
        Arrays.sort(stateList);
        for (int i = stateList.length - 1; i >= 0; --i) {
            result.add(((NNState)stateList[i]).best);
        }
    }

    private void checkSubtree(KDNode<T> node, Coordinate query, NNState state) {
        if (node == null) {
            return;
        }
        float dist = CoordinateKDTree.distance(query, node._point);
        if (dist < state.bestDist) {
            state.best = node._point;
            state.bestDist = dist;
        }
        if (state.bestDist == 0.0f) {
            return;
        }
        float d = node._point.getOrdinate(node._discriminate).floatValue() - query.getOrdinate(node._discriminate).floatValue();
        if (d * d > state.bestDist) {
            double ordinate2;
            double ordinate1 = query.getOrdinate(node._discriminate).doubleValue();
            if (ordinate1 > (ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue())) {
                this.checkSubtree(node._right, query, state);
            } else {
                this.checkSubtree(node._left, query, state);
            }
        } else {
            this.checkSubtree(node._left, query, state);
            this.checkSubtree(node._right, query, state);
        }
    }

    private void checkSubtreeK(KDNode<T> node, Coordinate query, PriorityQueue<NNState> state, int k) {
        float d;
        if (node == null) {
            return;
        }
        float dist = CoordinateKDTree.distance(query, node._point);
        boolean cont = false;
        for (NNState s : state) {
            if (!s.best.equals(node._point)) continue;
            cont = true;
            break;
        }
        if (!cont) {
            NNState s;
            if (state.size() < k) {
                s = new NNState();
                s.best = node._point;
                s.bestDist = dist;
                state.add(s);
            } else if (dist < state.peek().bestDist) {
                s = state.poll();
                s.best = node._point;
                s.bestDist = dist;
                state.add(s);
            }
        }
        if ((d = node._point.getOrdinate(node._discriminate).floatValue() - query.getOrdinate(node._discriminate).floatValue()) * d > state.peek().bestDist) {
            double ordinate2;
            double ordinate1 = query.getOrdinate(node._discriminate).doubleValue();
            if (ordinate1 > (ordinate2 = node._point.getOrdinate(node._discriminate).doubleValue())) {
                this.checkSubtreeK(node._right, query, state, k);
            } else {
                this.checkSubtreeK(node._left, query, state, k);
            }
        } else {
            this.checkSubtreeK(node._left, query, state, k);
            this.checkSubtreeK(node._right, query, state, k);
        }
    }

    private Stack<KDNode<T>> walkdown(Coordinate point) {
        double ordinate2;
        KDNode<T> tmpNode;
        int discriminate;
        double ordinate1;
        if (this._root == null) {
            return null;
        }
        Stack<KDNode<T>> stack = new Stack<KDNode<T>>();
        KDNode<T> curNode = this._root;
        do {
            tmpNode = curNode;
            stack.push(tmpNode);
            if (tmpNode._point != point) continue;
            return stack;
        } while ((curNode = (ordinate1 = point.getOrdinate(discriminate = tmpNode._discriminate).doubleValue()) > (ordinate2 = tmpNode._point.getOrdinate(discriminate).doubleValue()) ? tmpNode._right : tmpNode._left) != null);
        int dimensions = point.getDimensions();
        if (++discriminate >= dimensions) {
            discriminate = 0;
        }
        return stack;
    }

    public void fastKNN(Collection<T> result, Coordinate query, int k) {
        ArrayList tmp = new ArrayList();
        Coord lowerExtreme = new Coord(query);
        Coord upperExtreme = new Coord(query);
        while (tmp.size() < k) {
            tmp.clear();
            int i = 0;
            while (i < lowerExtreme.getDimensions()) {
                int n = i++;
                lowerExtreme.coords[n] = lowerExtreme.coords[n] - (float)k;
            }
            i = 0;
            while (i < upperExtreme.getDimensions()) {
                int n = i++;
                upperExtreme.coords[n] = upperExtreme.coords[n] + (float)k;
            }
            this.rangeSearch(tmp, lowerExtreme, upperExtreme);
        }
        CoordinateBruteForce bf = new CoordinateBruteForce(tmp);
        bf.kNearestNeighbour(result, query, k);
    }

    class Coord
    implements Coordinate {
        float[] coords;

        public Coord(int i) {
            this.coords = new float[i];
        }

        public Coord(Coordinate c) {
            this(c.getDimensions());
            for (int i = 0; i < this.coords.length; ++i) {
                this.coords[i] = c.getOrdinate(i).floatValue();
            }
        }

        public int getDimensions() {
            return this.coords.length;
        }

        public Float getOrdinate(int dimension) {
            return Float.valueOf(this.coords[dimension]);
        }

        public void readASCII(Scanner in) throws IOException {
            throw new RuntimeException("not implemented");
        }

        public String asciiHeader() {
            throw new RuntimeException("not implemented");
        }

        public void readBinary(DataInput in) throws IOException {
            throw new RuntimeException("not implemented");
        }

        public byte[] binaryHeader() {
            throw new RuntimeException("not implemented");
        }

        public void writeASCII(PrintWriter out) throws IOException {
            throw new RuntimeException("not implemented");
        }

        public void writeBinary(DataOutput out) throws IOException {
            throw new RuntimeException("not implemented");
        }

        public void setOrdinate(int dimension, Number value) {
            this.coords[dimension] = value.floatValue();
        }
    }

    class NNState
    implements Comparable<NNState> {
        T best;
        float bestDist;

        NNState() {
        }

        @Override
        public int compareTo(NNState o) {
            if (this.bestDist < o.bestDist) {
                return 1;
            }
            if (this.bestDist > o.bestDist) {
                return -1;
            }
            return 0;
        }

        public String toString() {
            return this.bestDist + "";
        }
    }
}

