/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.learning.function.vector.SquashedMatrixMultiplyVectorFunction;
import gov.sandia.cognition.math.UnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorizableVectorFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;

public class FeedforwardNeuralNetwork
extends AbstractCloneableSerializable
implements VectorizableVectorFunction {
    private ArrayList<? extends SquashedMatrixMultiplyVectorFunction> layers;

    public FeedforwardNeuralNetwork(ArrayList<Integer> nodesPerLayer, ArrayList<? extends UnivariateScalarFunction> layerActivationFunctions) {
        if (nodesPerLayer.size() != layerActivationFunctions.size() + 1) {
            throw new IllegalArgumentException("Number of layers must equal layerActivationFunction + 1");
        }
        ArrayList<SquashedMatrixMultiplyVectorFunction> localLayers = new ArrayList<SquashedMatrixMultiplyVectorFunction>(layerActivationFunctions.size());
        for (int i = 0; i < nodesPerLayer.size() - 1; ++i) {
            int currentNum = nodesPerLayer.get(i);
            int nextNum = nodesPerLayer.get(i + 1);
            localLayers.add(new SquashedMatrixMultiplyVectorFunction(currentNum, nextNum, layerActivationFunctions.get(i)));
        }
        this.setLayers(localLayers);
    }

    public FeedforwardNeuralNetwork(int numInputs, int numHiddens, int numOutputs, UnivariateScalarFunction activationFunction) {
        ArrayList<SquashedMatrixMultiplyVectorFunction> localLayers = new ArrayList<SquashedMatrixMultiplyVectorFunction>(2);
        localLayers.add(new SquashedMatrixMultiplyVectorFunction(numInputs, numHiddens, activationFunction));
        localLayers.add(new SquashedMatrixMultiplyVectorFunction(numHiddens, numOutputs, activationFunction));
        this.setLayers(localLayers);
    }

    public FeedforwardNeuralNetwork(ArrayList<? extends SquashedMatrixMultiplyVectorFunction> layers) {
        this.setLayers(layers);
    }

    public FeedforwardNeuralNetwork clone() {
        FeedforwardNeuralNetwork clone = (FeedforwardNeuralNetwork)super.clone();
        clone.setLayers(ObjectUtil.cloneSmartElementsAsArrayList(this.getLayers()));
        return clone;
    }

    public Vector convertToVector() {
        int numParams = 0;
        ArrayList<Vector> layerParameters = new ArrayList<Vector>(this.getLayers().size());
        for (int i = 0; i < this.getLayers().size(); ++i) {
            Vector p = this.getLayers().get(i).convertToVector();
            layerParameters.add(p);
            numParams += p.getDimensionality();
        }
        Vector parameters = VectorFactory.getDefault().createVector(numParams);
        int index = 0;
        for (Vector p : layerParameters) {
            int dim = p.getDimensionality();
            for (int i = 0; i < dim; ++i) {
                parameters.setElement(index, p.getElement(i));
                ++index;
            }
        }
        return parameters;
    }

    public void convertFromVector(Vector parameters) {
        int minIndex = 0;
        int maxIndex = -1;
        for (int i = 0; i < this.getLayers().size(); ++i) {
            SquashedMatrixMultiplyVectorFunction layer = this.getLayers().get(i);
            Matrix matrix = layer.getMatrixMultiply().getInternalMatrix();
            int num = matrix.getNumRows() * matrix.getNumColumns();
            minIndex = maxIndex + 1;
            maxIndex = minIndex + num - 1;
            Vector layerParameters = parameters.subVector(minIndex, maxIndex);
            layer.convertFromVector(layerParameters);
        }
    }

    public Vector evaluate(Vector input) {
        ArrayList<Vector> layerActivations = this.evaluateAtEachLayer(input);
        return layerActivations.get(layerActivations.size() - 1);
    }

    protected ArrayList<Vector> evaluateAtEachLayer(Vector input) {
        int N = this.getLayers().size();
        ArrayList<Vector> layerActivations = new ArrayList<Vector>(N + 1);
        layerActivations.add(input);
        Vector activation = input;
        for (SquashedMatrixMultiplyVectorFunction squashedMatrixMultiplyVectorFunction : this.getLayers()) {
            activation = squashedMatrixMultiplyVectorFunction.evaluate(activation);
            layerActivations.add(activation);
        }
        return layerActivations;
    }

    public ArrayList<? extends SquashedMatrixMultiplyVectorFunction> getLayers() {
        return this.layers;
    }

    public void setLayers(ArrayList<? extends SquashedMatrixMultiplyVectorFunction> layers) {
        this.layers = layers;
    }

    public String toString() {
        StringBuilder retval = new StringBuilder();
        retval.append(((Object)((Object)this)).getClass() + " with " + this.getLayers().size() + " Layers.");
        retval.append("\n");
        for (int i = 0; i < this.getLayers().size(); ++i) {
            retval.append("Layer " + i + "->" + (i + 1));
            retval.append("\n");
            retval.append(this.getLayers().get(i).toString());
        }
        return retval.toString();
    }
}

