Tensorflow-批次内的LSTM状态重用 [英] Tensorflow - LSTM state reuse within batch

查看:169
本文介绍了Tensorflow-批次内的LSTM状态重用的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在研究使用LSTM跟踪参数(时间序列数据回归问题)的Tensorflow NN.一批训练数据包含连续观测值的batch_size.我想将LSTM状态用作下一个示例的输入.因此,如果我有一批数据观测值,我想将第一个观测值的状态作为第二个观测值的输入,依此类推.在下面,我将lstm状态定义为size = batch_size的张量.我想在 批内重用状态:

I am working on a Tensorflow NN which uses an LSTM to track a parameter (time series data regression problem). A batch of training data contains a batch_size of consecutive observations. I would like to use the LSTM state as input to the next sample. So, if I have a batch of data observations, I would like to feed the state of the first observation as input to the second observation and so on. Below I define the lstm state as a tensor of size = batch_size. I would like to reuse the state within a batch:

state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
cell = tf.nn.rnn_cell.BasicLSTMCell(100)
output, curr_state = tf.nn.rnn(cell, data, initial_state=state) 

API中有一个tf.nn.state_saving_rnn,但是文档有点含糊. 我的问题:如何在训练批次中重复使用curr_state .

In the API there is a tf.nn.state_saving_rnn but the documentation is kinda vague. My question: How to reuse curr_state within a training batch.

推荐答案

您基本上就在那儿,只需要用curr_state更新state:

You are basically there, just need to update state with curr_state:

state_update = tf.assign(state, curr_state)

然后,确保您在state_update本身上调用run或以state_update作为依赖项的操作,否则分配实际上不会发生.例如:

Then, make sure you either call run on state_update itself or an operation that has state_update as a dependency, or the assignment will not actually happen. For example:

with tf.control_dependencies([state_update]):
    model_output = ...

如评论中所建议,RNN的典型情况是您有一个批处理,其中第一维(0)是序列数,第二维(1)是每个序列的最大长度(如果通过time_major=True当您构建RNN时,这两个被交换了.理想情况下,为了获得良好的性能,您可以将多个序列堆叠为一批,然后按时间拆分该批.但这确实是一个不同的话题.

As suggested in the comments, the typical case for RNNs is that you have a batch where the first dimension (0) is the number of sequences and the second dimension (1) is the maximum length of each sequence (if you pass time_major=True when you build the RNN these two are swapped). Ideally, in order to get good performance, you stack multiple sequences into one batch, and then split that batch time-wise. But that's all a different topic really.

这篇关于Tensorflow-批次内的LSTM状态重用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆