/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.unet;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.nerveEntity.ConvSize;
import org.dromara.easyai.unet.UNetDecoder;

public class UNetEncoder
extends ConvCount {
    private final ConvParameter convParameter = new ConvParameter();
    private final MatrixOperation matrixOperation = new MatrixOperation();
    private final int kerSize;
    private final float studyRate;
    private final int deep;
    private final int channelNo;
    private List<Matrix> decodeErrorMatrix;
    private final ActiveFunction activeFunction;
    private UNetEncoder afterEncoder;
    private UNetEncoder beforeEncoder;
    private UNetDecoder decoder;
    private final int xSize;
    private final int ySize;
    private final float oneStudyRate;

    public UNetEncoder(int kerSize, int channelNo, int deep, ActiveFunction activeFunction, float studyRate, int xSize, int ySize, float oneStudyRate) throws Exception {
        Random random = new Random();
        this.xSize = xSize;
        this.ySize = ySize;
        this.oneStudyRate = oneStudyRate;
        this.studyRate = studyRate;
        this.kerSize = kerSize;
        this.activeFunction = activeFunction;
        this.deep = deep;
        this.channelNo = channelNo;
        List<Matrix> nerveMatrixList = this.convParameter.getNerveMatrixList();
        List<ConvSize> convSizeList = this.convParameter.getConvSizeList();
        for (int i = 0; i < channelNo; ++i) {
            this.initNervePowerMatrix(random, nerveMatrixList);
            convSizeList.add(new ConvSize());
        }
        if (deep == 1) {
            ArrayList<List<Float>> oneConvPowers = new ArrayList<List<Float>>();
            for (int k = 0; k < channelNo; ++k) {
                ArrayList<Float> oneConvPower = new ArrayList<Float>();
                oneConvPowers.add(oneConvPower);
                int channelNum = 3;
                for (int i = 0; i < channelNum; ++i) {
                    oneConvPower.add(Float.valueOf(random.nextFloat() / (float)channelNum));
                }
            }
            this.convParameter.setOneConvPower(oneConvPowers);
        }
    }

    public ConvParameter getConvParameter() {
        return this.convParameter;
    }

    protected void setDecodeErrorMatrix(List<Matrix> decodeErrorMatrix) {
        this.decodeErrorMatrix = decodeErrorMatrix;
    }

    protected List<Matrix> getAfterConvMatrix(long eventID) {
        List<Matrix> outMatrixList = this.convParameter.getFeatureMap().get(eventID);
        this.convParameter.getFeatureMap().remove(eventID);
        return outMatrixList;
    }

    public void sendThreeChannel(long eventID, OutBack outBack, ThreeChannelMatrix feature, ThreeChannelMatrix featureE, boolean study) throws Exception {
        if (study && featureE == null) {
            throw new Exception("\u8bad\u7ec3\u65f6\u671f\u671b\u77e9\u9635\u4e0d\u80fd\u4e3a\u7a7a");
        }
        if (feature.getX() != this.xSize && feature.getY() != this.ySize) {
            throw new Exception("\u8f93\u5165\u56fe\u7247\u5c3a\u5bf8\u4e0e\u521d\u59cb\u5316\u53c2\u6570\u4e0d\u4e00\u81f4");
        }
        ArrayList<Matrix> matrixList = new ArrayList<Matrix>();
        matrixList.add(feature.getMatrixR());
        matrixList.add(feature.getMatrixG());
        matrixList.add(feature.getMatrixB());
        if (study) {
            this.convParameter.setFeatureMatrixList(matrixList);
        }
        this.sendMatrixList(eventID, outBack, featureE, matrixList, study, feature);
    }

    protected void sendFeature(long eventID, OutBack outBack, ThreeChannelMatrix featureE, List<Matrix> myFeatures, boolean study, ThreeChannelMatrix backGround) throws Exception {
        List<Matrix> convMatrixList = this.downConvAndPooling(myFeatures, this.convParameter, this.channelNo, this.activeFunction, this.kerSize, true, eventID);
        if (this.afterEncoder != null) {
            this.afterEncoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
        } else {
            this.decoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
        }
    }

    protected void backError(List<Matrix> errorMatrix) throws Exception {
        List<Matrix> errorList = this.backDownPoolingByList(errorMatrix, this.convParameter.getOutX(), this.convParameter.getOutY());
        List<Matrix> errorMatrixList = this.matrixOperation.addMatrixList(errorList, this.decodeErrorMatrix);
        List<Matrix> myErrorMatrix = this.backAllDownConv(this.convParameter, errorMatrixList, this.studyRate, this.activeFunction, this.channelNo, this.kerSize);
        if (this.beforeEncoder != null) {
            this.beforeEncoder.backError(myErrorMatrix);
        } else {
            this.backOneConvByList(myErrorMatrix, this.convParameter.getFeatureMatrixList(), this.convParameter.getOneConvPower(), this.oneStudyRate, true);
        }
    }

    public void sendMatrixList(long eventID, OutBack outBack, ThreeChannelMatrix featureE, List<Matrix> feature, boolean study, ThreeChannelMatrix backGround) throws Exception {
        List<Matrix> myFeatures = this.manyOneConv(feature, this.convParameter.getOneConvPower());
        List<Matrix> convMatrixList = this.downConvAndPooling(myFeatures, this.convParameter, this.channelNo, this.activeFunction, this.kerSize, true, eventID);
        if (this.afterEncoder != null) {
            this.afterEncoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
        } else {
            this.decoder.sendFeature(eventID, outBack, featureE, convMatrixList, study, backGround);
        }
    }

    private void initNervePowerMatrix(Random random, List<Matrix> nervePowerMatrixList) throws Exception {
        int convSize = this.kerSize * this.kerSize;
        Matrix nervePowerMatrix = new Matrix(convSize, 1);
        for (int i = 0; i < convSize; ++i) {
            float power = random.nextFloat() / (float)this.kerSize;
            nervePowerMatrix.setNub(i, 0, power);
        }
        nervePowerMatrixList.add(nervePowerMatrix);
    }

    public UNetEncoder getAfterEncoder() {
        return this.afterEncoder;
    }

    public void setAfterEncoder(UNetEncoder afterEncoder) {
        this.afterEncoder = afterEncoder;
    }

    public UNetEncoder getBeforeEncoder() {
        return this.beforeEncoder;
    }

    public void setBeforeEncoder(UNetEncoder beforeEncoder) {
        this.beforeEncoder = beforeEncoder;
    }

    public UNetDecoder getDecoder() {
        return this.decoder;
    }

    public void setDecoder(UNetDecoder decoder) {
        this.decoder = decoder;
    }
}

