/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.unet;

import java.util.ArrayList;
import java.util.List;
import org.dromara.easyai.config.UNetConfig;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.function.ReLu;
import org.dromara.easyai.function.Tanh;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.unet.ConvModel;
import org.dromara.easyai.unet.Cutting;
import org.dromara.easyai.unet.UNetDecoder;
import org.dromara.easyai.unet.UNetEncoder;
import org.dromara.easyai.unet.UNetInput;
import org.dromara.easyai.unet.UNetModel;

public class UNetManager
extends ConvCount {
    private final List<UNetEncoder> encoderList = new ArrayList<UNetEncoder>();
    private final List<UNetDecoder> decoderList = new ArrayList<UNetDecoder>();
    private final int kernLen;
    private final int channelNo;
    private final int deep;
    private final float studyRate;
    private final float oneStudyRate;
    private UNetInput input;

    public UNetInput getInput() {
        return this.input;
    }

    public UNetManager(UNetConfig uNetConfig) throws Exception {
        int xSize = uNetConfig.getXSize();
        int ySize = uNetConfig.getYSize();
        int minFeatureValue = uNetConfig.getMinFeatureValue();
        this.kernLen = uNetConfig.getKerSize();
        this.channelNo = uNetConfig.getChannelNo();
        this.studyRate = uNetConfig.getStudyRate();
        this.oneStudyRate = uNetConfig.getOneStudyRate();
        this.deep = this.getConvMyDep(xSize, ySize, this.kernLen, minFeatureValue);
        if (this.deep <= 1) {
            throw new Exception("minFeatureValue \u8bbe\u7f6e\u7684\u503c\u592a\u5927\u4e86");
        }
        this.initEncoder(xSize, ySize);
        this.initDecoder(uNetConfig.isCutting(), uNetConfig.getCutTh());
        this.connectionCoder();
    }

    private float[] getFValue(Float[] values) {
        float[] fValue = new float[values.length];
        for (int i = 0; i < values.length; ++i) {
            fValue[i] = values[i].floatValue();
        }
        return fValue;
    }

    private Float[] getValue(float[] values) {
        Float[] result = new Float[values.length];
        for (int i = 0; i < values.length; ++i) {
            result[i] = Float.valueOf(values[i]);
        }
        return result;
    }

    public void insertModel(UNetModel uNetModel) throws Exception {
        List<Float[]> downPowers;
        ConvModel convModel;
        List<Matrix> matrixList;
        ConvParameter convParameter;
        int i;
        List<ConvModel> encoderModel = uNetModel.getEncoderModels();
        List<ConvModel> decoderModel = uNetModel.getDecoderModels();
        if (encoderModel.size() != this.deep) {
            throw new Exception("\u6a21\u578b\u6df1\u5ea6\u4e0d\u5339\u914d");
        }
        for (i = 0; i < this.deep; ++i) {
            convParameter = this.encoderList.get(i).getConvParameter();
            matrixList = convParameter.getNerveMatrixList();
            convModel = encoderModel.get(i);
            downPowers = convModel.getDownNervePower();
            List<List<Float>> oneNervePower = convModel.getOneNervePowerList();
            convParameter.setOneConvPower(oneNervePower);
            for (int j = 0; j < matrixList.size(); ++j) {
                Matrix matrix = matrixList.get(j);
                float[] power = this.getFValue(downPowers.get(j));
                matrix.setCudaMatrix(power, matrix.getX(), matrix.getY());
            }
        }
        for (i = 0; i < this.deep + 1; ++i) {
            int j;
            convParameter = this.decoderList.get(i).getConvParameter();
            matrixList = convParameter.getNerveMatrixList();
            convModel = decoderModel.get(i);
            downPowers = convModel.getDownNervePower();
            List<Float[]> upNervePowerModel = convModel.getUpNervePower();
            convParameter.setUpOneConvPower(convModel.getOneNervePower());
            List<Matrix> upNervePowers = convParameter.getUpNerveMatrixList();
            for (j = 0; j < upNervePowerModel.size(); ++j) {
                float[] upPower = this.getFValue(upNervePowerModel.get(j));
                Matrix upNervePower = upNervePowers.get(j);
                upNervePower.setCudaMatrix(upPower, upNervePower.getX(), upNervePower.getY());
            }
            for (j = 0; j < matrixList.size(); ++j) {
                Matrix matrix = matrixList.get(j);
                float[] power = this.getFValue(downPowers.get(j));
                matrix.setCudaMatrix(power, matrix.getX(), matrix.getY());
            }
        }
    }

    public UNetModel getModel() {
        ArrayList<Object> downNervePower;
        ConvParameter convParameter;
        ConvModel convModel;
        int i;
        UNetModel unetModel = new UNetModel();
        ArrayList<ConvModel> encoderModel = new ArrayList<ConvModel>();
        ArrayList<ConvModel> decoderModel = new ArrayList<ConvModel>();
        unetModel.setEncoderModels(encoderModel);
        unetModel.setDecoderModels(decoderModel);
        for (i = 0; i < this.deep; ++i) {
            convModel = new ConvModel();
            encoderModel.add(convModel);
            convParameter = this.encoderList.get(i).getConvParameter();
            downNervePower = new ArrayList<Float[]>();
            convModel.setDownNervePower(downNervePower);
            List<List<Float>> onePowers = convParameter.getOneConvPower();
            if (onePowers != null && !onePowers.isEmpty()) {
                convModel.setOneNervePowerList(onePowers);
            }
            List<Matrix> downNerveMatrix = convParameter.getNerveMatrixList();
            for (Matrix nerveMatrix : downNerveMatrix) {
                Float[] downPower = this.getValue(nerveMatrix.getCudaMatrix());
                downNervePower.add(downPower);
            }
        }
        for (i = 0; i < this.deep + 1; ++i) {
            convModel = new ConvModel();
            decoderModel.add(convModel);
            convParameter = this.decoderList.get(i).getConvParameter();
            convModel.setOneNervePower(convParameter.getUpOneConvPower());
            downNervePower = new ArrayList();
            convModel.setDownNervePower(downNervePower);
            List<Matrix> upNerveMatrix = convParameter.getUpNerveMatrixList();
            ArrayList<Float[]> upNervePower = new ArrayList<Float[]>();
            for (Matrix upMatrix : upNerveMatrix) {
                upNervePower.add(this.getValue(upMatrix.getCudaMatrix()));
            }
            convModel.setUpNervePower(upNervePower);
            List<Matrix> downNerveMatrix = convParameter.getNerveMatrixList();
            for (Matrix nerveMatrix : downNerveMatrix) {
                Float[] downPower = this.getValue(nerveMatrix.getCudaMatrix());
                downNervePower.add(downPower);
            }
        }
        return unetModel;
    }

    private void connectionCoder() {
        UNetEncoder lastUNetEncoder = this.encoderList.get(this.deep - 1);
        UNetDecoder firstUNetDecoder = this.decoderList.get(0);
        lastUNetEncoder.setDecoder(firstUNetDecoder);
        firstUNetDecoder.setEncoder(lastUNetEncoder);
        for (int i = 0; i < this.deep; ++i) {
            UNetEncoder uNetEncoder = this.encoderList.get(i);
            UNetDecoder uNetDecoder = this.decoderList.get(this.deep - i);
            uNetDecoder.setMyUNetEncoder(uNetEncoder);
        }
    }

    private void initDecoder(boolean cutting, float cutTh) throws Exception {
        UNetDecoder uNetDecoder;
        int i;
        Cutting myCut = null;
        if (cutting) {
            myCut = new Cutting(cutTh);
        }
        for (i = 0; i < this.deep + 1; ++i) {
            uNetDecoder = new UNetDecoder(this.kernLen, i + 1, this.channelNo, new Tanh(), i == this.deep, this.studyRate, myCut, this.oneStudyRate);
            this.decoderList.add(uNetDecoder);
        }
        for (i = 0; i < this.deep; ++i) {
            uNetDecoder = this.decoderList.get(i);
            UNetDecoder nextUNetDecoder = this.decoderList.get(i + 1);
            uNetDecoder.setAfterDecoder(nextUNetDecoder);
            nextUNetDecoder.setBeforeDecoder(uNetDecoder);
        }
    }

    private void initEncoder(int xSize, int ySize) throws Exception {
        UNetEncoder uNetEncoder;
        int i;
        for (i = 0; i < this.deep; ++i) {
            uNetEncoder = new UNetEncoder(this.kernLen, this.channelNo, i + 1, new ReLu(), this.studyRate, xSize, ySize, this.oneStudyRate);
            if (i == 0) {
                this.input = new UNetInput(uNetEncoder);
            }
            this.encoderList.add(uNetEncoder);
        }
        for (i = 0; i < this.deep - 1; ++i) {
            uNetEncoder = this.encoderList.get(i);
            UNetEncoder nextUNetEncoder = this.encoderList.get(i + 1);
            uNetEncoder.setAfterEncoder(nextUNetEncoder);
            nextUNetEncoder.setBeforeEncoder(uNetEncoder);
        }
    }
}

