Tensorflow dynamic_rnn 参数含义 [英] Tensorflow dynamic_rnn parameters meaning

查看:28
本文介绍了Tensorflow dynamic_rnn 参数含义的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在努力理解神秘的 RNN 文档.对以下方面的任何帮助将不胜感激.

I'm struggling to understand the cryptic RNN docs. Any help with the following will be greatly appreciated.

tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None)

我正在努力理解这些参数与数学 LSTM 方程和 RNN 定义之间的关系.单元格展开大小在哪里?它是由输入的max_time"维度定义的吗?batch_size 只是为了方便拆分长数据还是与 minibatch SGD 有关?输出状态是否跨批次传递?

I'm struggling to understand how these parameters relate to the mathematical LSTM equations and RNN definition. Where is the cell unroll size? Is it defined by the 'max_time' dimension of the inputs? Is the batch_size only a convenience for splitting long data or it's related to minibatch SGD? Is the output state passed across batches?

推荐答案

tf.nn.dynamic_rnn 接收一批(具有 minibatch 的含义)无关序列.

tf.nn.dynamic_rnn takes in a batch (with the minibatch meaning) of unrelated sequences.

  • cell 是您要使用的实际单元格(LSTM、GRU、...)
  • inputs 具有 batch_size x max_time x input_size 的形状,其中 max_time 是最长序列中的步数(但所有序列的长度可以相同)
  • sequence_length 是一个大小为 batch_size 的向量,其中每个元素给出批次中每个序列的长度(如果所有序列都相同,则保留默认值大小.此参数是定义单元格展开大小的参数.
  • cell is the actual cell that you want to use (LSTM, GRU,...)
  • inputs has a shape of batch_size x max_time x input_size in which max_time is the number of steps in the longest sequence (but all sequences could be of the same length)
  • sequence_length is a vector of size batch_size in which each element gives the length of each sequence in the batch (leave it as default if all your sequences are of the same size. This parameter is the one that defines the cell unroll size.

处理隐藏状态的常用方法是在 dynamic_rnn 之前定义一个初始状态张量,例如:

The usual way of handling hidden state is to define an initial state tensor before the dynamic_rnn, like this for instance :

hidden_state_in = cell.zero_state(batch_size, tf.float32) 
output, hidden_state_out = tf.nn.dynamic_rnn(cell, 
                                             inputs,
                                             initial_state=hidden_state_in,
                                             ...)

在上面的代码段中,hidden_​​state_inhidden_​​state_out 具有相同的形状 [batch_size, ...](实际的形状取决于您使用的单元格类型,但重要的是第一维是批量大小).

In the above snippet, both hidden_state_in and hidden_state_out have the same shape [batch_size, ...] (the actual shape depends on the type of cell you use but the important thing is that the first dimension is the batch size).

这样,dynamic_rnn 对每个序列都有一个初始隐藏状态.它将自己的inputs参数中的每个序列的隐藏状态逐个时间步传递,并且hidden_​​state_out将包含批处理中每个序列的最终输出状态.同一批次的序列之间不传递隐藏状态,而仅在同一序列的时间步之间传递.

This way, dynamic_rnn has an initial hidden state for each sequence. It will pass on the hidden state from time step to time step for each sequence in the inputs parameter on its own, and hidden_state_out will contain the final output state for each sequence in the batch. No hidden state is passed between sequences of the same batch, but only between time steps of the same sequence.

通常,当您训练时,每个批次都是不相关的,因此您在执行 session.run(output) 时不必反馈隐藏状态.

Usually, when you're training, every batch is unrelated so you don't have to feed back the hidden state when doing a session.run(output).

但是,如果您正在测试,并且需要在每个时间步长输出(即您必须在每个时间步长执行 session.run()),您需要使用如下方式评估和反馈输出隐藏状态:

However, if you're testing, and you need the output at each time step, (i.e. you have to do a session.run() at every time step) you'll want to evaluate and feed back the output hidden state using something like this :

output, hidden_state = sess.run([output, hidden_state_out],
                                feed_dict={hidden_state_in:hidden_state})

否则 tensorflow 将在每个时间步只使用默认的 cell.zero_state(batch_size, tf.float32),这相当于在每个时间步重新初始化隐藏状态.

otherwise tensorflow will just use the default cell.zero_state(batch_size, tf.float32) at each time step which equates to reinitialising the hidden state at each time step.

这篇关于Tensorflow dynamic_rnn 参数含义的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆