/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.conv;

import java.util.ArrayList;
import java.util.List;
import org.dromara.easyai.conv.ConvResult;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.nerveEntity.ConvSize;

public abstract class ConvCount {
    private final MatrixOperation matrixOperation = new MatrixOperation();

    protected int getConvMyDep(int xSize, int ySize, int kernLen, int minFeatureValue) {
        int xDeep = this.getConvDeep(xSize, kernLen, minFeatureValue);
        int yDeep = this.getConvDeep(ySize, kernLen, minFeatureValue);
        return Math.min(xDeep, yDeep);
    }

    private int getConvDeep(int size, int kernLen, int minFeatureValue) {
        int x = size;
        int step = 1;
        int deep = 0;
        do {
            x = (x - (kernLen - step)) / step;
            x = x / 2 + x % 2;
            ++deep;
        } while (x > minFeatureValue);
        return deep - 1;
    }

    private Matrix upPooling(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix myMatrix = new Matrix(x * 2, y * 2);
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                float value = matrix.getNumber(i, j);
                this.insertMatrixValue(i * 2, j * 2, value, myMatrix);
            }
        }
        return myMatrix;
    }

    protected List<Matrix> backManyUpPooling(List<Matrix> errorMatrix) throws Exception {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        for (Matrix matrix : errorMatrix) {
            result.add(this.backUpPooling(matrix));
        }
        return result;
    }

    protected Matrix backUpPooling(Matrix errorMatrix) throws Exception {
        int x = errorMatrix.getX();
        int y = errorMatrix.getY();
        Matrix myMatrix = new Matrix(x / 2, y / 2);
        for (int i = 0; i < x - 1; i += 2) {
            for (int j = 0; j < y - 1; j += 2) {
                float sigma = errorMatrix.getNumber(i, j) + errorMatrix.getNumber(i, j + 1) + errorMatrix.getNumber(i + 1, j) + errorMatrix.getNumber(i + 1, j + 1);
                myMatrix.setNub(i / 2, j / 2, sigma);
            }
        }
        return myMatrix;
    }

    private Matrix downPooling(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        int xt = 0;
        int yt = 0;
        if (x % 2 == 1) {
            xt = 1;
        }
        if (y % 2 == 1) {
            yt = 1;
        }
        Matrix myMatrix = new Matrix(x / 2 + xt, y / 2 + yt);
        for (int i = 0; i < x - 1; i += 2) {
            for (int j = 0; j < y - 1; j += 2) {
                float sigma = (matrix.getNumber(i, j) + matrix.getNumber(i, j + 1) + matrix.getNumber(i + 1, j) + matrix.getNumber(i + 1, j + 1)) / 4.0f;
                myMatrix.setNub(i / 2, j / 2, sigma);
            }
        }
        return myMatrix;
    }

    private void insertMatrixValue(int x, int y, float value, Matrix matrix) throws Exception {
        int xSize = x + 2;
        int ySize = y + 2;
        if (xSize > matrix.getX()) {
            --xSize;
        }
        if (ySize > matrix.getY()) {
            --ySize;
        }
        for (int i = x; i < xSize; ++i) {
            for (int j = y; j < ySize; ++j) {
                matrix.setNub(i, j, value);
            }
        }
    }

    protected List<Matrix> backDownPoolingByList(List<Matrix> matrixList, int outX, int outY) throws Exception {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        for (Matrix matrix : matrixList) {
            result.add(this.backDownPooling(matrix, outX, outY));
        }
        return result;
    }

    protected Matrix backDownPooling(Matrix matrix, int outX, int outY) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        int xt = 0;
        int yt = 0;
        if (outX % 2 == 1) {
            xt = 1;
        }
        if (outY % 2 == 1) {
            yt = 1;
        }
        Matrix myMatrix = new Matrix(x * 2 - xt, y * 2 - yt);
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                float value = matrix.getNumber(i, j) / 4.0f;
                this.insertMatrixValue(i * 2, j * 2, value, myMatrix);
            }
        }
        return myMatrix;
    }

    private int getUpSize(int size, int kerSize) {
        return size + kerSize - 1;
    }

    private int backUpSize(int size, int kerSize) {
        return size - kerSize + 1;
    }

    protected List<Matrix> backManyUpConv(List<Matrix> errorMatrixList, int kerSize, ConvParameter convParameter, float studyRate, ActiveFunction activeFunction) throws Exception {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        int size = errorMatrixList.size();
        for (int i = 0; i < size; ++i) {
            Matrix errorMatrix = errorMatrixList.get(i);
            result.add(this.backUpConv(errorMatrix, kerSize, convParameter, studyRate, activeFunction, i));
        }
        return result;
    }

    protected Matrix backUpConv(Matrix errorMatrix, int kerSize, ConvParameter convParameter, float studyRate, ActiveFunction activeFunction, int index) throws Exception {
        int myX = errorMatrix.getX();
        int myY = errorMatrix.getY();
        Matrix outMatrix = convParameter.getUpOutMatrixList().get(index);
        for (int i = 0; i < myX; ++i) {
            for (int j = 0; j < myY; ++j) {
                float value = activeFunction.functionG(outMatrix.getNumber(i, j)) * errorMatrix.getNumber(i, j);
                errorMatrix.setNub(i, j, value);
            }
        }
        int x = this.backUpSize(myX, kerSize);
        int y = this.backUpSize(myY, kerSize);
        Matrix upNerveMatrix = convParameter.getUpNerveMatrixList().get(index);
        Matrix vector = convParameter.getUpFeatureMatrixList().get(index);
        Matrix error = this.matrixOperation.im2col(errorMatrix, kerSize, 1);
        Matrix subNerveMatrix = this.matrixOperation.matrixMulPd(error, vector, upNerveMatrix, false);
        Matrix errorFeature = this.matrixOperation.matrixMulPd(error, vector, upNerveMatrix, true);
        this.matrixOperation.mathMul(subNerveMatrix, studyRate);
        convParameter.getUpNerveMatrixList().set(index, this.matrixOperation.add(subNerveMatrix, upNerveMatrix));
        return this.matrixOperation.vectorToMatrix(errorFeature, x, y);
    }

    private ConvResult upConv(List<Matrix> matrixList, int kerSize, List<Matrix> nervePowerMatrixList, ActiveFunction activeFunction, int channelNo) throws Exception {
        ConvResult convResult = new ConvResult();
        int x = this.getUpSize(matrixList.get(0).getX(), kerSize);
        int y = this.getUpSize(matrixList.get(0).getY(), kerSize);
        ArrayList<Matrix> vectorList = new ArrayList<Matrix>();
        ArrayList<Matrix> resultList = new ArrayList<Matrix>();
        convResult.setLeftMatrixList(vectorList);
        convResult.setResultMatrixList(resultList);
        for (int k = 0; k < channelNo; ++k) {
            Matrix matrix = matrixList.get(k);
            Matrix nervePowerMatrix = nervePowerMatrixList.get(k);
            Matrix vector = this.matrixOperation.matrixToVector(matrix, false);
            Matrix im2colMatrix = this.matrixOperation.mulMatrix(vector, nervePowerMatrix);
            Matrix out = this.matrixOperation.reverseIm2col(im2colMatrix, kerSize, 1, x, y);
            Matrix outMatrix = new Matrix(x, y);
            for (int i = 0; i < x; ++i) {
                for (int j = 0; j < y; ++j) {
                    float value = activeFunction.function(out.getNumber(i, j));
                    outMatrix.setNub(i, j, value);
                }
            }
            vectorList.add(vector);
            resultList.add(outMatrix);
        }
        return convResult;
    }

    protected List<Matrix> upConvAndPooling(List<Matrix> matrixList, ConvParameter convParameter, int channelNo, ActiveFunction activeFunction, int kernLen, boolean pooling) throws Exception {
        List<Matrix> downConvMatrixList = this.downConvAndPooling(matrixList, convParameter, channelNo, activeFunction, kernLen, false, -1L);
        if (pooling) {
            ConvResult result = this.upConv(downConvMatrixList, kernLen, convParameter.getUpNerveMatrixList(), activeFunction, channelNo);
            convParameter.setUpOutMatrixList(result.getResultMatrixList());
            convParameter.setUpFeatureMatrixList(result.getLeftMatrixList());
            ArrayList<Matrix> upPoolingMatrixList = new ArrayList<Matrix>();
            List<Matrix> resultMatrixList = result.getResultMatrixList();
            for (Matrix matrix : resultMatrixList) {
                upPoolingMatrixList.add(this.upPooling(matrix));
            }
            return upPoolingMatrixList;
        }
        return downConvMatrixList;
    }

    protected List<Matrix> downConvAndPooling(List<Matrix> matrixList, ConvParameter convParameter, int channelNo, ActiveFunction activeFunction, int kernLen, boolean pooling, long eventID) throws Exception {
        List<ConvSize> convSizeList = convParameter.getConvSizeList();
        List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
        List<Matrix> im2colMatrixList = convParameter.getIm2colMatrixList();
        List<Matrix> outMatrixList = convParameter.getOutMatrixList();
        im2colMatrixList.clear();
        outMatrixList.clear();
        ArrayList<Matrix> resultMatrixList = new ArrayList<Matrix>();
        for (int i = 0; i < channelNo; ++i) {
            Matrix matrix = matrixList.get(i);
            ConvSize convSize = convSizeList.get(i);
            Matrix nerveMatrix = nerveMatrixList.get(i);
            int xInput = matrix.getX();
            int yInput = matrix.getY();
            convSize.setXInput(xInput);
            convSize.setYInput(yInput);
            ConvResult convResult = this.downConvCount(matrix, activeFunction, kernLen, nerveMatrix);
            im2colMatrixList.add(convResult.getLeftMatrix());
            Matrix myMatrix = convResult.getResultMatrix();
            outMatrixList.add(myMatrix);
            resultMatrixList.add(myMatrix);
        }
        if (eventID >= 0L) {
            convParameter.getFeatureMap().put(eventID, resultMatrixList);
        }
        convParameter.setOutX(((Matrix)resultMatrixList.get(0)).getX());
        convParameter.setOutY(((Matrix)resultMatrixList.get(0)).getY());
        if (pooling) {
            ArrayList<Matrix> poolMatrixList = new ArrayList<Matrix>();
            for (Matrix matrix : resultMatrixList) {
                poolMatrixList.add(this.downPooling(matrix));
            }
            return poolMatrixList;
        }
        return resultMatrixList;
    }

    protected List<Matrix> manyOneConv(List<Matrix> feature, List<List<Float>> oneConvPower) throws Exception {
        ArrayList<Matrix> result = new ArrayList<Matrix>();
        for (List<Float> convPower : oneConvPower) {
            result.add(this.oneConv(feature, convPower));
        }
        return result;
    }

    protected Matrix oneConv(List<Matrix> feature, List<Float> oneConvPower) throws Exception {
        int size = oneConvPower.size();
        Matrix sigmaMatrix = null;
        for (int i = 0; i < size; ++i) {
            Matrix featureMatrix = feature.get(i);
            float power = oneConvPower.get(i).floatValue();
            Matrix pMatrix = this.matrixOperation.mathMulBySelf(featureMatrix, power);
            sigmaMatrix = i == 0 ? pMatrix : this.matrixOperation.add(sigmaMatrix, pMatrix);
        }
        return sigmaMatrix;
    }

    protected void backOneConvByList(List<Matrix> errorMatrixList, List<Matrix> matrixList, List<List<Float>> oneConvPower, float studyRate, boolean UNet) throws Exception {
        int size = errorMatrixList.size();
        if (size == oneConvPower.size()) {
            for (int i = 0; i < size; ++i) {
                Matrix errorMatrix = errorMatrixList.get(i);
                List<Float> oneConvPowerList = oneConvPower.get(i);
                this.backOneConv(errorMatrix, matrixList, oneConvPowerList, studyRate, UNet);
            }
        } else {
            throw new Exception("\u8bef\u5dee\u77e9\u9635\u5927\u5c0f\u4e0e\u901a\u9053\u6570\u4e0d\u76f8\u7b26");
        }
    }

    protected void backOneConv(Matrix errorMatrix, List<Matrix> matrixList, List<Float> oneConvPower, float studyRate, boolean uNet) throws Exception {
        int size = oneConvPower.size();
        for (int t = 0; t < size; ++t) {
            Matrix myMatrix = matrixList.get(t);
            int x = myMatrix.getX();
            int y = myMatrix.getY();
            float power = oneConvPower.get(t).floatValue();
            float allSubPower = 0.0f;
            float len = (float)Math.sqrt(x * y);
            for (int i = 0; i < x; ++i) {
                for (int j = 0; j < y; ++j) {
                    float subPower = myMatrix.getNumber(i, j) * errorMatrix.getNumber(i, j) * studyRate;
                    allSubPower += subPower;
                }
            }
            float sup = uNet ? allSubPower / len : allSubPower;
            oneConvPower.set(t, Float.valueOf(power += sup));
        }
    }

    protected List<Matrix> backAllDownConv(ConvParameter convParameter, List<Matrix> errorMatrixList, float studyPoint, ActiveFunction activeFunction, int channelNo, int kernLen) throws Exception {
        List<Matrix> outMatrixList = convParameter.getOutMatrixList();
        List<Matrix> im2colMatrixList = convParameter.getIm2colMatrixList();
        List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
        List<ConvSize> convSizeList = convParameter.getConvSizeList();
        ArrayList<Matrix> resultMatrixList = new ArrayList<Matrix>();
        for (int i = 0; i < channelNo; ++i) {
            Matrix errorMatrix = errorMatrixList.get(i);
            Matrix outMatrix = outMatrixList.get(i);
            Matrix im2col = im2colMatrixList.get(i);
            Matrix nerveMatrix = nerveMatrixList.get(i);
            ConvSize convSize = convSizeList.get(i);
            int xInput = convSize.getXInput();
            int yInput = convSize.getYInput();
            ConvResult convResult = this.backDownConv(errorMatrix, outMatrix, activeFunction, im2col, nerveMatrix, studyPoint, kernLen, xInput, yInput);
            nerveMatrixList.set(i, convResult.getNervePowerMatrix());
            resultMatrixList.add(convResult.getResultMatrix());
        }
        return resultMatrixList;
    }

    private ConvResult backDownConv(Matrix errorMatrix, Matrix outMatrix, ActiveFunction activeFunction, Matrix im2col, Matrix nerveMatrix, float studyRate, int kernSize, int xInput, int yInput) throws Exception {
        ConvResult convResult = new ConvResult();
        int x = errorMatrix.getX();
        int y = errorMatrix.getY();
        Matrix resultError = new Matrix(x * y, 1);
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                float error = errorMatrix.getNumber(i, j);
                float out = outMatrix.getNumber(i, j);
                resultError.setNub(y * i + j, 0, error *= activeFunction.functionG(out));
            }
        }
        Matrix wSub = this.matrixOperation.matrixMulPd(resultError, im2col, nerveMatrix, false);
        Matrix im2colSub = this.matrixOperation.matrixMulPd(resultError, im2col, nerveMatrix, true);
        this.matrixOperation.mathMul(wSub, studyRate);
        nerveMatrix = this.matrixOperation.add(nerveMatrix, wSub);
        Matrix gNext = this.matrixOperation.reverseIm2col(im2colSub, kernSize, 1, xInput, yInput);
        convResult.setNervePowerMatrix(nerveMatrix);
        convResult.setResultMatrix(gNext);
        return convResult;
    }

    private ConvResult downConvCount(Matrix matrix, ActiveFunction activeFunction, int kerSize, Matrix nervePowerMatrix) throws Exception {
        ConvResult convResult = new ConvResult();
        int xInput = matrix.getX();
        int yInput = matrix.getY();
        int sub = kerSize - 1;
        int x = xInput - sub;
        int y = yInput - sub;
        Matrix myMatrix = new Matrix(x, y);
        Matrix im2col = this.matrixOperation.im2col(matrix, kerSize, 1);
        convResult.setLeftMatrix(im2col);
        Matrix matrixOut = this.matrixOperation.mulMatrix(im2col, nervePowerMatrix);
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                float nub = activeFunction.function(matrixOut.getNumber(i * y + j, 0));
                myMatrix.setNub(i, j, nub);
            }
        }
        convResult.setResultMatrix(myMatrix);
        return convResult;
    }
}

