在Keras中定义自定义LSTM Cell? [英] Define custom LSTM Cell in Keras?

查看:648
本文介绍了在Keras中定义自定义LSTM Cell?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我将Keras与TensorFlow一起用作后端.如果我想对LSTM单元进行修改,例如移除"输出门,该怎么办?这是一个乘法门,因此无论如何我都必须将其设置为固定值,这样无论乘以它,都不会起作用.

I use Keras with TensorFlow as back-end. If I want to make a modification to an LSTM cell, such as "removing" the output gate, how can I do it? It is a multiplicative gate, so somehow I will have to set it to fixed values so that whatever multiplies it, has no effect.

推荐答案

首先,您应该定义

First of all, you should define your own custom layer. If you need some intuition how to implement your own cell see LSTMCell in Keras repository. E.g. your custom cell will be:

class MinimalRNNCell(keras.layers.Layer):

    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = K.dot(inputs, self.kernel)
        output = h + K.dot(prev_output, self.recurrent_kernel)
        return output, [output]

然后,使用 tf.keras.layers.RNN 使用您的单元格:

Then, use tf.keras.layers.RNN to use your cell:

cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)

# Here's how to use the cell to build a stacked RNN:

cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)

这篇关于在Keras中定义自定义LSTM Cell?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆