/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.rnnNerveEntity;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;

public abstract class Nerve {
    private final List<Nerve> son = new ArrayList<Nerve>();
    private final List<Nerve> rnnOut = new ArrayList<Nerve>();
    private final List<Nerve> father = new ArrayList<Nerve>();
    protected Map<Integer, Float> dendrites = new HashMap<Integer, Float>();
    protected Map<Integer, Float> wg = new HashMap<Integer, Float>();
    private final int id;
    boolean fromOutNerve = false;
    protected int upNub;
    protected int downNub;
    protected int rnnOutNub;
    protected Map<Long, List<Float>> features = new HashMap<Long, List<Float>>();
    protected Matrix nerveMatrix;
    protected float threshold;
    protected String name;
    protected float outNub;
    protected float E;
    protected float gradient;
    protected float studyPoint;
    protected float sigmaW;
    private int backNub = 0;
    protected ActiveFunction activeFunction;
    private final int rzType;
    private final float lParam;
    private final Map<Long, Integer> embeddingIndex = new HashMap<Long, Integer>();

    public Map<Integer, Float> getDendrites() {
        return this.dendrites;
    }

    public Matrix getNerveMatrix() {
        return this.nerveMatrix;
    }

    public void setNerveMatrix(Matrix nerveMatrix) {
        this.nerveMatrix = nerveMatrix;
    }

    public void setDendrites(Map<Integer, Float> dendrites) {
        this.dendrites = dendrites;
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float threshold) {
        this.threshold = threshold;
    }

    protected Nerve(int id, int upNub, String name, int downNub, float studyPoint, boolean init, ActiveFunction activeFunction, int rzType, float lParam, int rnnOutNub) throws Exception {
        this.id = id;
        this.upNub = upNub;
        this.name = name;
        this.downNub = downNub;
        this.studyPoint = studyPoint;
        this.activeFunction = activeFunction;
        this.rzType = rzType;
        this.lParam = lParam;
        this.rnnOutNub = rnnOutNub;
        if (name.equals("OutNerve")) {
            this.fromOutNerve = true;
        }
        this.initPower(init);
    }

    protected void setStudyPoint(float studyPoint) {
        this.studyPoint = studyPoint;
    }

    public void sendMessage(long eventId, float parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack, boolean isEmbedding, Matrix rnnMatrix) throws Exception {
        if (!this.son.isEmpty()) {
            for (Nerve nerve : this.son) {
                nerve.input(eventId, parameter, isStudy, E, outBack, isEmbedding, rnnMatrix);
            }
        } else {
            throw new Exception("this layer is lastIndex");
        }
    }

    public void sendRnnMessage(long eventId, float parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack, boolean isEmbedding, Matrix rnnMatrix) throws Exception {
        if (!this.rnnOut.isEmpty()) {
            for (Nerve nerve : this.rnnOut) {
                nerve.input(eventId, parameter, isStudy, E, outBack, isEmbedding, rnnMatrix);
            }
        } else {
            throw new Exception("this layer is lastIndex");
        }
    }

    private void backSendMessage(long eventId, boolean fromOutNerve) throws Exception {
        if (!this.father.isEmpty()) {
            for (int i = 0; i < this.father.size(); ++i) {
                this.father.get(i).backGetMessage(this.wg.get(i + 1).floatValue(), eventId, fromOutNerve);
            }
        }
    }

    protected void input(long eventId, float parameter, boolean isStudy, Map<Integer, Float> E, OutBack imageBack, boolean isEmbedding, Matrix rnnMatrix) throws Exception {
    }

    private void backGetMessage(float parameter, long eventId, boolean fromOutNerve) throws Exception {
        ++this.backNub;
        this.sigmaW += parameter;
        int number = fromOutNerve ? this.rnnOutNub : this.downNub;
        if (this.backNub == number) {
            this.backNub = 0;
            this.gradient = this.activeFunction.functionG(this.outNub) * this.sigmaW;
            this.updatePower(eventId);
        }
    }

    protected void updatePower(long eventId) throws Exception {
        float h = this.gradient * this.studyPoint;
        this.threshold -= h;
        this.updateW(h, eventId);
        this.sigmaW = 0.0f;
        this.backSendMessage(eventId, this.fromOutNerve);
    }

    private float regularization(float w, float param) {
        float re = 0.0f;
        if (this.rzType != 0) {
            if (this.rzType == 2) {
                re = param * -w;
            } else if (this.rzType == 1) {
                if (w > 0.0f) {
                    re = -param;
                } else if (w < 0.0f) {
                    re = param;
                }
            }
        }
        return re;
    }

    private void updateW(float h, long eventId) {
        List<Float> list = this.features.get(eventId);
        float param = 0.0f;
        if (this.rzType != 0) {
            double sigma = 0.0;
            for (Map.Entry<Integer, Float> entry : this.dendrites.entrySet()) {
                if (this.rzType == 2) {
                    sigma += (double)((float)Math.pow(entry.getValue().floatValue(), 2.0));
                    continue;
                }
                sigma += (double)Math.abs(entry.getValue().floatValue());
            }
            param = (float)sigma * this.lParam * this.studyPoint;
        }
        for (Map.Entry<Integer, Float> entry : this.dendrites.entrySet()) {
            int key = entry.getKey();
            float w = entry.getValue().floatValue();
            float bn = list.get(key - 1).floatValue();
            float wp = bn * h;
            float dm = w * this.gradient;
            float regular = this.regularization(w, param);
            w += regular;
            this.wg.put(key, Float.valueOf(dm));
            this.dendrites.put(key, Float.valueOf(w += wp));
        }
        this.features.remove(eventId);
    }

    protected boolean insertParameter(long eventId, float parameter, boolean embedding) {
        List<Object> featuresList;
        boolean allReady = false;
        if (this.features.containsKey(eventId)) {
            featuresList = this.features.get(eventId);
        } else {
            featuresList = new ArrayList();
            this.features.put(eventId, featuresList);
        }
        if (embedding && (double)parameter > 0.5) {
            this.embeddingIndex.put(eventId, featuresList.size());
        }
        featuresList.add(Float.valueOf(parameter));
        if (featuresList.size() >= this.upNub) {
            allReady = true;
        }
        return allReady;
    }

    protected void destoryParameter(long eventId) {
        this.features.remove(eventId);
    }

    protected float getWOne(long eventId) {
        int index = this.embeddingIndex.get(eventId);
        return this.dendrites.get(index + 1).floatValue();
    }

    protected float calculation(long eventId, boolean isEmbedding) {
        float sigma = 0.0f;
        List<Float> featuresList = this.features.get(eventId);
        if (!isEmbedding) {
            for (int i = 0; i < featuresList.size(); ++i) {
                float value = featuresList.get(i).floatValue();
                float w = this.dendrites.get(i + 1).floatValue();
                sigma = w * value + sigma;
            }
        } else {
            int index = this.embeddingIndex.get(eventId);
            sigma = featuresList.get(index).floatValue() * this.dendrites.get(index + 1).floatValue();
            this.embeddingIndex.remove(eventId);
        }
        return sigma - this.threshold;
    }

    private void initPower(boolean init) {
        Random random = new Random();
        if (this.upNub > 0) {
            float sh = (float)Math.sqrt(this.upNub);
            for (int i = 1; i < this.upNub + 1; ++i) {
                float nub = 0.0f;
                if (init) {
                    nub = random.nextFloat() / sh;
                }
                this.dendrites.put(i, Float.valueOf(nub));
            }
            float nub = 0.0f;
            if (init) {
                nub = random.nextFloat() / sh;
            }
            this.threshold = nub;
        }
    }

    public int getId() {
        return this.id;
    }

    public void connect(List<Nerve> nerveList) {
        this.son.addAll(nerveList);
    }

    public void connectOut(List<Nerve> nerveList) {
        this.rnnOut.addAll(nerveList);
    }

    public void connectFather(List<Nerve> nerveList) {
        this.father.addAll(nerveList);
    }
}

