/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.naturalLanguage.word;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.dromara.easyai.config.SentenceConfig;
import org.dromara.easyai.entity.SentenceModel;
import org.dromara.easyai.entity.WordMatrix;
import org.dromara.easyai.entity.WordTwoVectorModel;
import org.dromara.easyai.function.Tanh;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.MatrixList;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.rnnJumpNerveEntity.MyWordFeature;
import org.dromara.easyai.rnnNerveCenter.NerveManager;
import org.dromara.easyai.rnnNerveEntity.SensoryNerve;

public class WordEmbedding
extends MatrixOperation {
    private NerveManager nerveManager;
    private SentenceModel sentenceModel;
    private final List<String> wordList = new ArrayList<String>();
    private SentenceConfig config;
    private int wordVectorDimension;
    private int studyTimes = 1;

    public void setStudyTimes(int studyTimes) {
        this.studyTimes = studyTimes;
    }

    public void setConfig(SentenceConfig config) {
        this.config = config;
    }

    public int getWordVectorDimension() {
        return this.wordVectorDimension;
    }

    public void init(SentenceModel sentenceModel, int wordVectorDimension) throws Exception {
        this.wordVectorDimension = wordVectorDimension;
        this.sentenceModel = sentenceModel;
        this.wordList.addAll(sentenceModel.getWordSet());
        this.nerveManager = new NerveManager(this.wordList.size(), wordVectorDimension, this.wordList.size(), 1, new Tanh(), this.config.getWeStudyPoint(), this.config.getRzModel(), this.config.getWeLParam());
        this.nerveManager.init(true, false, true);
    }

    public List<String> getWordList() {
        return this.wordList;
    }

    public String getWord(int id) {
        return this.wordList.get(id);
    }

    public void insertModel(WordTwoVectorModel wordTwoVectorModel, int wordVectorDimension) throws Exception {
        this.wordList.clear();
        this.wordVectorDimension = wordVectorDimension;
        List<String> myWordList = wordTwoVectorModel.getWordList();
        this.wordList.addAll(myWordList);
        this.nerveManager = new NerveManager(this.wordList.size(), wordVectorDimension, this.wordList.size(), 1, new Tanh(), this.config.getWeStudyPoint(), 0, 0.0f);
        this.nerveManager.init(true, false, true);
        this.nerveManager.insertModelParameter(wordTwoVectorModel.getModelParameter());
    }

    public MyWordFeature getEmbedding(String word, long eventId, boolean once) throws Exception {
        MyWordFeature myWordFeature = new MyWordFeature();
        int wordDim = this.wordVectorDimension;
        MatrixList matrixList = null;
        for (int i = 0; i < word.length(); ++i) {
            WordMatrix wordMatrix = new WordMatrix(wordDim);
            String myWord = !once ? word.substring(i, i + 1) : word;
            int index = this.getID(myWord);
            this.studyDNN(eventId, index, 0, wordMatrix, false);
            if (matrixList == null) {
                myWordFeature.setFirstFeatureList(wordMatrix.getList());
                matrixList = new MatrixList(wordMatrix.getVector(), true);
            } else {
                matrixList.add(wordMatrix.getVector());
            }
            if (once) break;
        }
        myWordFeature.setFeatureMatrix(matrixList.getMatrix());
        return myWordFeature;
    }

    private void studyDNN(long eventId, int featureIndex, int resIndex, OutBack outBack, boolean isStudy) throws Exception {
        List<SensoryNerve> sensoryNerves = this.nerveManager.getSensoryNerves();
        int size = sensoryNerves.size();
        HashMap<Integer, Float> map = new HashMap<Integer, Float>();
        if (resIndex > 0) {
            map.put(resIndex + 1, Float.valueOf(1.0f));
        }
        for (int i = 0; i < size; ++i) {
            float feature = 0.0f;
            if (i == featureIndex) {
                feature = 1.0f;
            }
            sensoryNerves.get(i).postMessage(eventId, feature, isStudy, map, outBack, true, null);
        }
    }

    public WordTwoVectorModel start() throws Exception {
        List<String[]> sentenceList = this.sentenceModel.getSentenceList();
        int size = sentenceList.size();
        System.out.println("\u8bcd\u5d4c\u5165\u8bad\u7ec3\u542f\u52a8...");
        int allTimes = this.studyTimes * size;
        int index = 0;
        for (int k = 0; k < this.studyTimes; ++k) {
            for (int i = 0; i < size; ++i) {
                long start = System.currentTimeMillis();
                this.study(sentenceList.get(i));
                long end = System.currentTimeMillis() - start;
                float r = (float)(++index) / (float)allTimes * 100.0f;
                String result = String.format("%.6f", Float.valueOf(r));
                System.out.println("size:" + size + ",index:" + i + ",\u8017\u65f6:" + end + ",\u5b8c\u6210\u5ea6:" + result + "%");
            }
        }
        WordTwoVectorModel wordTwoVectorModel = new WordTwoVectorModel();
        wordTwoVectorModel.setModelParameter(this.nerveManager.getModelParameter());
        wordTwoVectorModel.setWordList(this.wordList);
        return wordTwoVectorModel;
    }

    private void study(String[] word) throws Exception {
        int index;
        int i;
        int[] indexArray = new int[word.length];
        for (i = 0; i < word.length; ++i) {
            indexArray[i] = index = this.getID(word[i]);
        }
        for (i = 0; i < indexArray.length; ++i) {
            index = indexArray[i];
            for (int j = 0; j < indexArray.length; ++j) {
                if (i == j) continue;
                int resIndex = indexArray[j];
                this.studyDNN(1L, index, resIndex, null, true);
            }
        }
    }

    public int getID(String word) {
        int index = 0;
        int size = this.wordList.size();
        for (int i = 0; i < size; ++i) {
            if (!this.wordList.get(i).equals(word)) continue;
            index = i;
            break;
        }
        return index;
    }
}

