/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.vector;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.gradient.GradientDescendable;
import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.learning.function.vector.ElementWiseVectorFunction;
import gov.sandia.cognition.math.DifferentiableUnivariateScalarFunction;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.UnivariateScalarFunction;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.math.matrix.VectorizableVectorFunction;
import gov.sandia.cognition.util.AbstractRandomized;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Random;

@PublicationReference(author={"Wikipedia"}, title="Multilayer perceptron", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Multilayer_perceptron")
public class ThreeLayerFeedforwardNeuralNetwork
extends AbstractRandomized
implements VectorizableVectorFunction,
VectorInputEvaluator<Vector, Vector>,
VectorOutputEvaluator<Vector, Vector>,
GradientDescendable {
    public static final double DEFAULT_INITIALIZATION_RANGE = 0.001;
    public static final DifferentiableUnivariateScalarFunction DEFAULT_SQUASHING_FUNCTION = new AtanFunction();
    public static final int DEFAULT_RANDOM_SEED = 1;
    protected Matrix inputToHiddenWeights;
    protected Vector inputToHiddenBiasWeights;
    protected Matrix hiddenToOutputWeights;
    protected Vector hiddenToOutputBiasWeights;
    private DifferentiableUnivariateScalarFunction squashingFunction;
    private double initializationRange;

    public ThreeLayerFeedforwardNeuralNetwork() {
        this(1, 1, 1);
    }

    public ThreeLayerFeedforwardNeuralNetwork(int numInputs, int numHidden, int numOutputs) {
        this(numInputs, numHidden, numOutputs, DEFAULT_SQUASHING_FUNCTION);
    }

    public ThreeLayerFeedforwardNeuralNetwork(int numInputs, int numHidden, int numOutputs, DifferentiableUnivariateScalarFunction squashingFunction) {
        this(numInputs, numHidden, numOutputs, squashingFunction, 1, 0.001);
    }

    public ThreeLayerFeedforwardNeuralNetwork(int numInputs, int numHidden, int numOutputs, DifferentiableUnivariateScalarFunction squashingFunction, int randomSeed, double initializationRange) {
        super(new Random(randomSeed));
        this.setInitializationRange(initializationRange);
        this.setSquashingFunction(squashingFunction);
        this.initializeWeights(numInputs, numHidden, numOutputs);
    }

    @Override
    public ThreeLayerFeedforwardNeuralNetwork clone() {
        ThreeLayerFeedforwardNeuralNetwork clone = (ThreeLayerFeedforwardNeuralNetwork)super.clone();
        clone.inputToHiddenWeights = (Matrix)ObjectUtil.cloneSafe((CloneableSerializable)this.inputToHiddenWeights);
        clone.inputToHiddenBiasWeights = (Vector)ObjectUtil.cloneSafe((CloneableSerializable)this.inputToHiddenBiasWeights);
        clone.hiddenToOutputWeights = (Matrix)ObjectUtil.cloneSafe((CloneableSerializable)this.hiddenToOutputWeights);
        clone.hiddenToOutputBiasWeights = (Vector)ObjectUtil.cloneSafe((CloneableSerializable)this.hiddenToOutputBiasWeights);
        clone.squashingFunction = (DifferentiableUnivariateScalarFunction)ObjectUtil.cloneSafe((CloneableSerializable)this.squashingFunction);
        return clone;
    }

    @Override
    public Matrix computeParameterGradient(Vector input) {
        int i;
        int numInputs = this.getInputDimensionality();
        int numHidden = this.getHiddenDimensionality();
        int numOutput = this.getOutputDimensionality();
        int num1 = numInputs * numHidden;
        int num2 = numHidden;
        int num3 = numHidden * numOutput;
        int num4 = numOutput;
        int N = num1 + num2 + num3 + num4;
        int M = numOutput;
        Vector hiddenActivation = this.evaluateHiddenLayerActivation(input);
        Vector squashedHiddenLayerActivation = this.evaluateSquashedHiddenLayerActivation(hiddenActivation);
        double[] squashedDerivativeHiddenLayerActivation = new double[numHidden];
        for (int i2 = 0; i2 < squashedDerivativeHiddenLayerActivation.length; ++i2) {
            squashedDerivativeHiddenLayerActivation[i2] = this.squashingFunction.differentiate(hiddenActivation.getElement(i2));
        }
        Matrix gradient = MatrixFactory.getDefault().createMatrix(M, N);
        int columnIndex = N - num4 - num3;
        for (int j = 0; j < numHidden; ++j) {
            double hj = squashedHiddenLayerActivation.getElement(j);
            for (int i3 = 0; i3 < numOutput; ++i3) {
                gradient.setElement(i3, columnIndex, hj);
                ++columnIndex;
            }
        }
        int offset = N - num4;
        for (i = 0; i < numOutput; ++i) {
            gradient.setElement(i, i + offset, 1.0);
        }
        offset = numHidden * numInputs;
        for (i = 0; i < numOutput; ++i) {
            for (int j = 0; j < numHidden; ++j) {
                double W2ij = this.hiddenToOutputWeights.getElement(i, j);
                double dfdhj = squashedDerivativeHiddenLayerActivation[j];
                double dyi_db1j = W2ij * dfdhj;
                gradient.setElement(i, j + offset, dyi_db1j);
                for (int k = 0; k < numInputs; ++k) {
                    double dyi_dW1jk = W2ij * dfdhj * input.getElement(k);
                    gradient.setElement(i, k * numHidden + j, dyi_dW1jk);
                }
            }
        }
        return gradient;
    }

    public Vector convertToVector() {
        int i;
        Vector p1 = this.inputToHiddenWeights.convertToVector();
        Vector p2 = this.inputToHiddenBiasWeights;
        Vector p3 = this.hiddenToOutputWeights.convertToVector();
        Vector p4 = this.hiddenToOutputBiasWeights;
        int num = p1.getDimensionality() + p2.getDimensionality() + p3.getDimensionality() + p4.getDimensionality();
        Vector parameters = VectorFactory.getDefault().createVector(num);
        int index = 0;
        for (i = 0; i < p1.getDimensionality(); ++i) {
            parameters.setElement(index, p1.getElement(i));
            ++index;
        }
        for (i = 0; i < p2.getDimensionality(); ++i) {
            parameters.setElement(index, p2.getElement(i));
            ++index;
        }
        for (i = 0; i < p3.getDimensionality(); ++i) {
            parameters.setElement(index, p3.getElement(i));
            ++index;
        }
        for (i = 0; i < p4.getDimensionality(); ++i) {
            parameters.setElement(index, p4.getElement(i));
            ++index;
        }
        return parameters;
    }

    public int getNumParameters() {
        int numInputs = this.getInputDimensionality();
        int numHidden = this.getHiddenDimensionality();
        int numOutput = this.getOutputDimensionality();
        int num1 = numInputs * numHidden;
        int num2 = numHidden;
        int num3 = numHidden * numOutput;
        int num4 = numOutput;
        return num1 + num2 + num3 + num4;
    }

    public void convertFromVector(Vector parameters) {
        int numInputs = this.getInputDimensionality();
        int numHidden = this.getHiddenDimensionality();
        int numOutput = this.getOutputDimensionality();
        int num1 = numInputs * numHidden;
        int num2 = numHidden;
        int num3 = numHidden * numOutput;
        int num4 = numOutput;
        int num = num1 + num2 + num3 + num4;
        parameters.assertDimensionalityEquals(num);
        Vector p1 = parameters.subVector(0, num1 - 1);
        Vector p2 = parameters.subVector(num1, num1 + num2 - 1);
        Vector p3 = parameters.subVector(num1 + num2, num1 + num2 + num3 - 1);
        Vector p4 = parameters.subVector(num1 + num2 + num3, num - 1);
        this.inputToHiddenWeights.convertFromVector(p1);
        this.inputToHiddenBiasWeights = p2;
        this.hiddenToOutputWeights.convertFromVector(p3);
        this.hiddenToOutputBiasWeights = p4;
    }

    public Vector evaluate(Vector input) {
        Vector hiddenActivation = this.evaluateHiddenLayerActivation(input);
        Vector squashedHiddenActivation = this.evaluateSquashedHiddenLayerActivation(hiddenActivation);
        return this.evaluateOutputFromSquashedHiddenLayerActivation(squashedHiddenActivation);
    }

    protected Vector evaluateHiddenLayerActivation(Vector input) {
        Vector hiddenActivation = this.inputToHiddenWeights.times(input);
        hiddenActivation.plusEquals((Ring)this.inputToHiddenBiasWeights);
        return hiddenActivation;
    }

    protected Vector evaluateSquashedHiddenLayerActivation(Vector hiddenActivation) {
        return ElementWiseVectorFunction.evaluate(hiddenActivation, (UnivariateScalarFunction)this.getSquashingFunction());
    }

    protected Vector evaluateOutputFromSquashedHiddenLayerActivation(Vector squashedHiddenActivation) {
        Vector outputActivation = this.hiddenToOutputWeights.times(squashedHiddenActivation);
        outputActivation.plusEquals((Ring)this.hiddenToOutputBiasWeights);
        return outputActivation;
    }

    public void reinitializeWeights() {
        this.initializeWeights(this.getInputDimensionality(), this.getHiddenDimensionality(), this.getOutputDimensionality());
    }

    public void initializeWeights(int inputDimensionality, int hiddenDimensionality, int outputDimensionality) {
        if (inputDimensionality < 1) {
            throw new IllegalArgumentException("inputDimensionality must be >= 1");
        }
        if (hiddenDimensionality < 1) {
            throw new IllegalArgumentException("hiddenDimensionality must be >= 1");
        }
        if (outputDimensionality < 1) {
            throw new IllegalArgumentException("outputDimensionality must be >= 1");
        }
        this.inputToHiddenWeights = MatrixFactory.getDefault().createUniformRandom(hiddenDimensionality, inputDimensionality, -this.getInitializationRange(), this.getInitializationRange(), this.getRandom());
        this.inputToHiddenBiasWeights = VectorFactory.getDefault().createUniformRandom(hiddenDimensionality, -this.getInitializationRange(), this.getInitializationRange(), this.random);
        this.hiddenToOutputWeights = MatrixFactory.getDefault().createUniformRandom(outputDimensionality, hiddenDimensionality, -this.getInitializationRange(), this.getInitializationRange(), this.getRandom());
        this.hiddenToOutputBiasWeights = VectorFactory.getDefault().createUniformRandom(outputDimensionality, -this.getInitializationRange(), this.getInitializationRange(), this.random);
    }

    public int getOutputDimensionality() {
        return this.hiddenToOutputWeights.getNumRows();
    }

    public void setOutputDimensionality(int outputDimensionality) {
        this.initializeWeights(this.getInputDimensionality(), this.getHiddenDimensionality(), outputDimensionality);
    }

    public int getHiddenDimensionality() {
        return this.hiddenToOutputWeights.getNumColumns();
    }

    public void setHiddenDimensionality(int hiddenDimensionality) {
        this.initializeWeights(this.getInputDimensionality(), hiddenDimensionality, this.getOutputDimensionality());
    }

    public int getInputDimensionality() {
        return this.inputToHiddenWeights.getNumColumns();
    }

    public void setInputDimensionality(int inputDimensionality) {
        this.initializeWeights(inputDimensionality, this.getHiddenDimensionality(), this.getOutputDimensionality());
    }

    public DifferentiableUnivariateScalarFunction getSquashingFunction() {
        return this.squashingFunction;
    }

    public void setSquashingFunction(DifferentiableUnivariateScalarFunction squashingFunction) {
        if (squashingFunction == null) {
            throw new IllegalArgumentException("Squashing function cannot be null!");
        }
        this.squashingFunction = squashingFunction;
    }

    public double getInitializationRange() {
        return this.initializationRange;
    }

    public void setInitializationRange(double initializationRange) {
        if (initializationRange < 0.0) {
            throw new IllegalArgumentException("initializationRange must be >= 0.0");
        }
        this.initializationRange = initializationRange;
    }
}

