/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.transFormer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.dromara.easyai.function.ReLu;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.FirstDecoderBlock;
import org.dromara.easyai.transFormer.LineBlock;
import org.dromara.easyai.transFormer.TransWordVector;
import org.dromara.easyai.transFormer.model.CodecBlockModel;
import org.dromara.easyai.transFormer.nerve.HiddenNerve;
import org.dromara.easyai.transFormer.nerve.Nerve;
import org.dromara.easyai.transFormer.seflAttention.LayNorm;
import org.dromara.easyai.transFormer.seflAttention.MultiSelfAttention;

public class CodecBlock {
    private final MultiSelfAttention multiSelfAttention;
    private final LayNorm attentionLayNorm;
    private final List<HiddenNerve> fistHiddenNerves = new ArrayList<HiddenNerve>();
    private final List<HiddenNerve> secondHiddenNerves = new ArrayList<HiddenNerve>();
    private final LayNorm lineLayNorm;
    private final TransWordVector transWordVector;
    private CodecBlock afterEncoderBlock;
    private CodecBlock beforeEncoderBlock;
    private CodecBlock lastEncoderBlock;
    private final Map<Long, Matrix> outMatrixMap = new HashMap<Long, Matrix>();
    private final boolean encoder;
    private LineBlock lineBlock;
    private FirstDecoderBlock firstDecoderBlock;
    private final MatrixOperation matrixOperation;
    private final int coreNumber;

    public CodecBlockModel getModel() throws Exception {
        ArrayList<float[][]> firstNerveModel = new ArrayList<float[][]>();
        ArrayList<float[][]> secondNerveModel = new ArrayList<float[][]>();
        for (int i = 0; i < this.fistHiddenNerves.size(); ++i) {
            firstNerveModel.add(this.fistHiddenNerves.get(i).getModel());
            secondNerveModel.add(this.secondHiddenNerves.get(i).getModel());
        }
        CodecBlockModel codecBlockModel = new CodecBlockModel();
        codecBlockModel.setMultiSelfAttentionModel(this.multiSelfAttention.getModel());
        codecBlockModel.setAttentionLayNormModel(this.attentionLayNorm.getModel());
        codecBlockModel.setFistNervesModel(firstNerveModel);
        codecBlockModel.setSecondNervesModel(secondNerveModel);
        codecBlockModel.setLineLayNormModel(this.lineLayNorm.getModel());
        return codecBlockModel;
    }

    public void insertModel(CodecBlockModel codecBlockModel) throws Exception {
        this.multiSelfAttention.insertModel(codecBlockModel.getMultiSelfAttentionModel());
        this.attentionLayNorm.insertModel(codecBlockModel.getAttentionLayNormModel());
        List<float[][]> firstNerveModel = codecBlockModel.getFistNervesModel();
        List<float[][]> secondNerveModel = codecBlockModel.getSecondNervesModel();
        for (int i = 0; i < this.fistHiddenNerves.size(); ++i) {
            this.fistHiddenNerves.get(i).insertModel(firstNerveModel.get(i));
            this.secondHiddenNerves.get(i).insertModel(secondNerveModel.get(i));
        }
        this.lineLayNorm.insertModel(codecBlockModel.getLineLayNormModel());
    }

    public void setFirstDecoderBlock(FirstDecoderBlock firstDecoderBlock) {
        this.firstDecoderBlock = firstDecoderBlock;
    }

    public void setLineBlock(LineBlock lineBlock) {
        this.lineBlock = lineBlock;
    }

    public void setLastEncoderBlock(CodecBlock lastEncoderBlock) {
        this.lastEncoderBlock = lastEncoderBlock;
    }

    public void setAfterEncoderBlock(CodecBlock afterEncoderBlock) {
        this.afterEncoderBlock = afterEncoderBlock;
    }

    public void setBeforeEncoderBlock(CodecBlock beforeEncoderBlock) {
        this.beforeEncoderBlock = beforeEncoderBlock;
    }

    public CodecBlock(int multiNumber, int featureDimension, float studyPoint, int depth, boolean encoder, int regularModel, float regular, int coreNumber, TransWordVector transWordVector) throws Exception {
        this.matrixOperation = new MatrixOperation(coreNumber);
        this.encoder = encoder;
        this.transWordVector = transWordVector;
        this.coreNumber = coreNumber;
        this.attentionLayNorm = new LayNorm(1, featureDimension, this, null, studyPoint, coreNumber, encoder, depth);
        this.lineLayNorm = new LayNorm(2, featureDimension, this, null, studyPoint, coreNumber, encoder, depth);
        this.multiSelfAttention = new MultiSelfAttention(multiNumber, studyPoint, depth, featureDimension, encoder, this, coreNumber, null);
        this.multiSelfAttention.setLayNorm(this.attentionLayNorm);
        this.attentionLayNorm.setMultiSelfAttention(this.multiSelfAttention);
        this.initLine(featureDimension, studyPoint, regularModel, regular);
        this.attentionLayNorm.setHiddenNerves(this.fistHiddenNerves);
        this.lineLayNorm.setHiddenNerves(this.secondHiddenNerves);
    }

    public void backError(long eventID, Matrix errorMatrix) throws Exception {
        this.lineLayNorm.backErrorFromLine(errorMatrix, eventID);
    }

    public void removeOutMatrix(long eventID) {
        this.outMatrixMap.remove(eventID);
    }

    public Matrix getOutMatrix(long eventID) {
        return this.outMatrixMap.get(eventID);
    }

    public void sendOutputMatrix(long eventID, Matrix out, boolean isStudy, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        if (this.beforeEncoderBlock != null) {
            this.beforeEncoderBlock.sendInputMatrix(eventID, out, isStudy, outBack, E, encoderFeature, outAllPro);
        } else if (this.encoder) {
            this.outMatrixMap.put(eventID, out);
        } else {
            this.lineBlock.sendParameter(eventID, out, isStudy, outBack, E, outAllPro);
        }
    }

    public void backCodecError(Matrix errorMatrix, long eventID, Matrix allFeature) throws Exception {
        Matrix error = this.matrixOperation.add(errorMatrix, allFeature);
        if (this.afterEncoderBlock != null) {
            this.afterEncoderBlock.backError(eventID, error);
        } else if (this.firstDecoderBlock != null) {
            this.firstDecoderBlock.backError(eventID, error);
        } else {
            this.transWordVector.backEncoderError(error);
        }
    }

    public void backLastEncoderError(Matrix error) throws Exception {
        this.lastEncoderBlock.backLastError(error);
    }

    private void backLastError(Matrix error) throws Exception {
        this.lineLayNorm.backLastError(error);
    }

    public void encoderBackStart(long eventID) throws Exception {
        this.lineLayNorm.encoderBackStart(eventID);
    }

    public void sendInputMatrix(long eventID, Matrix feature, boolean isStudy, OutBack outBack, List<Integer> E, Matrix encoderFeature, boolean outAllPro) throws Exception {
        this.multiSelfAttention.sendMatrixMessage(eventID, feature, isStudy, outBack, E, encoderFeature, outAllPro);
    }

    private void initLine(int featureDimension, float studyPoint, int regularModel, float regular) throws Exception {
        int i;
        ArrayList<Nerve> firstNerves = new ArrayList<Nerve>();
        ArrayList<Nerve> secondNerves = new ArrayList<Nerve>();
        for (i = 0; i < featureDimension; ++i) {
            HiddenNerve hiddenNerve1 = new HiddenNerve(i + 1, 1, studyPoint, new ReLu(), featureDimension, featureDimension, null, regularModel, regular, this.coreNumber);
            this.fistHiddenNerves.add(hiddenNerve1);
            hiddenNerve1.setAfterLayNorm(this.attentionLayNorm);
            firstNerves.add(hiddenNerve1);
        }
        for (i = 0; i < featureDimension; ++i) {
            HiddenNerve hiddenNerve2 = new HiddenNerve(i + 1, 2, studyPoint, null, featureDimension, 1, null, regularModel, regular, this.coreNumber);
            hiddenNerve2.setBeforeLayNorm(this.lineLayNorm);
            this.secondHiddenNerves.add(hiddenNerve2);
            secondNerves.add(hiddenNerve2);
        }
        for (Nerve hiddenNerve : firstNerves) {
            hiddenNerve.connect(secondNerves);
        }
        for (Nerve hiddenNerve : secondNerves) {
            hiddenNerve.connectFather(firstNerves);
        }
    }
}

