/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLstm
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLstm.class);
    public static final String LAYER_FIELD_INNER_INIT = "inner_init";
    public static final String LAYER_FIELD_INNER_ACTIVATION = "inner_activation";
    public static final String LAYER_FIELD_FORGET_BIAS_INIT = "forget_bias_init";
    public static final String LAYER_FIELD_DROPOUT_U = "dropout_U";
    public static final String LAYER_FIELD_UNROLL = "unroll";
    public static final String LSTM_FORGET_BIAS_INIT_ZERO = "zero";
    public static final String LSTM_FORGET_BIAS_INIT_ONE = "one";
    public static final int NUM_TRAINABLE_PARAMS = 12;
    public static final String KERAS_PARAM_NAME_W_C = "W_c";
    public static final String KERAS_PARAM_NAME_W_F = "W_f";
    public static final String KERAS_PARAM_NAME_W_I = "W_i";
    public static final String KERAS_PARAM_NAME_W_O = "W_o";
    public static final String KERAS_PARAM_NAME_U_C = "U_c";
    public static final String KERAS_PARAM_NAME_U_F = "U_f";
    public static final String KERAS_PARAM_NAME_U_I = "U_i";
    public static final String KERAS_PARAM_NAME_U_O = "U_o";
    public static final String KERAS_PARAM_NAME_B_C = "b_c";
    public static final String KERAS_PARAM_NAME_B_F = "b_f";
    public static final String KERAS_PARAM_NAME_B_I = "b_i";
    public static final String KERAS_PARAM_NAME_B_O = "b_o";
    public static final int NUM_WEIGHTS_IN_KERAS_LSTM = 12;
    protected boolean unroll = false;

    public KerasLstm(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true);
    }

    public KerasLstm(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(layerConfig, enforceTrainingConfig);
        WeightInit weightInit = this.getWeightInitFromConfig(layerConfig, enforceTrainingConfig);
        WeightInit recurrentWeightInit = KerasLstm.getRecurrentWeightInitFromConfig(layerConfig, enforceTrainingConfig);
        if (weightInit != recurrentWeightInit) {
            if (enforceTrainingConfig) {
                throw new UnsupportedKerasConfigurationException("Specifying different initialization for recurrent weights not supported.");
            }
            log.warn("Specifying different initialization for recurrent weights not supported.");
        }
        KerasLstm.getRecurrentDropout(layerConfig);
        this.unroll = KerasLstm.getUnrollRecurrentLayer(layerConfig);
        this.layer = ((GravesLSTM.Builder)((GravesLSTM.Builder)((GravesLSTM.Builder)((GravesLSTM.Builder)((GravesLSTM.Builder)((GravesLSTM.Builder)((GravesLSTM.Builder)((GravesLSTM.Builder)new GravesLSTM.Builder().gateActivationFunction(KerasLstm.getGateActivationFromConfig(layerConfig)).forgetGateBiasInit(KerasLstm.getForgetBiasInitFromConfig(layerConfig, enforceTrainingConfig)).name(this.layerName)).nOut(KerasLstm.getNOutFromConfig(layerConfig))).dropOut(this.dropout)).activation(this.getActivationFromConfig(layerConfig))).weightInit(weightInit)).biasInit(0.0)).l1(this.weightL1Regularization)).l2(this.weightL2Regularization)).build();
    }

    public GravesLSTM getGravesLSTMLayer() {
        return (GravesLSTM)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one input (received " + inputType.length + ")");
        }
        return this.getGravesLSTMLayer().getOutputType(-1, inputType[0]);
    }

    @Override
    public int getNumParams() {
        return 12;
    }

    @Override
    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!weights.containsKey(KERAS_PARAM_NAME_W_C)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_c");
        }
        INDArray W_c = weights.get(KERAS_PARAM_NAME_W_C);
        if (!weights.containsKey(KERAS_PARAM_NAME_W_F)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_f");
        }
        INDArray W_f = weights.get(KERAS_PARAM_NAME_W_F);
        if (!weights.containsKey(KERAS_PARAM_NAME_W_O)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_o");
        }
        INDArray W_o = weights.get(KERAS_PARAM_NAME_W_O);
        if (!weights.containsKey(KERAS_PARAM_NAME_W_I)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter W_i");
        }
        INDArray W_i = weights.get(KERAS_PARAM_NAME_W_I);
        INDArray W = Nd4j.zeros((int)W_c.rows(), (int)(W_c.columns() + W_f.columns() + W_o.columns() + W_i.columns()));
        W.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)W.rows()), NDArrayIndex.interval((int)0, (int)W_c.columns())}, W_c);
        W.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)W.rows()), NDArrayIndex.interval((int)W_c.columns(), (int)(W_c.columns() + W_f.columns()))}, W_f);
        W.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)W.rows()), NDArrayIndex.interval((int)(W_c.columns() + W_f.columns()), (int)(W_c.columns() + W_f.columns() + W_o.columns()))}, W_o);
        W.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)W.rows()), NDArrayIndex.interval((int)(W_c.columns() + W_f.columns() + W_o.columns()), (int)(W_c.columns() + W_f.columns() + W_o.columns() + W_i.columns()))}, W_i);
        this.weights.put("W", W);
        if (!weights.containsKey(KERAS_PARAM_NAME_U_C)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_c");
        }
        INDArray U_c = weights.get(KERAS_PARAM_NAME_U_C);
        if (!weights.containsKey(KERAS_PARAM_NAME_U_F)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_f");
        }
        INDArray U_f = weights.get(KERAS_PARAM_NAME_U_F);
        if (!weights.containsKey(KERAS_PARAM_NAME_U_O)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_o");
        }
        INDArray U_o = weights.get(KERAS_PARAM_NAME_U_O);
        if (!weights.containsKey(KERAS_PARAM_NAME_U_I)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter U_i");
        }
        INDArray U_i = weights.get(KERAS_PARAM_NAME_U_I);
        INDArray U = Nd4j.zeros((int)U_c.rows(), (int)(U_c.columns() + U_f.columns() + U_o.columns() + U_i.columns() + 3));
        U.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)U.rows()), NDArrayIndex.interval((int)0, (int)U_c.columns())}, U_c);
        U.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)U.rows()), NDArrayIndex.interval((int)U_c.columns(), (int)(U_c.columns() + U_f.columns()))}, U_f);
        U.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)U.rows()), NDArrayIndex.interval((int)(U_c.columns() + U_f.columns()), (int)(U_c.columns() + U_f.columns() + U_o.columns()))}, U_o);
        U.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)U.rows()), NDArrayIndex.interval((int)(U_c.columns() + U_f.columns() + U_o.columns()), (int)(U_c.columns() + U_f.columns() + U_o.columns() + U_i.columns()))}, U_i);
        this.weights.put("RW", U);
        if (!weights.containsKey(KERAS_PARAM_NAME_B_C)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_c");
        }
        INDArray b_c = weights.get(KERAS_PARAM_NAME_B_C);
        if (!weights.containsKey(KERAS_PARAM_NAME_B_F)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_f");
        }
        INDArray b_f = weights.get(KERAS_PARAM_NAME_B_F);
        if (!weights.containsKey(KERAS_PARAM_NAME_B_O)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_o");
        }
        INDArray b_o = weights.get(KERAS_PARAM_NAME_B_O);
        if (!weights.containsKey(KERAS_PARAM_NAME_B_I)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer does not contain parameter b_i");
        }
        INDArray b_i = weights.get(KERAS_PARAM_NAME_B_I);
        INDArray b = Nd4j.zeros((int)b_c.rows(), (int)(b_c.columns() + b_f.columns() + b_o.columns() + b_i.columns()));
        b.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)b.rows()), NDArrayIndex.interval((int)0, (int)b_c.columns())}, b_c);
        b.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)b.rows()), NDArrayIndex.interval((int)b_c.columns(), (int)(b_c.columns() + b_f.columns()))}, b_f);
        b.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)b.rows()), NDArrayIndex.interval((int)(b_c.columns() + b_f.columns()), (int)(b_c.columns() + b_f.columns() + b_o.columns()))}, b_o);
        b.put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)b.rows()), NDArrayIndex.interval((int)(b_c.columns() + b_f.columns() + b_o.columns()), (int)(b_c.columns() + b_f.columns() + b_o.columns() + b_i.columns()))}, b_i);
        this.weights.put("b", b);
        if (weights.size() > 12) {
            Set<String> paramNames = weights.keySet();
            paramNames.remove(KERAS_PARAM_NAME_W_C);
            paramNames.remove(KERAS_PARAM_NAME_W_F);
            paramNames.remove(KERAS_PARAM_NAME_W_I);
            paramNames.remove(KERAS_PARAM_NAME_W_O);
            paramNames.remove(KERAS_PARAM_NAME_U_C);
            paramNames.remove(KERAS_PARAM_NAME_U_F);
            paramNames.remove(KERAS_PARAM_NAME_U_I);
            paramNames.remove(KERAS_PARAM_NAME_U_O);
            paramNames.remove(KERAS_PARAM_NAME_B_C);
            paramNames.remove(KERAS_PARAM_NAME_B_F);
            paramNames.remove(KERAS_PARAM_NAME_B_I);
            paramNames.remove(KERAS_PARAM_NAME_B_O);
            String unknownParamNames = paramNames.toString();
            log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
        }
    }

    public boolean getUnroll() {
        return this.unroll;
    }

    public static boolean getUnrollRecurrentLayer(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLstm.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_UNROLL)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer config missing unroll field");
        }
        return (Boolean)innerConfig.get(LAYER_FIELD_UNROLL);
    }

    public static WeightInit getRecurrentWeightInitFromConfig(Map<String, Object> layerConfig, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        WeightInit init;
        Map<String, Object> innerConfig = KerasLstm.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_INNER_INIT)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer config missing inner_init field");
        }
        String kerasInit = (String)innerConfig.get(LAYER_FIELD_INNER_INIT);
        try {
            init = KerasLstm.mapWeightInitialization(kerasInit);
        }
        catch (UnsupportedKerasConfigurationException e) {
            if (train) {
                throw e;
            }
            init = WeightInit.XAVIER;
            log.warn("Unknown weight initializer " + kerasInit + " (Using XAVIER instead).");
        }
        return init;
    }

    public static double getRecurrentDropout(Map<String, Object> layerConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLstm.getInnerLayerConfigFromConfig(layerConfig);
        double dropout = 1.0;
        if (innerConfig.containsKey(LAYER_FIELD_DROPOUT_U)) {
            dropout = 1.0 - (Double)innerConfig.get(LAYER_FIELD_DROPOUT_U);
        }
        if (dropout < 1.0) {
            throw new UnsupportedKerasConfigurationException("Dropout > 0 on LSTM recurrent connections not supported.");
        }
        return dropout;
    }

    public static IActivation getGateActivationFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLstm.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_INNER_ACTIVATION)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer config missing inner_activation field");
        }
        return KerasLstm.mapActivation((String)innerConfig.get(LAYER_FIELD_INNER_ACTIVATION));
    }

    public static double getForgetBiasInitFromConfig(Map<String, Object> layerConfig, boolean train) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLstm.getInnerLayerConfigFromConfig(layerConfig);
        if (!innerConfig.containsKey(LAYER_FIELD_FORGET_BIAS_INIT)) {
            throw new InvalidKerasConfigurationException("Keras LSTM layer config missing forget_bias_init field");
        }
        String kerasForgetBiasInit = (String)innerConfig.get(LAYER_FIELD_FORGET_BIAS_INIT);
        double init = 0.0;
        switch (kerasForgetBiasInit) {
            case "zero": {
                init = 0.0;
                break;
            }
            case "one": {
                init = 1.0;
                break;
            }
            default: {
                if (train) {
                    throw new UnsupportedKerasConfigurationException("Unsupported LSTM forget gate bias initialization: " + kerasForgetBiasInit);
                }
                init = 1.0;
                log.warn("Unsupported LSTM forget gate bias initialization: " + kerasForgetBiasInit + " (using 1 instead)");
            }
        }
        return init;
    }
}

