/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.transFormer;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.dromara.easyai.config.TfConfig;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixList;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.WordIds;
import org.dromara.easyai.transFormer.model.TransWordVectorModel;

public class TransWordVector {
    private final List<String> wordList = new ArrayList<String>();
    private final List<Matrix> wordVectorList = new ArrayList<Matrix>();
    private final Matrix positionCodeMatrix;
    private final WordIds wordIds = new WordIds();
    private final String splitWord;
    private final int featureDimension;
    private final Random random = new Random();
    private final String startWord;
    private final String endWord;
    private final MatrixOperation matrixOperation = new MatrixOperation();
    private final float studyRate;
    private final int maxLength;

    public int getEndID() {
        return 2;
    }

    public int getStartID() {
        return 1;
    }

    public TransWordVector(TfConfig tfConfig) throws Exception {
        this.splitWord = tfConfig.getSplitWord();
        this.studyRate = tfConfig.getStudyRate();
        this.featureDimension = tfConfig.getFeatureDimension();
        this.startWord = tfConfig.getStartWord();
        this.endWord = tfConfig.getEndWord();
        this.maxLength = tfConfig.getMaxLength() + 2;
        this.positionCodeMatrix = new Matrix(this.maxLength, this.featureDimension);
        this.wordList.add(this.startWord);
        this.wordList.add(this.endWord);
        this.initWordVector();
        this.initWordVector();
        this.initPositionMatrix();
    }

    private void initPositionMatrix() throws Exception {
        int x = this.positionCodeMatrix.getX();
        int y = this.positionCodeMatrix.getY();
        Random random = new Random();
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                float value = random.nextFloat();
                if (i == 0) {
                    value += 1.0f;
                }
                this.positionCodeMatrix.setNub(i, j, value);
            }
        }
    }

    public TransWordVectorModel getModel() {
        TransWordVectorModel transWordVectorModel = new TransWordVectorModel();
        transWordVectorModel.setWordList(this.wordList);
        transWordVectorModel.setPositionMatrix(this.positionCodeMatrix.getMatrixModel());
        transWordVectorModel.setX(this.wordVectorList.get(0).getX());
        transWordVectorModel.setY(this.wordVectorList.get(0).getY());
        ArrayList<Float[]> wordVectorModel = new ArrayList<Float[]>();
        transWordVectorModel.setWordVectorModel(wordVectorModel);
        for (Matrix matrix : this.wordVectorList) {
            wordVectorModel.add(matrix.getMatrixModel());
        }
        return transWordVectorModel;
    }

    public void insertModel(TransWordVectorModel transWordVectorModel) {
        int x = transWordVectorModel.getX();
        int y = transWordVectorModel.getY();
        this.wordList.clear();
        this.wordVectorList.clear();
        this.wordList.addAll(transWordVectorModel.getWordList());
        this.positionCodeMatrix.insertMatrixModel(transWordVectorModel.getPositionMatrix());
        List<Float[]> wordVectorModel = transWordVectorModel.getWordVectorModel();
        for (Float[] floats : wordVectorModel) {
            Matrix matrix = new Matrix(x, y);
            matrix.insertMatrixModel(floats);
            this.wordVectorList.add(matrix);
        }
    }

    public void backEncoderError(Matrix error) throws Exception {
        List<Integer> ids = this.wordIds.getEncoder();
        int size = ids.size();
        if (size != error.getX()) {
            throw new Exception("\u7f16\u7801\u5668\u8bef\u5dee\u8fd4\u56de\u957f\u5ea6\u4e0d\u4e00\u81f4,size:" + size + ",errorSize:" + error.getX());
        }
        this.updateWordVector(ids, error);
        this.wordIds.getEncoder().clear();
    }

    private void updatePositionCode(Matrix error) throws Exception {
        int x = error.getX();
        int y = error.getY();
        for (int i = 0; i < x; ++i) {
            for (int j = 0; j < y; ++j) {
                float value = this.positionCodeMatrix.getNumber(i, j) + error.getNumber(i, j);
                this.positionCodeMatrix.setNub(i, j, value);
            }
        }
    }

    private void updateWordVector(List<Integer> ids, Matrix error) throws Exception {
        int size = ids.size();
        this.matrixOperation.mathMul(error, this.studyRate);
        this.updatePositionCode(error);
        for (int i = 0; i < size; ++i) {
            int index = ids.get(i);
            Matrix wordError = error.getRow(i);
            Matrix wordVector = this.wordVectorList.get(index);
            wordVector = this.matrixOperation.add(wordVector, wordError);
            this.wordVectorList.set(index, wordVector);
        }
    }

    public void backDecoderError(Matrix errorMatrix, Matrix allFeature) throws Exception {
        Matrix error = this.matrixOperation.add(errorMatrix, allFeature);
        List<Integer> ids = this.wordIds.getDecoder();
        int size = ids.size();
        if (size != error.getX()) {
            throw new Exception("\u89e3\u7801\u5668\u8bef\u5dee\u8fd4\u56de\u957f\u5ea6\u4e0d\u4e00\u81f4");
        }
        this.updateWordVector(ids, error);
        this.wordIds.getDecoder().clear();
    }

    public String getWordByID(int id) {
        return this.wordList.get(id - 1);
    }

    public int getWordID(String word) {
        int id = -1;
        int size = this.wordList.size();
        for (int i = 0; i < size; ++i) {
            if (!this.wordList.get(i).equals(word)) continue;
            id = i + 1;
            break;
        }
        return id;
    }

    public List<Integer> getE(String word) {
        ArrayList<Integer> result = new ArrayList<Integer>();
        if (this.splitWord == null) {
            for (int i = 0; i < word.length(); ++i) {
                result.add(this.getWordID(word.substring(i, i + 1)));
            }
        } else {
            String[] words;
            for (String s : words = word.split(this.splitWord)) {
                result.add(this.getWordID(s));
            }
        }
        result.add(2);
        return result;
    }

    public Matrix getVector(String word) {
        int size = this.wordList.size();
        Matrix feature = null;
        for (int i = 0; i < size; ++i) {
            if (!this.wordList.get(i).equals(word)) continue;
            feature = this.wordVectorList.get(i);
            break;
        }
        return feature;
    }

    private Matrix getVectorByStudy(String word, boolean decoder, boolean study) {
        int size = this.wordList.size();
        Matrix feature = null;
        List<Integer> ids = null;
        if (decoder && study) {
            ids = this.wordIds.getDecoder();
        } else if (!decoder && study) {
            ids = this.wordIds.getEncoder();
        }
        for (int i = 0; i < size; ++i) {
            if (!this.wordList.get(i).equals(word)) continue;
            if (ids != null) {
                ids.add(i);
            }
            feature = this.wordVectorList.get(i);
            break;
        }
        if (feature == null) {
            feature = new Matrix(1, this.featureDimension);
        }
        return feature;
    }

    public Matrix getWordVector(String word, boolean decoder, boolean study) throws Exception {
        MatrixList matrixList;
        block9: {
            matrixList = null;
            if (decoder) {
                if (study) {
                    this.wordIds.getDecoder().add(0);
                }
                matrixList = new MatrixList(this.wordVectorList.get(0), true, this.maxLength + 10);
            }
            if (word == null || word.isEmpty()) break block9;
            if (word.length() > this.maxLength - 2) {
                throw new Exception("\u8bed\u53e5\u957f\u5ea6\u8d85\u8fc7\u8bbe\u5b9a\u7684\u6700\u5927\u503c");
            }
            if (this.splitWord == null) {
                int size = word.length();
                for (int i = 0; i < size; ++i) {
                    Matrix feature = this.getVectorByStudy(word.substring(i, i + 1), decoder, study);
                    if (matrixList == null) {
                        matrixList = new MatrixList(feature, true, this.maxLength + 10);
                        continue;
                    }
                    matrixList.add(feature);
                }
            } else {
                String[] myWord;
                for (String s : myWord = word.split(this.splitWord)) {
                    Matrix feature = this.getVectorByStudy(s, decoder, study);
                    if (matrixList == null) {
                        matrixList = new MatrixList(feature, true, this.maxLength + 10);
                        continue;
                    }
                    matrixList.add(feature);
                }
            }
        }
        return this.addPositionMatrix(matrixList.getMatrix());
    }

    private Matrix addPositionMatrix(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix positionCode = this.positionCodeMatrix.getSonOfMatrix(0, 0, x, y);
        return this.matrixOperation.add(matrix, positionCode);
    }

    private void initWordVector() throws Exception {
        Matrix matrix = new Matrix(1, this.featureDimension);
        for (int j = 0; j < this.featureDimension; ++j) {
            matrix.setNub(0, j, this.random.nextFloat());
        }
        this.wordVectorList.add(matrix);
    }

    private void insertWord(String word) throws Exception {
        if (!word.equals(this.startWord) && !word.equals(this.endWord)) {
            boolean here = false;
            for (String myWord : this.wordList) {
                if (!myWord.equals(word)) continue;
                here = true;
                break;
            }
            if (!here) {
                this.wordList.add(word);
                this.initWordVector();
            }
        } else {
            throw new Exception("\u4efb\u4f55\u5b57\u8bcd\u4e0d\u53ef\u4ee5\u4e0e\u7ed3\u675f\u7b26\u6216\u5f00\u59cb\u7b26\u91cd\u53e0");
        }
    }

    public void init(List<String> sentenceList) throws Exception {
        for (String sentence : sentenceList) {
            String[] myWord;
            if (sentence == null || sentence.isEmpty()) continue;
            if (this.splitWord == null) {
                for (int i = 0; i < sentence.length(); ++i) {
                    this.insertWord(sentence.substring(i, i + 1));
                }
                continue;
            }
            for (String s : myWord = sentence.split(this.splitWord)) {
                this.insertWord(s);
            }
        }
    }

    public int getWordSize() {
        return this.wordList.size();
    }
}

