当state_is_tuple = True时如何设置TensorFlow RNN状态? [英] How do I set TensorFlow RNN state when state_is_tuple=True?

查看:79
本文介绍了当state_is_tuple = True时如何设置TensorFlow RNN状态?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经使用TensorFlow编写了 RNN语言模型.该模型被实现为RNN类.图结构内置在构造函数中,而RNN.trainRNN.test方法运行它.

I have written an RNN language model using TensorFlow. The model is implemented as an RNN class. The graph structure is built in the constructor, while RNN.train and RNN.test methods run it.

当我移到训练集中的新文档时,或者当我想在训练期间运行验证集时,我希望能够重置RNN状态.我是通过在训练循环中管理状态,并通过供稿字典将其传递到图形中来实现的.

I want to be able to reset the RNN state when I move to a new document in the training set, or when I want to run a validation set during training. I do this by managing the state inside the training loop, passing it into the graph via a feed dictionary.

在构造函数中,我像这样定义RNN

In the constructor I define the the RNN like so

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

训练循环如下

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

xy是文档中的一批训练数据.这样做的想法是,每批处理后都会传递最新状态,除非启动新文档时是通过运行self.reset_state将状态清零的.

x and y are batches of training data in a document. The idea is that I pass the latest state along after each batch, except when I start a new document, when I zero out the state by running self.reset_state.

这一切有效.现在,我想更改我的RNN以使用推荐的state_is_tuple=True.但是,我不知道如何通过提要字典传递更复杂的LSTM状态对象.另外,我也不知道要将哪些参数传递给构造函数中的self.state = tf.placeholder(...)行.

This all works. Now I want to change my RNN to use the recommended state_is_tuple=True. However, I don't know how to pass the more complicated LSTM state object via a feed dictionary. Also I don't know what arguments to pass to the self.state = tf.placeholder(...) line in my constructor.

这里正确的策略是什么? dynamic_rnn的示例代码或文档仍然很少.

What is the correct strategy here? There still isn't much example code or documentation for dynamic_rnn available.

TensorFlow问题 2695

TensorFlow issues 2695 and 2838 appear relevant.

博客文章在WILDML上解决了这些问题,但没有直接说明答案.

A blog post on WILDML addresses these issues but doesn't directly spell out the answer.

另请参见 TensorFlow:记住下一批的LSTM状态(有状态的LSTM).

推荐答案

Tensorflow占位符的一个问题是,您只能使用Python列表或Numpy数组来提供它(我认为).因此,您无法在LSTMStateTuple的元组中保存两次运行之间的状态.

One problem with a Tensorflow placeholder is that you can only feed it with a Python list or Numpy array (I think). So you can't save the state between runs in tuples of LSTMStateTuple.

我通过将状态保存在这样的张量中来解决了这个问题

I solved this by saving the state in a tensor like this

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

您在LSTM层中有两个组件,单元状态隐藏状态,这就是"2"的含义. (这篇文章很棒: https://arxiv.org/pdf/1506.00019.pdf )

You have two components in an LSTM layer, the cell state and hidden state, thats what the "2" comes from. (this article is great: https://arxiv.org/pdf/1506.00019.pdf)

在构建图形时,您将解压缩并创建元组状态,如下所示:

When building the graph you unpack and create the tuple state like this:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后您以通常的方式获取新状态

Then you get the new state the usual way

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

不应该这样...也许他们正在研究解决方案.

It shouldn't be like this... perhaps they are working on a solution.

这篇关于当state_is_tuple = True时如何设置TensorFlow RNN状态?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆