state_is_tuple=True 时如何设置 TensorFlow RNN 状态? [英] How do I set TensorFlow RNN state when state_is_tuple=True?

查看:37
本文介绍了state_is_tuple=True 时如何设置 TensorFlow RNN 状态?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经使用 TensorFlow 编写了一个 RNN 语言模型.该模型被实现为一个 RNN 类.图结构构建在构造函数中,而RNN.trainRNN.test 方法运行它.

I have written an RNN language model using TensorFlow. The model is implemented as an RNN class. The graph structure is built in the constructor, while RNN.train and RNN.test methods run it.

当我移动到训练集中的新文档时,或者当我想在训练期间运行验证集时,我希望能够重置 RNN 状态.我通过管理训练循环内的状态,通过提要字典将其传递到图中来做到这一点.

I want to be able to reset the RNN state when I move to a new document in the training set, or when I want to run a validation set during training. I do this by managing the state inside the training loop, passing it into the graph via a feed dictionary.

在构造函数中,我像这样定义了 RNN

In the constructor I define the the RNN like so

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

训练循环看起来像这样

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

xy 是文档中的一批训练数据.这个想法是我在每批之后传递最新状态,除非当我开始一个新文档时,当我通过运行 self.reset_state 将状态归零时.

x and y are batches of training data in a document. The idea is that I pass the latest state along after each batch, except when I start a new document, when I zero out the state by running self.reset_state.

这一切都有效.现在我想更改我的 RNN 以使用推荐的 state_is_tuple=True.但是,我不知道如何通过提要字典传递更复杂的 LSTM 状态对象.此外,我不知道将哪些参数传递给构造函数中的 self.state = tf.placeholder(...) 行.

This all works. Now I want to change my RNN to use the recommended state_is_tuple=True. However, I don't know how to pass the more complicated LSTM state object via a feed dictionary. Also I don't know what arguments to pass to the self.state = tf.placeholder(...) line in my constructor.

这里的正确策略是什么?dynamic_rnn 的示例代码或文档仍然不多.

What is the correct strategy here? There still isn't much example code or documentation for dynamic_rnn available.

TensorFlow 问题 26952838 似乎相关.

TensorFlow issues 2695 and 2838 appear relevant.

A 博客文章 WILDML 解决了这些问题,但没有直接说明答案.

A blog post on WILDML addresses these issues but doesn't directly spell out the answer.

另见 TensorFlow:记住下一批的 LSTM 状态(有状态的 LSTM).

推荐答案

Tensorflow 占位符的一个问题是你只能用 Python 列表或 Numpy 数组(我认为)来提供它.所以你不能在 LSTMStateTuple 的元组中保存运行之间的状态.

One problem with a Tensorflow placeholder is that you can only feed it with a Python list or Numpy array (I think). So you can't save the state between runs in tuples of LSTMStateTuple.

我通过将状态保存在这样的张量中解决了这个问题

I solved this by saving the state in a tensor like this

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

在 LSTM 层中有两个组件,细胞状态隐藏状态,这就是2"的来源.(这篇文章很棒:https://arxiv.org/pdf/1506.00019.pdf)

You have two components in an LSTM layer, the cell state and hidden state, thats what the "2" comes from. (this article is great: https://arxiv.org/pdf/1506.00019.pdf)

在构建图形时,您解包并创建元组状态,如下所示:

When building the graph you unpack and create the tuple state like this:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后你以通常的方式获得新状态

Then you get the new state the usual way

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

不应该是这样……也许他们正在研究解决方案.

It shouldn't be like this... perhaps they are working on a solution.

这篇关于state_is_tuple=True 时如何设置 TensorFlow RNN 状态?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆