/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.transFormer;

import java.util.ArrayList;
import java.util.List;
import org.dromara.easyai.config.TfConfig;
import org.dromara.easyai.transFormer.CodecBlock;
import org.dromara.easyai.transFormer.FirstDecoderBlock;
import org.dromara.easyai.transFormer.LineBlock;
import org.dromara.easyai.transFormer.TransWordVector;
import org.dromara.easyai.transFormer.model.CodecBlockModel;
import org.dromara.easyai.transFormer.model.TransFormerModel;
import org.dromara.easyai.transFormer.model.TransWordVectorModel;
import org.dromara.easyai.transFormer.nerve.SensoryNerve;

public class TransFormerManager {
    private final List<CodecBlock> encoderBlocks = new ArrayList<CodecBlock>();
    private final List<CodecBlock> decoderBlocks = new ArrayList<CodecBlock>();
    private SensoryNerve sensoryNerve;
    private FirstDecoderBlock firstDecoderBlock;
    private LineBlock lineBlock;
    private TransWordVector transWordVector;

    public TransWordVector getTransWordVector() {
        return this.transWordVector;
    }

    public SensoryNerve getSensoryNerve() {
        return this.sensoryNerve;
    }

    public TransFormerModel getModel() throws Exception {
        TransFormerModel transFormerModel = new TransFormerModel();
        transFormerModel.setTransWordVectorModel(this.transWordVector.getModel());
        ArrayList<CodecBlockModel> encoderBlockModels = new ArrayList<CodecBlockModel>();
        ArrayList<CodecBlockModel> decoderBlockModels = new ArrayList<CodecBlockModel>();
        for (int i = 0; i < this.encoderBlocks.size(); ++i) {
            encoderBlockModels.add(this.encoderBlocks.get(i).getModel());
            decoderBlockModels.add(this.decoderBlocks.get(i).getModel());
        }
        transFormerModel.setEncoderBlockModels(encoderBlockModels);
        transFormerModel.setDecoderBlockModels(decoderBlockModels);
        transFormerModel.setFirstDecoderBlockModel(this.firstDecoderBlock.getModel());
        transFormerModel.setLineBlockModel(this.lineBlock.getModel());
        return transFormerModel;
    }

    public void insertModel(TransFormerModel transFormerModel, TfConfig tfConfig) throws Exception {
        this.init(tfConfig, null, transFormerModel.getTransWordVectorModel());
        List<CodecBlockModel> encoderBlockModels = transFormerModel.getEncoderBlockModels();
        List<CodecBlockModel> decoderBlockModels = transFormerModel.getDecoderBlockModels();
        int minSize = Math.min(this.encoderBlocks.size(), encoderBlockModels.size());
        for (int i = 0; i < minSize; ++i) {
            this.encoderBlocks.get(i).insertModel(encoderBlockModels.get(i));
            this.decoderBlocks.get(i).insertModel(decoderBlockModels.get(i));
        }
        this.firstDecoderBlock.insertModel(transFormerModel.getFirstDecoderBlockModel());
        this.lineBlock.insertModel(transFormerModel.getLineBlockModel());
    }

    public void init(TfConfig tfConfig, List<String> sentenceList) throws Exception {
        if (this.transWordVector == null) {
            this.init(tfConfig, sentenceList, null);
        } else {
            this.transWordVector.init(sentenceList);
        }
    }

    private void init(TfConfig tfConfig, List<String> sentenceList, TransWordVectorModel transWordVectorModel) throws Exception {
        CodecBlock lastEnCoderBlock;
        this.transWordVector = new TransWordVector(tfConfig);
        int typeNumber = tfConfig.getTypeNumber();
        if (transWordVectorModel == null) {
            this.transWordVector.init(sentenceList);
        } else {
            this.transWordVector.insertModel(transWordVectorModel);
        }
        if (tfConfig.isNorm()) {
            typeNumber = this.transWordVector.getWordSize();
        }
        int multiNumber = tfConfig.getMultiNumber();
        int featureDimension = tfConfig.getFeatureDimension();
        if (featureDimension % 2 != 0) {
            throw new Exception("TransFormer \u8bcd\u5411\u91cf\u7ef4\u5ea6\u5fc5\u987b\u4e3a\u5076\u6570");
        }
        int allDepth = tfConfig.getAllDepth();
        float studyPoint = tfConfig.getStudyRate();
        boolean showLog = tfConfig.isShowLog();
        int regularModel = tfConfig.getRegularModel();
        float regular = tfConfig.getRegular();
        if (multiNumber > 1 && featureDimension > 0 && allDepth > 0 && typeNumber > 1) {
            for (int i = 0; i < allDepth; ++i) {
                CodecBlock encoderBlock = new CodecBlock(multiNumber, featureDimension, studyPoint, i + 1, true, regularModel, regular, tfConfig.getCoreNumber(), this.transWordVector);
                this.encoderBlocks.add(encoderBlock);
            }
            lastEnCoderBlock = this.encoderBlocks.get(this.encoderBlocks.size() - 1);
            for (int i = 0; i < allDepth; ++i) {
                CodecBlock decoderBlock = new CodecBlock(multiNumber, featureDimension, studyPoint, i + 2, false, regularModel, regular, tfConfig.getCoreNumber(), this.transWordVector);
                decoderBlock.setLastEncoderBlock(lastEnCoderBlock);
                this.decoderBlocks.add(decoderBlock);
            }
        } else {
            throw new Exception("param is null,typeNumber:" + typeNumber + ",featureDimension:" + featureDimension);
        }
        CodecBlock lastDecoderBlock = this.decoderBlocks.get(this.decoderBlocks.size() - 1);
        this.connectCodecBlock(this.encoderBlocks);
        this.connectCodecBlock(this.decoderBlocks);
        this.lineBlock = new LineBlock(typeNumber, featureDimension, studyPoint, lastDecoderBlock, showLog, regularModel, regular, tfConfig.getCoreNumber(), tfConfig.getTimePunValue());
        lastDecoderBlock.setLineBlock(this.lineBlock);
        this.firstDecoderBlock = new FirstDecoderBlock(multiNumber, featureDimension, studyPoint, this.decoderBlocks.get(0), tfConfig.getCoreNumber(), this.transWordVector);
        this.firstDecoderBlock.setLastEncoderBlock(lastEnCoderBlock);
        this.decoderBlocks.get(0).setFirstDecoderBlock(this.firstDecoderBlock);
        this.sensoryNerve = new SensoryNerve(this.encoderBlocks.get(0), this.firstDecoderBlock, this.transWordVector);
    }

    private void connectCodecBlock(List<CodecBlock> codecBlocks) {
        int size = codecBlocks.size();
        for (int i = 0; i < size - 1; ++i) {
            CodecBlock encoderBlock = codecBlocks.get(i);
            CodecBlock beforeBlock = codecBlocks.get(i + 1);
            encoderBlock.setBeforeEncoderBlock(beforeBlock);
            beforeBlock.setAfterEncoderBlock(encoderBlock);
        }
    }
}

