/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.unet;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.nerveEntity.ConvSize;
import org.dromara.easyai.unet.Cutting;
import org.dromara.easyai.unet.UNetEncoder;

public class UNetDecoder
extends ConvCount {
    private final ConvParameter convParameter = new ConvParameter();
    private final MatrixOperation matrixOperation = new MatrixOperation();
    private final int kerSize;
    private final int deep;
    private final float studyRate;
    private final int channelNo;
    private final boolean lastLay;
    private final ActiveFunction activeFunction;
    private UNetDecoder afterDecoder;
    private UNetDecoder beforeDecoder;
    private UNetEncoder encoder;
    private UNetEncoder myUNetEncoder;
    private final ConvSize convSize = new ConvSize();
    private final Cutting cutting;
    private final float oneConvStudyRate;

    public UNetDecoder(int kerSize, int deep, int channelNo, ActiveFunction activeFunction, boolean lastLay, float studyRate, Cutting cutting, float oneConvStudyRate) throws Exception {
        this.cutting = cutting;
        this.kerSize = kerSize;
        this.oneConvStudyRate = oneConvStudyRate;
        this.deep = deep;
        this.studyRate = studyRate;
        this.lastLay = lastLay;
        this.channelNo = channelNo;
        this.activeFunction = activeFunction;
        Random random = new Random();
        List<Matrix> nerveMatrixList = this.convParameter.getNerveMatrixList();
        List<Matrix> upNeverMatrixList = this.convParameter.getUpNerveMatrixList();
        List<ConvSize> convSizeList = this.convParameter.getConvSizeList();
        for (int i = 0; i < channelNo; ++i) {
            upNeverMatrixList.add(this.initUpNervePowerMatrix(random));
            this.initNervePowerMatrix(random, nerveMatrixList);
            convSizeList.add(new ConvSize());
        }
        if (lastLay) {
            ArrayList<Float> oneConvPower = new ArrayList<Float>();
            for (int i = 0; i < channelNo; ++i) {
                oneConvPower.add(Float.valueOf(random.nextFloat() / (float)channelNo));
            }
            this.convParameter.setUpOneConvPower(oneConvPower);
        }
    }

    public ConvParameter getConvParameter() {
        return this.convParameter;
    }

    private ThreeChannelMatrix fillColor(ThreeChannelMatrix picture, int heightSize, int widthSize) throws Exception {
        int myFaceHeight = picture.getX();
        int sub = myFaceHeight - heightSize;
        int fillHeight = sub / 2;
        if (fillHeight == 0) {
            fillHeight = 1;
        }
        ThreeChannelMatrix fillMatrix = null;
        if (sub > 0) {
            fillMatrix = picture.cutChannel(fillHeight, 0, heightSize, widthSize);
        } else if (sub < 0) {
            fillMatrix = this.getFaceMatrix(heightSize, widthSize);
            fillMatrix.fill(Math.abs(fillHeight), 0, picture);
        }
        return fillMatrix;
    }

    private ThreeChannelMatrix getFaceMatrix(int height, int width) {
        ThreeChannelMatrix threeChannelMatrix = new ThreeChannelMatrix();
        Matrix matrixR = new Matrix(height, width);
        Matrix matrixG = new Matrix(height, width);
        Matrix matrixB = new Matrix(height, width);
        Matrix matrixH = new Matrix(height, width);
        threeChannelMatrix.setX(height);
        threeChannelMatrix.setY(width);
        threeChannelMatrix.setMatrixR(matrixR);
        threeChannelMatrix.setMatrixG(matrixG);
        threeChannelMatrix.setMatrixB(matrixB);
        threeChannelMatrix.setH(matrixH);
        return threeChannelMatrix;
    }

    private void addFeatures(List<Matrix> encoderFeatures, List<Matrix> myFeatures, boolean study) throws Exception {
        int size = encoderFeatures.size();
        for (int i = 0; i < size; ++i) {
            this.addFeature(encoderFeatures.get(i), myFeatures.get(i), study);
        }
    }

    private void addFeature(Matrix encoderFeature, Matrix myFeature, boolean study) throws Exception {
        if (study) {
            this.convSize.setXInput(encoderFeature.getX());
            this.convSize.setYInput(encoderFeature.getY());
        }
        int tx = encoderFeature.getX();
        int ty = encoderFeature.getY();
        int x = myFeature.getX();
        int y = myFeature.getY();
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                float encoderValue = 0.0f;
                if (i < tx && j < ty) {
                    encoderValue = encoderFeature.getNumber(i, j);
                }
                float value = (myFeature.getNumber(i, j) + encoderValue) / 2.0f;
                myFeature.setNub(i, j, value);
            }
        }
    }

    private void toThreeChannelMatrix(List<Matrix> features, ThreeChannelMatrix featureE, boolean study, OutBack outBack, ThreeChannelMatrix backGround) throws Exception {
        int x = features.get(0).getX();
        int y = features.get(0).getY();
        List<Float> upOneConvPower = this.convParameter.getUpOneConvPower();
        Matrix feature = this.oneConv(features, upOneConvPower);
        if (study) {
            ThreeChannelMatrix sfe = featureE.scale(true, y);
            ThreeChannelMatrix fe = this.fillColor(sfe, x, y);
            if (fe == null) {
                fe = sfe;
            }
            Matrix he = fe.CalculateAvgGrayscale();
            Matrix errorMatrix = this.matrixOperation.sub(he, feature);
            ArrayList<Matrix> errorMatrixList = new ArrayList<Matrix>();
            for (int i = 0; i < this.channelNo; ++i) {
                float power = upOneConvPower.get(i).floatValue();
                Matrix error = this.matrixOperation.mathMulBySelf(errorMatrix, power);
                errorMatrixList.add(error);
            }
            this.backOneConv(errorMatrix, features, upOneConvPower, this.oneConvStudyRate, true);
            this.backLastError(errorMatrixList);
        } else {
            int mx = backGround.getX();
            int my = backGround.getY();
            int startX = (mx - feature.getX()) / 2;
            int startY = (my - feature.getY()) / 2;
            Matrix myMatrix = new Matrix(mx, my);
            for (int i = startX; i < x; ++i) {
                for (int j = startY; j < y; ++j) {
                    myMatrix.setNub(i, j, feature.getNumber(i - startX, j - startY));
                }
            }
            ThreeChannelMatrix threeChannelMatrix = new ThreeChannelMatrix();
            threeChannelMatrix.setX(x);
            threeChannelMatrix.setY(y);
            threeChannelMatrix.setMatrixR(myMatrix);
            threeChannelMatrix.setMatrixG(myMatrix);
            threeChannelMatrix.setMatrixB(myMatrix);
            if (this.cutting != null) {
                this.cutting.cut(backGround, threeChannelMatrix, outBack);
            } else {
                outBack.getBackThreeChannelMatrix(threeChannelMatrix);
            }
        }
    }

    private void backLastError(List<Matrix> errorMatrixList) throws Exception {
        List<Matrix> errorList = this.backAllDownConv(this.convParameter, errorMatrixList, this.studyRate, this.activeFunction, this.channelNo, this.kerSize);
        this.sendEncoderError(errorList);
        this.beforeDecoder.backErrorMatrix(errorList);
    }

    private void sendEncoderError(List<Matrix> errors) throws Exception {
        ArrayList<Matrix> encoderErrors = new ArrayList<Matrix>();
        for (Matrix error : errors) {
            Matrix encoderError = new Matrix(this.convSize.getXInput(), this.convSize.getYInput());
            int x = this.convSize.getXInput();
            int y = this.convSize.getYInput();
            int tx = error.getX();
            int ty = error.getY();
            for (int i = 0; i < x; ++i) {
                for (int j = 0; j < y; ++j) {
                    float value = 0.0f;
                    if (i < tx && j < ty) {
                        value = error.getNumber(i, j) / 2.0f;
                    }
                    encoderError.setNub(i, j, value);
                }
            }
            encoderErrors.add(encoderError);
        }
        this.myUNetEncoder.setDecodeErrorMatrix(encoderErrors);
    }

    protected void backErrorMatrix(List<Matrix> myErrorMatrixList) throws Exception {
        List<Matrix> errorList = this.backManyUpPooling(myErrorMatrixList);
        List<Matrix> errorMatrixList = this.backManyUpConv(errorList, this.kerSize, this.convParameter, this.studyRate, this.activeFunction);
        List<Matrix> backList = this.backAllDownConv(this.convParameter, errorMatrixList, this.studyRate, this.activeFunction, this.channelNo, this.kerSize);
        if (this.myUNetEncoder != null) {
            this.sendEncoderError(backList);
        }
        if (this.beforeDecoder != null) {
            this.beforeDecoder.backErrorMatrix(backList);
        } else {
            this.encoder.backError(backList);
        }
    }

    protected void sendFeature(long eventID, OutBack outBack, ThreeChannelMatrix featureE, List<Matrix> myFeatures, boolean study, ThreeChannelMatrix backGround) throws Exception {
        if (this.deep > 1) {
            List<Matrix> encoderMatrixList = this.myUNetEncoder.getAfterConvMatrix(eventID);
            this.addFeatures(encoderMatrixList, myFeatures, study);
        }
        List<Matrix> upConvMatrixList = this.upConvAndPooling(myFeatures, this.convParameter, this.channelNo, this.activeFunction, this.kerSize, !this.lastLay);
        if (this.lastLay) {
            this.toThreeChannelMatrix(upConvMatrixList, featureE, study, outBack, backGround);
        } else {
            this.afterDecoder.sendFeature(eventID, outBack, featureE, upConvMatrixList, study, backGround);
        }
    }

    private Matrix initUpNervePowerMatrix(Random random) throws Exception {
        int convSize = this.kerSize * this.kerSize;
        Matrix nervePowerMatrix = new Matrix(1, convSize);
        for (int j = 0; j < convSize; ++j) {
            float power = random.nextFloat() / (float)this.kerSize;
            nervePowerMatrix.setNub(0, j, power);
        }
        return nervePowerMatrix;
    }

    private void initNervePowerMatrix(Random random, List<Matrix> nervePowerMatrixList) throws Exception {
        int convSize = this.kerSize * this.kerSize;
        Matrix nervePowerMatrix = new Matrix(convSize, 1);
        for (int i = 0; i < convSize; ++i) {
            float power = random.nextFloat() / (float)this.kerSize;
            nervePowerMatrix.setNub(i, 0, power);
        }
        nervePowerMatrixList.add(nervePowerMatrix);
    }

    public UNetDecoder getAfterDecoder() {
        return this.afterDecoder;
    }

    public void setAfterDecoder(UNetDecoder afterDecoder) {
        this.afterDecoder = afterDecoder;
    }

    public UNetDecoder getBeforeDecoder() {
        return this.beforeDecoder;
    }

    public void setBeforeDecoder(UNetDecoder beforeDecoder) {
        this.beforeDecoder = beforeDecoder;
    }

    public UNetEncoder getEncoder() {
        return this.encoder;
    }

    public void setEncoder(UNetEncoder encoder) {
        this.encoder = encoder;
    }

    public void setMyUNetEncoder(UNetEncoder myUNetEncoder) {
        this.myUNetEncoder = myUNetEncoder;
    }
}

