Tensorflow RNN-LSTM - 重置隐藏状态 [英] Tensorflow RNN-LSTM - reset hidden state

查看:76
本文介绍了Tensorflow RNN-LSTM - 重置隐藏状态的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在构建一个用于语言识别的有状态 LSTM.作为有状态的,我可以用较小的文件训练网络,一个新的批次就像讨论中的下一句话.但是,为了使网络得到正确训练,我需要在某些批次之间重置 LSTM 的隐藏状态.

I'm building a statefull LSTM used for language recognition. Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a discussion. However for the network to be properly trained I need to reset the hidden state of the LSTM between some batches.

我正在使用一个变量来存储 LSTM 的 hidden_​​state 以提高性能:

I'm using a variable to store the hidden_state of the LSTM for performance :

    with tf.variable_scope('Hidden_state'):
        hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
                                       tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
        # Arrange it to a tuple of LSTMStateTuple as needed
        l = tf.unstack(hidden_state, axis=0)
        rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
                                for idx in range(self.num_layers)])

    # Build the RNN
    with tf.name_scope('LSTM'):
        rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
                                          initial_state=rnn_tuple_state, time_major=True)

现在我对如何重置隐藏状态感到困惑.我尝试了两种解决方案,但没有用:

Now I'm confused on how to reset the hidden state. I've tried two solutions but it's not working :

第一个解决方案

使用以下命令重置hidden_​​state"变量:

Reset the "hidden_state" variable with :

rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))

它确实有效,我认为这是因为在运行 rnn_state_zero_op 操作后,unstack 和元组构造没有重播"到图中.

It does work and I think it's because the unstack and tuple construction are not "re-played" into the graph after running the rnn_state_zero_op operation.

第二种解决方案

遵循 LSTMStateTuple 与 cell.zero_state() 对于Tensorflow 中的 RNN 我试图用 :

rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)

它似乎也不起作用.

问题

我想到了另一个解决方案,但充其量只是猜测:我没有保留 tf.nn.dynamic_rnn 返回的状态,我已经想到了,但是我得到了一个元组,但我找不到方法构建一个操作来重置元组.

I've another solution in mind but it's guessing at best : I'm not keeping the state returned by tf.nn.dynamic_rnn, I've thought of it but I get a tuple and I can't find a way to build an op to reset the tuple.

在这一点上,我必须承认我不太了解 tensorflow 的内部工作,以及是否有可能做我想做的事情.有什么合适的方法吗?

At this point I've to admit that I don't quite understand the internal working of tensorflow and if it's even possible to do what I'm trying to do. Is there a proper way to do it ?

谢谢!

推荐答案

感谢 这个回答另一个问题我能够找到一种方法来完全控制 RNN 的内部状态是否(以及何时)应该重置为 0.

Thanks to this answer to another question I was able to find a way to have complete control on whether or not (and when) the internal state of the RNN should be reset to 0.

首先你需要定义一些变量来存储 RNN 的状态,这样你就可以控制它:

First you need to define some variables to store the state of the RNN, this way you will have control over it :

with tf.variable_scope('Hidden_state'):
    state_variables = []
    for state_c, state_h in cell.zero_state(self.batch_size, tf.float32):
        state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    rnn_tuple_state = tuple(state_variables)

请注意,此版本直接定义了 LSTM 使用的变量,这比我的问题中的版本要好得多,因为您不必拆开和构建元组,这会向图中添加一些您无法运行的操作明确的.

Note that this version define directly the variables used by the LSTM, this is much better than the version in my question because you don't have to unstack and build the tuple, which add some ops to the graph that you cannot run explicitly.

第二次构建RNN并检索最终状态:

Secondly build the RNN and retrieve the final state :

# Build the RNN
with tf.name_scope('LSTM'):
    rnn_output, new_states = tf.nn.dynamic_rnn(cell, rnn_inputs,
                                               sequence_length=input_seq_lengths,
                                               initial_state=rnn_tuple_state,
                                               time_major=True)

所以现在你有了 RNN 的新内部状态.您可以定义两个操作来管理它.

So now you have the new internal state of the RNN. You can define two ops to manage it.

第一个将更新下一批的变量.因此,在下一批中,RNN 的initial_state"将使用上一批的最终状态:

The first one will update the variables for the next batch. So in the next batch the "initial_state" of the RNN will be fed with the final state of the previous batch :

# Define an op to keep the hidden state between batches
update_ops = []
for state_variable, new_state in zip(rnn_tuple_state, new_states):
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(new_state[0]),
                       state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_keep_state_op = tf.tuple(update_ops)

您应该在任何时候想要运行批处理并保持内部状态时将此操作添加到您的会话中.

You should add this op to your session anytime you want to run a batch and keep the internal state.

注意:如果您在调用此操作的情况下运行批处理 1,则批处理 2 将从批处理 1 的最终状态开始,但如果您在运行批处理 2 时不再次调用它,则批处理 3 将也从第 1 批最终状态开始.我的建议是每次运行 RNN 时都添加这个操作.

Beware : if you run batch 1 with this op called then batch 2 will start with the batch 1 final state, but if you don't call it again when running batch 2 then batch 3 will start with batch 1 final state also. My advice is to add this op every time you run the RNN.

第二个操作将用于将 RNN 的内部状态重置为零:

The second op will be used to reset the internal state of the RNN to zeros:

# Define an op to reset the hidden state to zeros
update_ops = []
for state_variable in rnn_tuple_state:
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(tf.zeros_like(state_variable[0])),
                       state_variable[1].assign(tf.zeros_like(state_variable[1]))])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_state_zero_op = tf.tuple(update_ops)

只要您想重置内部状态,就可以调用此操作.

You can call this op whenever you want to reset the internal state.

这篇关于Tensorflow RNN-LSTM - 重置隐藏状态的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆