/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.DataConverter;

public class LibCommonsMath {
    private static final Log LOG = LogFactory.getLog((String)LibCommonsMath.class.getName());
    private static final double RELATIVE_SYMMETRY_THRESHOLD = 1.0E-6;
    private static final double EIGEN_LAMBDA = 1.0E-8;

    private LibCommonsMath() {
    }

    public static boolean isSupportedUnaryOperation(String opcode) {
        return opcode.equals("inverse") || opcode.equals("cholesky");
    }

    public static boolean isSupportedMultiReturnOperation(String opcode) {
        return opcode.equals("qr") || opcode.equals("lu") || opcode.equals("eigen") || opcode.equals("svd");
    }

    public static boolean isSupportedMatrixMatrixOperation(String opcode) {
        return opcode.equals("solve");
    }

    public static MatrixBlock unaryOperations(MatrixBlock inj, String opcode) {
        Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(inj);
        if (opcode.equals("inverse")) {
            return LibCommonsMath.computeMatrixInverse(matrixInput);
        }
        if (opcode.equals("cholesky")) {
            return LibCommonsMath.computeCholesky(matrixInput);
        }
        return null;
    }

    public static MatrixBlock[] multiReturnOperations(MatrixBlock in, String opcode) {
        return LibCommonsMath.multiReturnOperations(in, opcode, 1, 1L);
    }

    public static MatrixBlock[] multiReturnOperations(MatrixBlock in, String opcode, int threads, int num_iterations, double tol) {
        if (opcode.equals("eigen_qr")) {
            return LibCommonsMath.computeEigenQR(in, num_iterations, tol, threads);
        }
        return LibCommonsMath.multiReturnOperations(in, opcode, threads, 1L);
    }

    public static MatrixBlock[] multiReturnOperations(MatrixBlock in, String opcode, int threads, long seed) {
        if (opcode.equals("qr")) {
            return LibCommonsMath.computeQR(in);
        }
        if (opcode.equals("qr2")) {
            return LibCommonsMath.computeQR2(in, threads);
        }
        if (opcode.equals("lu")) {
            return LibCommonsMath.computeLU(in);
        }
        if (opcode.equals("eigen")) {
            return LibCommonsMath.computeEigen(in);
        }
        if (opcode.equals("eigen_lanczos")) {
            return LibCommonsMath.computeEigenLanczos(in, threads, seed);
        }
        if (opcode.equals("eigen_qr")) {
            return LibCommonsMath.computeEigenQR(in, threads);
        }
        if (opcode.equals("svd")) {
            return LibCommonsMath.computeSvd(in);
        }
        return null;
    }

    public static MatrixBlock matrixMatrixOperations(MatrixBlock in1, MatrixBlock in2, String opcode) {
        if (opcode.equals("solve")) {
            if (in1.getNumRows() != in1.getNumColumns()) {
                throw new DMLRuntimeException("The A matrix, in solve(A,b) should have squared dimensions.");
            }
            return LibCommonsMath.computeSolve(in1, in2);
        }
        return null;
    }

    private static MatrixBlock computeSolve(MatrixBlock in1, MatrixBlock in2) {
        BlockRealMatrix matrixInput = DataConverter.convertToBlockRealMatrix(in1);
        BlockRealMatrix vectorInput = DataConverter.convertToBlockRealMatrix(in2);
        QRDecomposition qrdecompose = new QRDecomposition((RealMatrix)matrixInput);
        DecompositionSolver solver = qrdecompose.getSolver();
        RealMatrix solutionMatrix = solver.solve((RealMatrix)vectorInput);
        return DataConverter.convertToMatrixBlock(solutionMatrix);
    }

    private static MatrixBlock[] computeQR(MatrixBlock in) {
        Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
        QRDecomposition qrdecompose = new QRDecomposition((RealMatrix)matrixInput);
        RealMatrix H = qrdecompose.getH();
        RealMatrix R = qrdecompose.getR();
        MatrixBlock mbH = DataConverter.convertToMatrixBlock(H.getData());
        MatrixBlock mbR = DataConverter.convertToMatrixBlock(R.getData());
        return new MatrixBlock[]{mbH, mbR};
    }

    private static MatrixBlock[] computeLU(MatrixBlock in) {
        if (in.getNumRows() != in.getNumColumns()) {
            throw new DMLRuntimeException("LU Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols=" + in.getNumColumns() + ")");
        }
        Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
        LUDecomposition ludecompose = new LUDecomposition((RealMatrix)matrixInput);
        RealMatrix P = ludecompose.getP();
        RealMatrix L = ludecompose.getL();
        RealMatrix U = ludecompose.getU();
        MatrixBlock mbP = DataConverter.convertToMatrixBlock(P.getData());
        MatrixBlock mbL = DataConverter.convertToMatrixBlock(L.getData());
        MatrixBlock mbU = DataConverter.convertToMatrixBlock(U.getData());
        return new MatrixBlock[]{mbP, mbL, mbU};
    }

    private static MatrixBlock[] computeEigen(MatrixBlock in) {
        if (in.getNumRows() != in.getNumColumns()) {
            throw new DMLRuntimeException("Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols=" + in.getNumColumns() + ")");
        }
        EigenDecomposition eigendecompose = null;
        try {
            Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
            eigendecompose = new EigenDecomposition((RealMatrix)matrixInput);
        }
        catch (MaxCountExceededException ex) {
            LOG.warn((Object)("Eigen: " + ex.getMessage() + ". Falling back to regularized eigen factorization."));
            eigendecompose = LibCommonsMath.computeEigenRegularized(in);
        }
        RealMatrix eVectorsMatrix = eigendecompose.getV();
        double[][] eVectors = eVectorsMatrix.getData();
        double[] eValues = eigendecompose.getRealEigenvalues();
        return LibCommonsMath.sortEVs(eValues, eVectors);
    }

    private static EigenDecomposition computeEigenRegularized(MatrixBlock in) {
        if (in == null || in.isEmptyBlock(false)) {
            throw new DMLRuntimeException("Invalid empty block");
        }
        MatrixBlock in2 = new MatrixBlock(in, false);
        DenseBlock a = in2.getDenseBlock();
        for (int i = 0; i < in2.rlen; ++i) {
            double[] avals = a.values(i);
            int apos = a.pos(i);
            for (int j = 0; j < in2.clen; ++j) {
                double v = avals[apos + j];
                int n = apos + j;
                avals[n] = avals[n] + Math.signum(v) * 1.0E-8;
            }
        }
        return new EigenDecomposition((RealMatrix)DataConverter.convertToArray2DRowRealMatrix(in2));
    }

    private static MatrixBlock[] computeSvd(MatrixBlock in) {
        Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
        SingularValueDecomposition svd = new SingularValueDecomposition((RealMatrix)matrixInput);
        double[] sigma = svd.getSingularValues();
        RealMatrix u = svd.getU();
        RealMatrix v = svd.getV();
        MatrixBlock U = DataConverter.convertToMatrixBlock(u.getData());
        MatrixBlock Sigma = DataConverter.convertToMatrixBlock(sigma, true);
        Sigma = LibMatrixReorg.diag(Sigma, new MatrixBlock(Sigma.rlen, Sigma.rlen, true));
        MatrixBlock V = DataConverter.convertToMatrixBlock(v.getData());
        return new MatrixBlock[]{U, Sigma, V};
    }

    private static MatrixBlock computeMatrixInverse(Array2DRowRealMatrix in) {
        if (!in.isSquare()) {
            throw new DMLRuntimeException("Input to inv() must be square matrix -- given: a " + in.getRowDimension() + "x" + in.getColumnDimension() + " matrix.");
        }
        QRDecomposition qrdecompose = new QRDecomposition((RealMatrix)in);
        DecompositionSolver solver = qrdecompose.getSolver();
        RealMatrix inverseMatrix = solver.getInverse();
        return DataConverter.convertToMatrixBlock(inverseMatrix.getData());
    }

    private static MatrixBlock computeCholesky(Array2DRowRealMatrix in) {
        if (!in.isSquare()) {
            throw new DMLRuntimeException("Input to cholesky() must be square matrix -- given: a " + in.getRowDimension() + "x" + in.getColumnDimension() + " matrix.");
        }
        CholeskyDecomposition cholesky = new CholeskyDecomposition((RealMatrix)in, 1.0E-6, 1.0E-10);
        RealMatrix rmL = cholesky.getL();
        return DataConverter.convertToMatrixBlock(rmL.getData());
    }

    private static MatrixBlock randNormalizedVect(int dim, int threads, long seed) {
        MatrixBlock v1 = MatrixBlock.randOperations(dim, 1, 1.0, 0.0, 1.0, "UNIFORM", seed);
        double v1_sum = v1.sum();
        RightScalarOperator op_div_scalar = new RightScalarOperator((ValueFunction)Divide.getDivideFnObject(), v1_sum, threads);
        v1 = v1.scalarOperations(op_div_scalar, new MatrixBlock());
        UnaryOperator op_sqrt = new UnaryOperator(Builtin.getBuiltinFnObject(Builtin.BuiltinCode.SQRT), threads, true);
        if (Math.abs((v1 = v1.unaryOperations(op_sqrt, new MatrixBlock())).sumSq() - 1.0) >= 1.0E-7) {
            throw new DMLRuntimeException("v1 not correctly normalized (maybe try changing the seed)");
        }
        return v1;
    }

    private static MatrixBlock[] computeEigenLanczos(MatrixBlock in, int threads, long seed) {
        if (in.getNumRows() != in.getNumColumns()) {
            throw new DMLRuntimeException("Lanczos algorithm and Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols=" + in.getNumColumns() + ")");
        }
        int m = in.getNumRows();
        MatrixBlock v0 = new MatrixBlock(m, 1, 0.0);
        MatrixBlock v1 = LibCommonsMath.randNormalizedVect(m, threads, seed);
        MatrixBlock T = new MatrixBlock(m, m, 0.0);
        MatrixBlock TV = new MatrixBlock(m, m, 0.0);
        ReorgOperator op_t = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), threads);
        TernaryOperator op_minus_mul = new TernaryOperator(MinusMultiply.getFnObject(), threads);
        AggregateBinaryOperator op_mul_agg = InstructionUtils.getMatMultOperator(threads);
        ScalarOperator op_div_scalar = new RightScalarOperator((ValueFunction)Divide.getDivideFnObject(), 1.0, threads);
        MatrixBlock beta = new MatrixBlock(1, 1, 0.0);
        for (int i = 0; i < m; ++i) {
            v1.putInto(TV, 0, i, false);
            MatrixBlock w1 = in.aggregateBinaryOperations(in, v1, op_mul_agg);
            MatrixBlock alpha = w1.aggregateBinaryOperations(v1.reorgOperations(op_t, new MatrixBlock(), 0, 0, m), w1, op_mul_agg);
            if (i < m - 1) {
                w1 = w1.ternaryOperations(op_minus_mul, v1, alpha, new MatrixBlock());
                w1 = w1.ternaryOperations(op_minus_mul, v0, beta, new MatrixBlock());
                beta.setValue(0, 0, Math.sqrt(w1.sumSq()));
                v0.copy(v1);
                op_div_scalar = ((ScalarOperator)op_div_scalar).setConstant(beta.getDouble(0, 0));
                w1.scalarOperations(op_div_scalar, v1);
                T.setValue(i + 1, i, beta.getValue(0, 0));
                T.setValue(i, i + 1, beta.getValue(0, 0));
            }
            T.setValue(i, i, alpha.getValue(0, 0));
        }
        MatrixBlock[] e = LibCommonsMath.computeEigen(T);
        TV.setNonZeros((long)m * (long)m);
        e[1] = TV.aggregateBinaryOperations(TV, e[1], op_mul_agg);
        return e;
    }

    private static MatrixBlock[] computeQR2(MatrixBlock in, int threads) {
        if (in.getNumRows() != in.getNumColumns()) {
            throw new DMLRuntimeException("QR2 Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols=" + in.getNumColumns() + ")");
        }
        int m = in.rlen;
        MatrixBlock A_n = new MatrixBlock();
        A_n.copy(in);
        MatrixBlock Q_n = new MatrixBlock(m, m, true);
        for (int i = 0; i < m; ++i) {
            Q_n.setValue(i, i, 1.0);
        }
        ReorgOperator op_t = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), threads);
        AggregateBinaryOperator op_mul_agg = InstructionUtils.getMatMultOperator(threads);
        BinaryOperator op_sub = InstructionUtils.parseExtendedBinaryOperator("-");
        ScalarOperator op_div_scalar = new RightScalarOperator((ValueFunction)Divide.getDivideFnObject(), 1.0, threads);
        LeftScalarOperator op_mult_2 = new LeftScalarOperator((ValueFunction)Multiply.getMultiplyFnObject(), 2.0, threads);
        for (int k = 0; k < m; ++k) {
            MatrixBlock z = A_n.slice(k, m - 1, k, k);
            MatrixBlock uk = new MatrixBlock(m - k, 1, 0.0);
            uk.copy(z);
            uk.setValue(0, 0, uk.getValue(0, 0) + Math.signum(z.getValue(0, 0)) * Math.sqrt(z.sumSq()));
            op_div_scalar = ((ScalarOperator)op_div_scalar).setConstant(Math.sqrt(uk.sumSq()));
            uk = uk.scalarOperations(op_div_scalar, new MatrixBlock());
            MatrixBlock vk = new MatrixBlock(m, 1, 0.0);
            vk.copy(k, m - 1, 0, 0, uk, true);
            MatrixBlock vkt = vk.reorgOperations(op_t, new MatrixBlock(), 0, 0, m);
            MatrixBlock vkt2 = vkt.scalarOperations(op_mult_2, new MatrixBlock());
            MatrixBlock vkvkt2 = vk.aggregateBinaryOperations(vk, vkt2, op_mul_agg);
            A_n = A_n.binaryOperations(op_sub, A_n.aggregateBinaryOperations(vkvkt2, A_n, op_mul_agg));
            Q_n = Q_n.binaryOperations(op_sub, Q_n.aggregateBinaryOperations(Q_n, vkvkt2, op_mul_agg));
        }
        return new MatrixBlock[]{Q_n, A_n};
    }

    private static MatrixBlock[] computeEigenQR(MatrixBlock in, int threads) {
        return LibCommonsMath.computeEigenQR(in, 100, 1.0E-10, threads);
    }

    private static MatrixBlock[] computeEigenQR(MatrixBlock in, int num_iterations, double tol, int threads) {
        int i;
        if (in.getNumRows() != in.getNumColumns()) {
            throw new DMLRuntimeException("Eigen Decomposition (QR) can only be done on a square matrix. Input matrix is rectangular (rows=" + in.getNumRows() + ", cols=" + in.getNumColumns() + ")");
        }
        int m = in.rlen;
        AggregateBinaryOperator op_mul_agg = InstructionUtils.getMatMultOperator(threads);
        MatrixBlock Q_prod = new MatrixBlock(m, m, 0.0);
        for (i = 0; i < m; ++i) {
            Q_prod.setValue(i, i, 1.0);
        }
        for (i = 0; i < num_iterations; ++i) {
            MatrixBlock[] QR = LibCommonsMath.computeQR2(in, threads);
            Q_prod = Q_prod.aggregateBinaryOperations(Q_prod, QR[0], op_mul_agg);
            in = QR[1].aggregateBinaryOperations(QR[1], QR[0], op_mul_agg);
        }
        double[] check = in.getDenseBlockValues();
        double[] eval = new double[m];
        for (int i2 = 0; i2 < m; ++i2) {
            eval[i2] = check[i2 * m + i2];
        }
        double[] evec = Q_prod.getDenseBlockValues();
        return LibCommonsMath.sortEVs(eval, evec);
    }

    private static MatrixBlock computeHouseholder(MatrixBlock in, int threads) {
        int m = in.rlen;
        MatrixBlock A_n = new MatrixBlock(m, m, 0.0);
        A_n.copy(in);
        for (int k = 0; k < m - 2; ++k) {
            MatrixBlock ajk = A_n.slice(0, m - 1, k, k);
            for (int i = 0; i <= k; ++i) {
                ajk.setValue(i, 0, 0.0);
            }
            double alpha = Math.sqrt(ajk.sumSq());
            double ak1k = A_n.getDouble(k + 1, k);
            if (ak1k > 0.0) {
                alpha *= -1.0;
            }
            double r = Math.sqrt(0.5 * (alpha * alpha - ak1k * alpha));
            MatrixBlock v = new MatrixBlock(m, 1, 0.0);
            v.copy(ajk);
            v.setValue(k + 1, 0, ak1k - alpha);
            RightScalarOperator op_div_scalar = new RightScalarOperator((ValueFunction)Divide.getDivideFnObject(), 2.0 * r, threads);
            v = v.scalarOperations(op_div_scalar, new MatrixBlock());
            MatrixBlock P = new MatrixBlock(m, m, 0.0);
            for (int i = 0; i < m; ++i) {
                P.setValue(i, i, 1.0);
            }
            ReorgOperator op_t = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), threads);
            AggregateBinaryOperator op_mul_agg = InstructionUtils.getMatMultOperator(threads);
            BinaryOperator op_add = InstructionUtils.parseExtendedBinaryOperator("+");
            BinaryOperator op_sub = InstructionUtils.parseExtendedBinaryOperator("-");
            MatrixBlock v_t = v.reorgOperations(op_t, new MatrixBlock(), 0, 0, m);
            v_t = v_t.binaryOperations(op_add, v_t);
            MatrixBlock v_v_t_2 = A_n.aggregateBinaryOperations(v, v_t, op_mul_agg);
            P = P.binaryOperations(op_sub, v_v_t_2);
            A_n = A_n.aggregateBinaryOperations(P, A_n.aggregateBinaryOperations(A_n, P, op_mul_agg), op_mul_agg);
        }
        return A_n;
    }

    private static MatrixBlock[] sortEVs(double[] eValues, double[][] eVectors) {
        int n = eValues.length;
        for (int i = 0; i < n; ++i) {
            int j;
            int k = i;
            double p = eValues[i];
            for (j = i + 1; j < n; ++j) {
                if (!(eValues[j] < p)) continue;
                k = j;
                p = eValues[j];
            }
            if (k == i) continue;
            eValues[k] = eValues[i];
            eValues[i] = p;
            for (j = 0; j < n; ++j) {
                p = eVectors[j][i];
                eVectors[j][i] = eVectors[j][k];
                eVectors[j][k] = p;
            }
        }
        MatrixBlock eval = DataConverter.convertToMatrixBlock(eValues, true);
        MatrixBlock evec = DataConverter.convertToMatrixBlock(eVectors);
        return new MatrixBlock[]{eval, evec};
    }

    private static MatrixBlock[] sortEVs(double[] eValues, double[] eVectors) {
        int n = eValues.length;
        for (int i = 0; i < n; ++i) {
            int j;
            int k = i;
            double p = eValues[i];
            for (j = i + 1; j < n; ++j) {
                if (!(eValues[j] < p)) continue;
                k = j;
                p = eValues[j];
            }
            if (k == i) continue;
            eValues[k] = eValues[i];
            eValues[i] = p;
            for (j = 0; j < n; ++j) {
                p = eVectors[j * n + i];
                eVectors[j * n + i] = eVectors[j * n + k];
                eVectors[j * n + k] = p;
            }
        }
        MatrixBlock eval = DataConverter.convertToMatrixBlock(eValues, true);
        MatrixBlock evec = new MatrixBlock(n, n, false);
        evec.init(eVectors, n, n);
        return new MatrixBlock[]{eval, evec};
    }
}

