如何在Keras中的每个时间步从LSTM提取细胞状态? [英] How to extract cell state from a LSTM at each timestep in Keras?
问题描述
在Keras中是否有一种方法可以在给定输入的每个时间步上检索LSTM层的单元状态(即 c 矢量)?
似乎 return_state
参数返回计算完成后的最后一个单元格状态,但我也需要中间状态.另外,我不想将这些单元状态传递给下一层,只希望能够访问它们.
最好使用TensorFlow作为后端.
谢谢
我知道已经很晚了,希望对您有所帮助.
在技术上,您要问的是可以通过修改call方法中的LSTM单元来实现的.我修改了它,并在您赋予 return_sequences = True
时使它返回4维,而不是3维.代码
从keras.layers.recurrent导入 类Mod_LSTMCELL(LSTMCell):def调用(自身,输入,状态,训练=无):如果0<自我辍学<1并且self._dropout_mask为None:self._dropout_mask = _generate_dropout_mask(K.ones_like(输入),自我辍学训练=训练,数= 4)如果(0< self.recurrent_dropout< 1并且self._recurrent_dropout_mask为无):self._recurrent_dropout_mask = _generate_dropout_mask(K.ones_like(states [0]),self.recurrent_dropout,训练=训练,数= 4)#输入单元的辍学矩阵dp_mask = self._dropout_mask#递归单位的辍学矩阵rec_dp_mask =自我._recurrent_dropout_maskh_tm1 =状态[0]#先前的内存状态c_tm1 =状态[1]#先前的进位状态如果self.implementation == 1:如果0<自我辍学<1 .:inputs_i =输入* dp_mask [0]input_f =输入* dp_mask [1]input_c =输入* dp_mask [2]inputs_o =输入* dp_mask [3]别的:输入_i =输入input_f =输入input_c =输入input_o =输入x_i = K.dot(inputs_i,self.kernel_i)x_f = K.dot(inputs_f,self.kernel_f)x_c = K.dot(inputs_c,self.kernel_c)x_o = K.dot(inputs_o,self.kernel_o)如果self.use_bias:x_i = K.bias_add(x_i,self.bias_i)x_f = K.bias_add(x_f,self.bias_f)x_c = K.bias_add(x_c,self.bias_c)x_o = K.bias_add(x_o,self.bias_o)如果0<self.recurrent_dropout<1 .:h_tm1_i = h_tm1 * rec_dp_mask [0]h_tm1_f = h_tm1 * rec_dp_mask [1]h_tm1_c = h_tm1 * rec_dp_mask [2]h_tm1_o = h_tm1 * rec_dp_mask [3]别的:h_tm1_i = h_tm1h_tm1_f = h_tm1h_tm1_c = h_tm1h_tm1_o = h_tm1i = self.recurrent_activation(x_i + K.dot(h_tm1_i,self.recurrent_kernel_i))f = self.recurrent_activation(x_f + K.dot(h_tm1_f,self.recurrent_kernel_f))c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,self.recurrent_kernel_c))o = self.recurrent_activation(x_o + K.dot(h_tm1_o,self.recurrent_kernel_o))别的:如果为0.自我辍学<1 .:输入* = dp_mask [0]z = K.dot(输入,self.kernel)如果为0.self.recurrent_dropout<1 .:h_tm1 * = rec_dp_mask [0]z + = K.dot(h_tm1,self.recurrent_kernel)如果self.use_bias:z = K.bias_add(z,self.bias)z0 = z [:,:self.units]z1 = z [:, self.units:2 * self.units]z2 = z [:, 2 * self.units:3 * self.units]z3 = z [:, 3 * self.units:]我= self.recurrent_activation(z0)f = self.recurrent_activation(z1)c = f * c_tm1 + i * self.activation(z2)o = self.recurrent_activation(z3)h = o * self.activation(c)如果0<self.dropout + self.recurrent_dropout:如果培训为无":h._uses_learning_phase = True返回tf.expand_dims(tf.concat([h,c],axis = 0),0),[h,c]
示例代码
#创建一个单元格测试= Mod_LSTMCELL(100)#输入时间步长= 10,功能= 7in1 =输入(形状=(10,7))out1 = RNN(测试,return_sequences = True)(in1)M =模型(输入= [输入1],输出= [输出1])M.compile(keras.optimizers.Adam(),loss ='mse')ans = M.predict(np.arange(7 * 10,dtype = np.float32).reshape(1,10,7))打印(形状)#state_h打印(ans [0,0,0 ,:])#state_c打印(ans [0,0,1 ,:])
Is there a way in Keras to retrieve the cell state (i.e., c vector) of a LSTM layer at every timestep of a given input?
It seems the return_state
argument returns the last cell state after the computation is done, but I need also the intermediate ones. Also, I don't want to pass these cell states to the next layer, I only want to be able to access them.
Preferably using TensorFlow as backend.
Thanks
I know it's pretty late, I hope this can help.
what you are asking, technically, is possible by modifying the LSTM-cell in call method. I modify it and make it return 4 dimension instead of 3 when you give return_sequences=True
.
Code
from keras.layers.recurrent import _generate_dropout_mask
class Mod_LSTMCELL(LSTMCell):
def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
K.ones_like(inputs),
self.dropout,
training=training,
count=4)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
K.ones_like(states[0]),
self.recurrent_dropout,
training=training,
count=4)
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
rec_dp_mask = self._recurrent_dropout_mask
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
if self.implementation == 1:
if 0 < self.dropout < 1.:
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)
if self.use_bias:
x_i = K.bias_add(x_i, self.bias_i)
x_f = K.bias_add(x_f, self.bias_f)
x_c = K.bias_add(x_c, self.bias_c)
x_o = K.bias_add(x_o, self.bias_o)
if 0 < self.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
self.recurrent_kernel_o))
else:
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
z = K.dot(inputs, self.kernel)
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
z += K.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
z = K.bias_add(z, self.bias)
z0 = z[:, :self.units]
z1 = z[:, self.units: 2 * self.units]
z2 = z[:, 2 * self.units: 3 * self.units]
z3 = z[:, 3 * self.units:]
i = self.recurrent_activation(z0)
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout:
if training is None:
h._uses_learning_phase = True
return tf.expand_dims(tf.concat([h,c],axis=0),0), [h, c]
Sample code
# create a cell
test = Mod_LSTMCELL(100)
# Input timesteps=10, features=7
in1 = Input(shape=(10,7))
out1 = RNN(test, return_sequences=True)(in1)
M = Model(inputs=[in1],outputs=[out1])
M.compile(keras.optimizers.Adam(),loss='mse')
ans = M.predict(np.arange(7*10,dtype=np.float32).reshape(1, 10, 7))
print(ans.shape)
# state_h
print(ans[0,0,0,:])
# state_c
print(ans[0,0,1,:])
这篇关于如何在Keras中的每个时间步从LSTM提取细胞状态?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!