/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.learner.matlib;

import ch.akuhn.matrix.Matrix;
import ch.akuhn.matrix.SparseMatrix;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.openimaj.io.ReadWriteableBinary;
import org.openimaj.math.matrix.DiagonalMatrix;
import org.openimaj.math.matrix.MatlibMatrixUtils;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.OnlineLearner;
import org.openimaj.ml.linear.learner.matlib.init.InitStrategy;
import org.openimaj.ml.linear.learner.matlib.init.SparseSingleValueInitStrat;
import org.openimaj.ml.linear.learner.matlib.loss.LossFunction;
import org.openimaj.ml.linear.learner.matlib.loss.MatLossFunction;
import org.openimaj.ml.linear.learner.matlib.regul.Regulariser;

public class MatlibBilinearSparseOnlineLearner
implements OnlineLearner<Matrix, Matrix>,
ReadWriteableBinary {
    static Logger logger = LogManager.getLogger(MatlibBilinearSparseOnlineLearner.class);
    protected BilinearLearnerParameters params;
    protected Matrix w;
    protected Matrix u;
    protected LossFunction loss;
    protected Regulariser regul;
    protected Double lambda_w;
    protected Double lambda_u;
    protected Boolean biasMode;
    protected Matrix bias;
    protected Matrix diagX;
    protected Double eta0_u;
    protected Double eta0_w;
    private Boolean forceSparcity;
    private Boolean zStandardise;
    private boolean nodataseen;

    public MatlibBilinearSparseOnlineLearner() {
        this(new BilinearLearnerParameters());
    }

    public MatlibBilinearSparseOnlineLearner(BilinearLearnerParameters params) {
        this.params = params;
        this.reinitParams();
    }

    public void reinitParams() {
        this.loss = (LossFunction)this.params.getTyped("loss");
        this.regul = (Regulariser)this.params.getTyped("regul");
        this.lambda_w = (Double)this.params.getTyped("lambda_w");
        this.lambda_u = (Double)this.params.getTyped("lambda_u");
        this.biasMode = (Boolean)this.params.getTyped("bias");
        this.eta0_u = (Double)this.params.getTyped("eta0u");
        this.eta0_w = (Double)this.params.getTyped("eta0w");
        this.forceSparcity = (Boolean)this.params.getTyped("forcesparcity");
        this.zStandardise = (Boolean)this.params.getTyped("z_standardise");
        if (!this.loss.isMatrixLoss()) {
            this.loss = new MatLossFunction(this.loss);
        }
        this.nodataseen = true;
    }

    private void initParams(Matrix x, Matrix y, int xrows, int xcols, int ycols) {
        InitStrategy wstrat = this.getInitStrat("winitstrat", x, y);
        InitStrategy ustrat = this.getInitStrat("uinitstrat", x, y);
        this.w = wstrat.init(xrows, ycols);
        this.u = ustrat.init(xcols, ycols);
        this.bias = SparseMatrix.sparse((int)ycols, (int)ycols);
        if (this.biasMode.booleanValue()) {
            InitStrategy bstrat = this.getInitStrat("biasinitstrat", x, y);
            this.bias = bstrat.init(ycols, ycols);
            this.diagX = new DiagonalMatrix(ycols, 1.0);
        }
    }

    private InitStrategy getInitStrat(String initstrat, Matrix x, Matrix y) {
        InitStrategy strat = (InitStrategy)this.params.getTyped(initstrat);
        return strat;
    }

    @Override
    public void process(Matrix X, Matrix Y) {
        double totalbias;
        double totalu;
        double totalw;
        Integer maxiter;
        double ratio;
        Double biconvextol;
        int nfeatures = X.rowCount();
        int nusers = X.columnCount();
        int ntasks = Y.columnCount();
        if (this.w == null) {
            this.initParams(X, Y, nfeatures, nusers, ntasks);
        }
        Double dampening = (Double)this.params.getTyped("dampening");
        double weighting = 1.0 - dampening;
        logger.debug("... dampening w, u and bias by: " + weighting);
        MatlibMatrixUtils.scaleInplace((Matrix)this.w, (double)weighting);
        MatlibMatrixUtils.scaleInplace((Matrix)this.u, (double)weighting);
        if (this.biasMode.booleanValue()) {
            MatlibMatrixUtils.scaleInplace((Matrix)this.bias, (double)weighting);
        }
        SparseMatrix Yexp = MatlibBilinearSparseOnlineLearner.expandY(Y);
        this.loss.setY((Matrix)Yexp);
        int iter = 0;
        do {
            if (this.biasMode.booleanValue()) {
                this.loss.setBias(this.bias);
            }
            double uLossWeight = this.etat(++iter, this.eta0_u);
            double wLossWeighted = this.etat(iter, this.eta0_w);
            double weightedLambda_u = this.lambdat(iter, this.lambda_u);
            double weightedLambda_w = this.lambdat(iter, this.lambda_w);
            Matrix Dprime = null;
            if (this.nodataseen) {
                this.nodataseen = false;
                Matrix fakeut = new SparseSingleValueInitStrat(1.0).init(this.u.columnCount(), this.u.rowCount());
                Dprime = MatlibMatrixUtils.dotProductTranspose((Matrix)fakeut, (Matrix)X);
            } else {
                Dprime = MatlibMatrixUtils.dotProductTransposeTranspose((Matrix)this.u, (Matrix)X);
            }
            if (this.zStandardise.booleanValue()) {
                // empty if block
            }
            this.loss.setX(Dprime);
            Matrix neww = this.updateW(this.w, wLossWeighted, weightedLambda_w);
            Matrix Vt = MatlibMatrixUtils.transposeDotProduct((Matrix)neww, (Matrix)X);
            this.loss.setX(Vt);
            Matrix newu = this.updateU(this.u, uLossWeight, weightedLambda_u);
            double sumchangew = MatlibMatrixUtils.normF((Matrix)MatlibMatrixUtils.minus((Matrix)neww, (Matrix)this.w));
            totalw = MatlibMatrixUtils.normF((Matrix)this.w);
            double sumchangeu = MatlibMatrixUtils.normF((Matrix)MatlibMatrixUtils.minus((Matrix)newu, (Matrix)this.u));
            totalu = MatlibMatrixUtils.normF((Matrix)this.u);
            double ratioU = 0.0;
            if (totalu != 0.0) {
                ratioU = sumchangeu / totalu;
            }
            double ratioW = 0.0;
            if (totalw != 0.0) {
                ratioU = sumchangew / totalw;
            }
            double ratioB = 0.0;
            ratio = ratioU + 0.0;
            totalbias = 0.0;
            if (this.biasMode.booleanValue()) {
                Matrix mult = MatlibMatrixUtils.dotProductTransposeTranspose((Matrix)newu, (Matrix)X);
                mult = MatlibMatrixUtils.dotProduct((Matrix)mult, (Matrix)neww);
                MatlibMatrixUtils.plusInplace((Matrix)mult, (Matrix)this.bias);
                this.loss.setBias(null);
                this.loss.setX(this.diagX);
                Matrix biasGrad = this.loss.gradient(mult);
                double biasLossWeight = this.biasEtat(iter);
                Matrix newbias = this.updateBias(biasGrad, biasLossWeight);
                double sumchangebias = MatlibMatrixUtils.normF((Matrix)MatlibMatrixUtils.minus((Matrix)newbias, (Matrix)this.bias));
                totalbias = MatlibMatrixUtils.normF((Matrix)this.bias);
                if (totalbias != 0.0) {
                    ratioB = sumchangebias / totalbias;
                }
                this.bias = newbias;
                ratio += ratioB;
                ratio /= 3.0;
            } else {
                ratio /= 2.0;
            }
            biconvextol = (Double)this.params.getTyped("biconvex_tol");
            maxiter = (Integer)this.params.getTyped("biconvex_maxiter");
            if (iter % 3 != 0) continue;
            logger.debug(String.format("Iter: %d. Last Ratio: %2.3f", iter, ratio));
            logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity((Matrix)this.w));
            logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity((Matrix)this.u));
            logger.debug("Total U magnitude: " + totalu);
            logger.debug("Total W magnitude: " + totalw);
            logger.debug("Total Bias: " + totalbias);
        } while (!(biconvextol < 0.0) && !(ratio < biconvextol) && iter < maxiter);
        logger.debug("tolerance reached after iteration: " + iter);
        logger.debug("W row sparcity: " + MatlibMatrixUtils.sparsity((Matrix)this.w));
        logger.debug("U row sparcity: " + MatlibMatrixUtils.sparsity((Matrix)this.u));
        logger.debug("Total U magnitude: " + totalu);
        logger.debug("Total W magnitude: " + totalw);
        logger.debug("Total Bias: " + totalbias);
    }

    protected Matrix updateBias(Matrix biasGrad, double biasLossWeight) {
        Matrix newbias = MatlibMatrixUtils.minus((Matrix)this.bias, (Matrix)MatlibMatrixUtils.scaleInplace((Matrix)biasGrad, (double)biasLossWeight));
        return newbias;
    }

    protected Matrix updateW(Matrix currentW, double wLossWeighted, double weightedLambda) {
        Matrix gradW = this.loss.gradient(currentW);
        MatlibMatrixUtils.scaleInplace((Matrix)gradW, (double)wLossWeighted);
        Matrix neww = MatlibMatrixUtils.minus((Matrix)currentW, (Matrix)gradW);
        neww = this.regul.prox(neww, weightedLambda);
        return neww;
    }

    protected Matrix updateU(Matrix currentU, double uLossWeight, double uWeightedLambda) {
        Matrix gradU = this.loss.gradient(currentU);
        MatlibMatrixUtils.scaleInplace((Matrix)gradU, (double)uLossWeight);
        Matrix newu = MatlibMatrixUtils.minus((Matrix)currentU, (Matrix)gradU);
        newu = this.regul.prox(newu, uWeightedLambda);
        return newu;
    }

    private double lambdat(int iter, double lambda) {
        return lambda / (double)iter;
    }

    public static SparseMatrix expandY(Matrix Y) {
        int ntasks = Y.columnCount();
        SparseMatrix Yexp = SparseMatrix.sparse((int)ntasks, (int)ntasks);
        for (int touter = 0; touter < ntasks; ++touter) {
            for (int tinner = 0; tinner < ntasks; ++tinner) {
                if (tinner == touter) {
                    Yexp.put(touter, tinner, Y.get(0, tinner));
                    continue;
                }
                Yexp.put(touter, tinner, Double.NaN);
            }
        }
        return Yexp;
    }

    private double biasEtat(int iter) {
        Double biasEta0 = (Double)this.params.getTyped("biaseta0");
        return biasEta0 / Math.sqrt(iter);
    }

    private double etat(int iter, double eta0) {
        Integer etaSteps = (Integer)this.params.getTyped("etasteps");
        double sqrtCeil = Math.sqrt(Math.ceil((double)iter / (double)etaSteps.intValue()));
        return this.eta(eta0) / sqrtCeil;
    }

    private double eta(double eta0) {
        return eta0;
    }

    public BilinearLearnerParameters getParams() {
        return this.params;
    }

    public Matrix getU() {
        return this.u;
    }

    public Matrix getW() {
        return this.w;
    }

    public Matrix getBias() {
        if (this.biasMode.booleanValue()) {
            return this.bias;
        }
        return null;
    }

    public void addU(int newUsers) {
        if (this.u == null) {
            return;
        }
        InitStrategy ustrat = this.getInitStrat("expandeduinitstrat", null, null);
        Matrix newU = ustrat.init(newUsers, this.u.columnCount());
        this.u = MatlibMatrixUtils.vstack((Matrix[])new Matrix[]{this.u, newU});
    }

    public void addW(int newWords) {
        if (this.w == null) {
            return;
        }
        InitStrategy wstrat = this.getInitStrat("expandedwinitstrat", null, null);
        Matrix newW = wstrat.init(newWords, this.w.columnCount());
        this.w = MatlibMatrixUtils.vstack((Matrix[])new Matrix[]{this.w, newW});
    }

    public MatlibBilinearSparseOnlineLearner clone() {
        MatlibBilinearSparseOnlineLearner ret = new MatlibBilinearSparseOnlineLearner(this.getParams());
        ret.u = MatlibMatrixUtils.copy((Matrix)this.u);
        ret.w = MatlibMatrixUtils.copy((Matrix)this.w);
        if (this.biasMode.booleanValue()) {
            ret.bias = MatlibMatrixUtils.copy((Matrix)this.bias);
        }
        return ret;
    }

    public void setU(Matrix newu) {
        this.u = newu;
    }

    public void setW(Matrix neww) {
        this.w = neww;
    }

    public void readBinary(DataInput in) throws IOException {
        double readDouble;
        int r;
        int t;
        int nwords = in.readInt();
        int nusers = in.readInt();
        int ntasks = in.readInt();
        this.w = SparseMatrix.sparse((int)nwords, (int)ntasks);
        for (t = 0; t < ntasks; ++t) {
            for (r = 0; r < nwords; ++r) {
                readDouble = in.readDouble();
                if (readDouble == 0.0) continue;
                this.w.put(r, t, readDouble);
            }
        }
        this.u = SparseMatrix.sparse((int)nusers, (int)ntasks);
        for (t = 0; t < ntasks; ++t) {
            for (r = 0; r < nusers; ++r) {
                readDouble = in.readDouble();
                if (readDouble == 0.0) continue;
                this.u.put(r, t, readDouble);
            }
        }
        this.bias = SparseMatrix.sparse((int)ntasks, (int)ntasks);
        for (int t1 = 0; t1 < ntasks; ++t1) {
            for (int t2 = 0; t2 < ntasks; ++t2) {
                readDouble = in.readDouble();
                if (readDouble == 0.0) continue;
                this.bias.put(t1, t2, readDouble);
            }
        }
    }

    public byte[] binaryHeader() {
        return "".getBytes();
    }

    public void writeBinary(DataOutput out) throws IOException {
        out.writeInt(this.w.rowCount());
        out.writeInt(this.u.rowCount());
        out.writeInt(this.u.columnCount());
        double[] wdata = this.w.asColumnMajorArray();
        for (int i = 0; i < wdata.length; ++i) {
            out.writeDouble(wdata[i]);
        }
        double[] udata = this.u.asColumnMajorArray();
        for (int i = 0; i < udata.length; ++i) {
            out.writeDouble(udata[i]);
        }
        double[] biasdata = this.bias.asColumnMajorArray();
        for (int i = 0; i < biasdata.length; ++i) {
            out.writeDouble(biasdata[i]);
        }
    }

    @Override
    public Matrix predict(Matrix x) {
        Matrix xt = MatlibMatrixUtils.transpose((Matrix)x);
        Matrix mult = MatlibMatrixUtils.dotProduct((Matrix)MatlibMatrixUtils.dotProduct((Matrix)MatlibMatrixUtils.transpose((Matrix)this.u), (Matrix)xt), (Matrix)this.w);
        if (this.biasMode.booleanValue()) {
            MatlibMatrixUtils.plusInplace((Matrix)mult, (Matrix)this.bias);
        }
        DiagonalMatrix ydiag = new DiagonalMatrix(mult);
        return ydiag;
    }
}

