/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.rnnNerveCenter;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.rnnNerveCenter.ModelParameter;
import org.dromara.easyai.rnnNerveCenter.RnnOutNerveStudy;
import org.dromara.easyai.rnnNerveEntity.HiddenNerve;
import org.dromara.easyai.rnnNerveEntity.Nerve;
import org.dromara.easyai.rnnNerveEntity.NerveStudy;
import org.dromara.easyai.rnnNerveEntity.OutNerve;
import org.dromara.easyai.rnnNerveEntity.RnnOutNerveBody;
import org.dromara.easyai.rnnNerveEntity.SensoryNerve;
import org.dromara.easyai.rnnNerveEntity.SoftMax;

public class NerveManager {
    private final int hiddenNerveNub;
    private final int sensoryNerveNub;
    private final int outNerveNub;
    private final int hiddenDepth;
    private final List<SensoryNerve> sensoryNerves = new ArrayList<SensoryNerve>();
    private final List<List<Nerve>> depthNerves = new ArrayList<List<Nerve>>();
    private final List<Nerve> outNerves = new ArrayList<Nerve>();
    private final List<RnnOutNerveBody> rnnOutNerveBodies = new ArrayList<RnnOutNerveBody>();
    private boolean initPower;
    private float studyPoint = 0.1f;
    private final ActiveFunction activeFunction;
    private boolean isRnn = false;
    private List<Float> studyList = new ArrayList<Float>();
    private final int rzType;
    private final float lParam;
    private boolean isSoftMax = true;

    public void setSoftMax(boolean softMax) {
        this.isSoftMax = softMax;
    }

    public List<Float> getStudyList() {
        return this.studyList;
    }

    public void setStudyList(List<Float> studyList) {
        this.studyList = studyList;
    }

    private Map<String, Float> conversion(Map<Integer, Float> map) {
        HashMap<String, Float> cMap = new HashMap<String, Float>();
        for (Map.Entry<Integer, Float> entry : map.entrySet()) {
            cMap.put(String.valueOf(entry.getKey()), entry.getValue());
        }
        return cMap;
    }

    private Map<Integer, Float> unConversion(Map<String, Float> map) {
        HashMap<Integer, Float> cMap = new HashMap<Integer, Float>();
        for (Map.Entry<String, Float> entry : map.entrySet()) {
            cMap.put(Integer.parseInt(entry.getKey()), entry.getValue());
        }
        return cMap;
    }

    public ModelParameter getModelParameter() throws Exception {
        if (this.isRnn) {
            return this.getRnnModelParameter();
        }
        return this.getStaticModelParameter();
    }

    private ModelParameter getRnnModelParameter() {
        ModelParameter modelParameter = new ModelParameter();
        ArrayList<List<NerveStudy>> studyDepthNerves = new ArrayList<List<NerveStudy>>();
        ArrayList<RnnOutNerveStudy> rnnOutNerveStudies = new ArrayList<RnnOutNerveStudy>();
        modelParameter.setDepthNerves(studyDepthNerves);
        modelParameter.setRnnOutNerveStudies(rnnOutNerveStudies);
        this.getHiddenNerveModel(studyDepthNerves);
        for (RnnOutNerveBody rnnOutNerveBody : this.rnnOutNerveBodies) {
            ArrayList<NerveStudy> nerveStudies = new ArrayList<NerveStudy>();
            RnnOutNerveStudy rnnOutNerveStudy = new RnnOutNerveStudy();
            rnnOutNerveStudies.add(rnnOutNerveStudy);
            rnnOutNerveStudy.setDepth(rnnOutNerveBody.getDepth());
            rnnOutNerveStudy.setNerveStudies(nerveStudies);
            List<Nerve> outNerveList = rnnOutNerveBody.getOutNerves();
            this.getRnnOutNerveModel(nerveStudies, outNerveList);
        }
        return modelParameter;
    }

    private void getRnnOutNerveModel(List<NerveStudy> nerveStudies, List<Nerve> outNerveList) {
        for (Nerve nerve : outNerveList) {
            NerveStudy nerveStudy = new NerveStudy();
            nerveStudy.setThreshold(nerve.getThreshold());
            nerveStudy.setDendrites(this.conversion(nerve.getDendrites()));
            nerveStudies.add(nerveStudy);
        }
    }

    private void getHiddenNerveModel(List<List<NerveStudy>> studyDepthNerves) {
        for (List<Nerve> depthNerve : this.depthNerves) {
            ArrayList<NerveStudy> deepNerve = new ArrayList<NerveStudy>();
            for (Nerve nerve : depthNerve) {
                NerveStudy nerveStudy = new NerveStudy();
                nerveStudy.setThreshold(nerve.getThreshold());
                nerveStudy.setDendrites(this.conversion(nerve.getDendrites()));
                deepNerve.add(nerveStudy);
            }
            studyDepthNerves.add(deepNerve);
        }
    }

    private ModelParameter getStaticModelParameter() {
        ModelParameter modelParameter = new ModelParameter();
        ArrayList<List<NerveStudy>> studyDepthNerves = new ArrayList<List<NerveStudy>>();
        ArrayList<NerveStudy> outStudyNerves = new ArrayList<NerveStudy>();
        this.getHiddenNerveModel(studyDepthNerves);
        this.getRnnOutNerveModel(outStudyNerves, this.outNerves);
        modelParameter.setDepthNerves(studyDepthNerves);
        modelParameter.setOutNerves(outStudyNerves);
        return modelParameter;
    }

    public void insertModelParameter(ModelParameter modelParameter) throws Exception {
        if (this.isRnn) {
            this.insertRnnModelParameter(modelParameter);
        } else {
            this.insertBpModelParameter(modelParameter);
        }
    }

    private void insertRnnModelParameter(ModelParameter modelParameter) {
        List<List<NerveStudy>> depthStudyNerves = modelParameter.getDepthNerves();
        List<RnnOutNerveStudy> rnnOutNerveStudies = modelParameter.getRnnOutNerveStudies();
        for (int i = 0; i < this.depthNerves.size(); ++i) {
            List<NerveStudy> depth = depthStudyNerves.get(i);
            List<Nerve> depthNerve = this.depthNerves.get(i);
            for (int j = 0; j < depthNerve.size(); ++j) {
                Nerve nerve = depthNerve.get(j);
                NerveStudy nerveStudy = depth.get(j);
                Map<Integer, Float> studyDendrites = this.unConversion(nerveStudy.getDendrites());
                Map<Integer, Float> dendrites = nerve.getDendrites();
                nerve.setThreshold(nerveStudy.getThreshold());
                for (Map.Entry<Integer, Float> entry : dendrites.entrySet()) {
                    int key = entry.getKey();
                    dendrites.put(key, studyDendrites.get(key));
                }
            }
        }
        for (RnnOutNerveStudy rnnOutNerveStudy : rnnOutNerveStudies) {
            RnnOutNerveBody rnnOutNerveBody = this.getRnnOutNerveBody(rnnOutNerveStudy.getDepth());
            List<NerveStudy> outStudyNerves = rnnOutNerveStudy.getNerveStudies();
            List<Nerve> outNerveBody = rnnOutNerveBody.getOutNerves();
            for (int i = 0; i < outNerveBody.size(); ++i) {
                Nerve outNerve = outNerveBody.get(i);
                NerveStudy nerveStudy = outStudyNerves.get(i);
                outNerve.setThreshold(nerveStudy.getThreshold());
                Map<Integer, Float> dendrites = outNerve.getDendrites();
                Map<Integer, Float> studyDendrites = this.unConversion(nerveStudy.getDendrites());
                for (Map.Entry<Integer, Float> outEntry : dendrites.entrySet()) {
                    int key = outEntry.getKey();
                    dendrites.put(key, studyDendrites.get(key));
                }
            }
        }
    }

    private RnnOutNerveBody getRnnOutNerveBody(int depth) {
        RnnOutNerveBody myRnnOutNerveBody = null;
        for (RnnOutNerveBody rnnOutNerveBody : this.rnnOutNerveBodies) {
            if (rnnOutNerveBody.getDepth() != depth) continue;
            myRnnOutNerveBody = rnnOutNerveBody;
            break;
        }
        return myRnnOutNerveBody;
    }

    private void insertBpModelParameter(ModelParameter modelParameter) {
        int i;
        List<List<NerveStudy>> depthStudyNerves = modelParameter.getDepthNerves();
        List<NerveStudy> outStudyNerves = modelParameter.getOutNerves();
        for (i = 0; i < this.depthNerves.size(); ++i) {
            List<NerveStudy> depth = depthStudyNerves.get(i);
            List<Nerve> depthNerve = this.depthNerves.get(i);
            for (int j = 0; j < depthNerve.size(); ++j) {
                Nerve nerve = depthNerve.get(j);
                NerveStudy nerveStudy = depth.get(j);
                Map<Integer, Float> studyDendrites = this.unConversion(nerveStudy.getDendrites());
                Map<Integer, Float> dendrites = nerve.getDendrites();
                nerve.setThreshold(nerveStudy.getThreshold());
                for (Map.Entry<Integer, Float> entry : dendrites.entrySet()) {
                    int key = entry.getKey();
                    dendrites.put(key, studyDendrites.get(key));
                }
            }
        }
        for (i = 0; i < this.outNerves.size(); ++i) {
            Nerve outNerve = this.outNerves.get(i);
            NerveStudy nerveStudy = outStudyNerves.get(i);
            outNerve.setThreshold(nerveStudy.getThreshold());
            Map<Integer, Float> dendrites = outNerve.getDendrites();
            Map<Integer, Float> studyDendrites = this.unConversion(nerveStudy.getDendrites());
            for (Map.Entry<Integer, Float> outEntry : dendrites.entrySet()) {
                int key = outEntry.getKey();
                dendrites.put(key, studyDendrites.get(key));
            }
        }
    }

    public NerveManager(int sensoryNerveNub, int hiddenNerveNub, int outNerveNub, int hiddenDepth, ActiveFunction activeFunction, float studyPoint, int rzType, float lParam) throws Exception {
        if (sensoryNerveNub > 0 && hiddenNerveNub > 0 && outNerveNub > 0 && hiddenDepth > 0 && activeFunction != null) {
            this.hiddenNerveNub = hiddenNerveNub;
            this.sensoryNerveNub = sensoryNerveNub;
            this.outNerveNub = outNerveNub;
            this.hiddenDepth = hiddenDepth;
            this.activeFunction = activeFunction;
            this.rzType = rzType;
            this.lParam = lParam;
            if (studyPoint > 0.0f && studyPoint < 1.0f) {
                this.studyPoint = studyPoint;
            }
        } else {
            throw new Exception("param is null");
        }
    }

    public List<SensoryNerve> getSensoryNerves() {
        return this.sensoryNerves;
    }

    public void init(boolean initPower, boolean isShowLog, boolean isSoftMax) throws Exception {
        this.initPower = initPower;
        this.initDepthNerve();
        List<Nerve> nerveList = this.depthNerves.get(0);
        List<Nerve> lastNerveList = this.depthNerves.get(this.depthNerves.size() - 1);
        ArrayList<OutNerve> outNerveList = new ArrayList<OutNerve>();
        for (int i = 1; i < this.outNerveNub + 1; ++i) {
            OutNerve outNerve = new OutNerve(i, this.hiddenNerveNub, 0, this.studyPoint, initPower, this.activeFunction, isShowLog, this.rzType, this.lParam, isSoftMax);
            outNerve.connectFather(lastNerveList);
            this.outNerves.add(outNerve);
            outNerveList.add(outNerve);
        }
        this.createSoftMax(isShowLog, isSoftMax, outNerveList, this.outNerves);
        for (Nerve nerve : lastNerveList) {
            nerve.connect(this.outNerves);
        }
        for (int i = 1; i < this.sensoryNerveNub + 1; ++i) {
            SensoryNerve sensoryNerve = new SensoryNerve(i, 0);
            sensoryNerve.connect(nerveList);
            this.sensoryNerves.add(sensoryNerve);
        }
    }

    private void createSoftMax(boolean isShowLog, boolean isSoftMax, List<OutNerve> outNerveList, List<Nerve> outNerves) throws Exception {
        if (isSoftMax) {
            ArrayList<Nerve> softMaxList = new ArrayList<Nerve>();
            SoftMax softMax = new SoftMax(this.outNerveNub, outNerveList, isShowLog);
            softMaxList.add(softMax);
            for (Nerve nerve : outNerves) {
                nerve.connect(softMaxList);
            }
        }
    }

    private void createRnnOutNerve(boolean initPower, boolean isShowLog, List<Nerve> nerveList, int depth) throws Exception {
        RnnOutNerveBody rnnOutNerveBody = new RnnOutNerveBody();
        ArrayList<Nerve> rnnOutNerves = new ArrayList<Nerve>();
        ArrayList<OutNerve> outNerveList = new ArrayList<OutNerve>();
        rnnOutNerveBody.setDepth(depth);
        rnnOutNerveBody.setOutNerves(rnnOutNerves);
        for (int i = 1; i < this.outNerveNub + 1; ++i) {
            OutNerve outNerve = new OutNerve(i, this.hiddenNerveNub, 0, this.studyPoint, initPower, this.activeFunction, isShowLog, this.rzType, this.lParam, this.isSoftMax);
            outNerve.connectFather(nerveList);
            rnnOutNerves.add(outNerve);
            outNerveList.add(outNerve);
        }
        this.createSoftMax(isShowLog, this.isSoftMax, outNerveList, rnnOutNerves);
        for (Nerve nerve : nerveList) {
            nerve.connectOut(rnnOutNerves);
        }
        this.rnnOutNerveBodies.add(rnnOutNerveBody);
    }

    public void initRnn(boolean initPower, boolean isShowLog) throws Exception {
        this.isRnn = true;
        this.initPower = initPower;
        this.initDepthNerve();
        for (int i = 0; i < this.depthNerves.size(); ++i) {
            this.createRnnOutNerve(initPower, isShowLog, this.depthNerves.get(i), i + 1);
        }
        List<Nerve> nerveList = this.depthNerves.get(0);
        for (int i = 1; i < this.sensoryNerveNub + 1; ++i) {
            SensoryNerve sensoryNerve = new SensoryNerve(i, 0);
            sensoryNerve.connect(nerveList);
            this.sensoryNerves.add(sensoryNerve);
        }
    }

    private void initDepthNerve() throws Exception {
        for (int i = 0; i < this.hiddenDepth; ++i) {
            ArrayList<HiddenNerve> hiddenNerveList = new ArrayList<HiddenNerve>();
            float studyPoint = this.studyPoint;
            if (this.studyList.contains(i)) {
                studyPoint = this.studyList.get(i).floatValue();
            }
            if (studyPoint <= 0.0f || studyPoint > 1.0f) {
                throw new Exception("studyPoint Values range from 0 to 1");
            }
            for (int j = 1; j < this.hiddenNerveNub + 1; ++j) {
                int upNub = i == 0 ? this.sensoryNerveNub : this.hiddenNerveNub;
                int downNub = i == this.hiddenDepth - 1 ? this.outNerveNub : this.hiddenNerveNub;
                HiddenNerve hiddenNerve = new HiddenNerve(j, i + 1, upNub, downNub, studyPoint, this.initPower, this.activeFunction, this.rzType, this.lParam, this.outNerveNub);
                hiddenNerveList.add(hiddenNerve);
            }
            this.depthNerves.add(hiddenNerveList);
        }
        this.initHiddenNerve();
    }

    private void initHiddenNerve() {
        for (int i = 0; i < this.hiddenDepth - 1; ++i) {
            List<Nerve> hiddenNerveList = this.depthNerves.get(i);
            List<Nerve> nextHiddenNerveList = this.depthNerves.get(i + 1);
            for (Nerve hiddenNerve : hiddenNerveList) {
                hiddenNerve.connect(nextHiddenNerveList);
            }
            for (Nerve nextHiddenNerve : nextHiddenNerveList) {
                nextHiddenNerve.connectFather(hiddenNerveList);
            }
        }
    }
}

