如何在Keras中的每个时间步从LSTM提取细胞状态? [英] How to extract cell state from a LSTM at each timestep in Keras?

查看:159
本文介绍了如何在Keras中的每个时间步从LSTM提取细胞状态?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在Keras中是否有一种方法可以在给定输入的每个时间步上检索LSTM层的单元状态(即 c 矢量)?

似乎 return_state 参数返回计算完成后的最后一个单元格状态,但我也需要中间状态.另外,我不想将这些单元状态传递给下一层,只希望能够访问它们.

最好使用TensorFlow作为后端.

谢谢

解决方案

我知道已经很晚了,希望对您有所帮助.

在技​​术上,您要问的是可以通过修改call方法中的LSTM单元来实现的.我修改了它,并在您赋予 return_sequences = True 时使它返回4维,而不是3维.

代码

从keras.layers.recurrent导入

 类Mod_LSTMCELL(LSTMCell):def调用(自身,输入,状态,训练=无):如果0<自我辍学<1并且self._dropout_mask为None:self._dropout_mask = _generate_dropout_mask(K.ones_like(输入),自我辍学训练=训练,数= 4)如果(0< self.recurrent_dropout< 1并且self._recurrent_dropout_mask为无):self._recurrent_dropout_mask = _generate_dropout_mask(K.ones_like(states [0]),self.recurrent_dropout,训练=训练,数= 4)#输入单元的辍学矩阵dp_mask = self._dropout_mask#递归单位的辍学矩阵rec_dp_mask =自我._recurrent_dropout_maskh_tm1 =状态[0]#先前的内存状态c_tm1 =状态[1]#先前的进位状态如果self.implementation == 1:如果0<自我辍学<1 .:inputs_i =输入* dp_mask [0]input_f =输入* dp_mask [1]input_c =输入* dp_mask [2]inputs_o =输入* dp_mask [3]别的:输入_i =输入input_f =输入input_c =输入input_o =输入x_i = K.dot(inputs_i,self.kernel_i)x_f = K.dot(inputs_f,self.kernel_f)x_c = K.dot(inputs_c,self.kernel_c)x_o = K.dot(inputs_o,self.kernel_o)如果self.use_bias:x_i = K.bias_add(x_i,self.bias_i)x_f = K.bias_add(x_f,self.bias_f)x_c = K.bias_add(x_c,self.bias_c)x_o = K.bias_add(x_o,self.bias_o)如果0<self.recurrent_dropout<1 .:h_tm1_i = h_tm1 * rec_dp_mask [0]h_tm1_f = h_tm1 * rec_dp_mask [1]h_tm1_c = h_tm1 * rec_dp_mask [2]h_tm1_o = h_tm1 * rec_dp_mask [3]别的:h_tm1_i = h_tm1h_tm1_f = h_tm1h_tm1_c = h_tm1h_tm1_o = h_tm1i = self.recurrent_activation(x_i + K.dot(h_tm1_i,self.recurrent_kernel_i))f = self.recurrent_activation(x_f + K.dot(h_tm1_f,self.recurrent_kernel_f))c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,self.recurrent_kernel_c))o = self.recurrent_activation(x_o + K.dot(h_tm1_o,self.recurrent_kernel_o))别的:如果为0.自我辍学<1 .:输入* = dp_mask [0]z = K.dot(输入,self.kernel)如果为0.self.recurrent_dropout<1 .:h_tm1 * = rec_dp_mask [0]z + = K.dot(h_tm1,self.recurrent_kernel)如果self.use_bias:z = K.bias_add(z,self.bias)z0 = z [:,:self.units]z1 = z [:, self.units:2 * self.units]z2 = z [:, 2 * self.units:3 * self.units]z3 = z [:, 3 * self.units:]我= self.recurrent_activation(z0)f = self.recurrent_activation(z1)c = f * c_tm1 + i * self.activation(z2)o = self.recurrent_activation(z3)h = o * self.activation(c)如果0<self.dropout + self.recurrent_dropout:如果培训为无":h._uses_learning_phase = True返回tf.expand_dims(tf.concat([h,c],axis = 0),0),[h,c] 

示例代码

 #创建一个单元格测试= Mod_LSTMCELL(100)#输入时间步长= 10,功能= 7in1 =输入(形状=(10,7))out1 = RNN(测试,return_sequences = True)(in1)M =模型(输入= [输入1],输出= [输出1])M.compile(keras.optimizers.Adam(),loss ='mse')ans = M.predict(np.arange(7 * 10,dtype = np.float32).reshape(1,10,7))打印(形状)#state_h打印(ans [0,0,0 ,:])#state_c打印(ans [0,0,1 ,:]) 

Is there a way in Keras to retrieve the cell state (i.e., c vector) of a LSTM layer at every timestep of a given input?

It seems the return_state argument returns the last cell state after the computation is done, but I need also the intermediate ones. Also, I don't want to pass these cell states to the next layer, I only want to be able to access them.

Preferably using TensorFlow as backend.

Thanks

解决方案

I know it's pretty late, I hope this can help.

what you are asking, technically, is possible by modifying the LSTM-cell in call method. I modify it and make it return 4 dimension instead of 3 when you give return_sequences=True.

Code

from keras.layers.recurrent import _generate_dropout_mask
class Mod_LSTMCELL(LSTMCell):
    def call(self, inputs, states, training=None):
        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=4)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(states[0]),
                self.recurrent_dropout,
                training=training,
                count=4)

            # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]
            z = K.dot(inputs, self.kernel)
            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]
            z += K.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)

        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return tf.expand_dims(tf.concat([h,c],axis=0),0), [h, c]

Sample code

# create a cell
test = Mod_LSTMCELL(100)

# Input timesteps=10, features=7
in1 = Input(shape=(10,7))
out1 = RNN(test, return_sequences=True)(in1)

M = Model(inputs=[in1],outputs=[out1])
M.compile(keras.optimizers.Adam(),loss='mse')

ans = M.predict(np.arange(7*10,dtype=np.float32).reshape(1, 10, 7))

print(ans.shape)
# state_h
print(ans[0,0,0,:])
# state_c
print(ans[0,0,1,:])

这篇关于如何在Keras中的每个时间步从LSTM提取细胞状态?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆