/*
 * Decompiled with CFR 0.152.
 */
package org.dromara.easyai.nerveEntity;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.nerveEntity.Nerve;
import org.dromara.easyai.nerveEntity.OutNerve;

public class SoftMax
extends Nerve {
    private final List<OutNerve> outNerves;
    private final boolean isShowLog;

    public SoftMax(int upNub, boolean isDynamic, List<OutNerve> outNerves, boolean isShowLog, int coreNumber) throws Exception {
        super(0, upNub, "softMax", 0, 0.0f, false, null, isDynamic, 0, 0.0f, 0, 0, 0, 0, coreNumber, 0, 0.0f, false);
        this.outNerves = outNerves;
        this.isShowLog = isShowLog;
    }

    @Override
    protected void input(long eventId, float parameter, boolean isStudy, Map<Integer, Float> E, OutBack outBack) throws Exception {
        boolean allReady = this.insertParameter(eventId, parameter);
        if (allReady) {
            Mes mes = this.softMax(eventId);
            int key = 0;
            if (isStudy) {
                for (Map.Entry<Integer, Float> entry : E.entrySet()) {
                    if (!((double)entry.getValue().floatValue() > 0.9)) continue;
                    key = entry.getKey();
                    break;
                }
                if (this.isShowLog) {
                    System.out.println("softMax==" + key + ",out==" + mes.poi + ",nerveId==" + mes.typeID);
                }
                List<Float> errors = this.error(mes, key);
                this.features.remove(eventId);
                int size = this.outNerves.size();
                for (int i = 0; i < size; ++i) {
                    this.outNerves.get(i).getGBySoftMax(errors.get(i).floatValue(), eventId);
                }
            } else {
                this.destoryParameter(eventId);
                if (outBack != null) {
                    outBack.getBack(mes.poi, mes.typeID, eventId);
                    outBack.getSoftMaxBack(eventId, mes.softMax);
                } else {
                    throw new Exception("not find outBack");
                }
            }
        }
    }

    private List<Float> error(Mes mes, int key) {
        int t = key - 1;
        List<Float> softMax = mes.softMax;
        ArrayList<Float> error = new ArrayList<Float>();
        for (int i = 0; i < softMax.size(); ++i) {
            float self = softMax.get(i).floatValue();
            float myError = i != t ? -self : 1.0f - self;
            error.add(Float.valueOf(myError));
        }
        return error;
    }

    private Mes softMax(long eventId) {
        float sigma = 0.0f;
        int id = 0;
        float poi = 0.0f;
        Mes mes = new Mes();
        List featuresList = (List)this.features.get(eventId);
        Iterator iterator = featuresList.iterator();
        while (iterator.hasNext()) {
            float value = ((Float)iterator.next()).floatValue();
            sigma = (float)Math.exp(value) + sigma;
        }
        ArrayList<Float> softMax = new ArrayList<Float>();
        for (int i = 0; i < featuresList.size(); ++i) {
            float eSelf = (float)Math.exp(((Float)featuresList.get(i)).floatValue());
            float value = eSelf / sigma;
            softMax.add(Float.valueOf(value));
            if (!(value > poi)) continue;
            poi = value;
            id = i + 1;
        }
        mes.softMax = softMax;
        mes.typeID = id;
        mes.poi = poi;
        return mes;
    }

    static class Mes {
        int typeID;
        float poi;
        List<Float> softMax;

        Mes() {
        }
    }
}

