TensorFlow dynamic_rnn状态 [英] TensorFlow dynamic_rnn state

查看:114
本文介绍了TensorFlow dynamic_rnn状态的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的问题是关于TensorFlow方法tf.nn.dynamic_rnn.它返回每个时间步和最终状态的输出.

My question is about the TensorFlow method tf.nn.dynamic_rnn. It returns the output of every time step and the final state.

我想知道返回的最终状态是最大序列长度的单元格状态还是由sequence_length参数单独确定.

I would like to know if the returned final state is the state of the cell at the maximum sequence length or if it is determined individually by the sequence_length argument.

为了更好地理解示例:我有3个长度为[10,20,30]的序列,并返回最终状态[3,512](如果单元格的隐藏状态的长度为512).

For better understanding an example: I have 3 sequences with length [10,20,30] and getting back the final state [3,512] (if the hidden state of the cell has the length 512).

这三个序列的三个返回的隐藏状态是时间步30的单元格状态还是我在时间步[10,20,30]取回状态?

Are the three returned hidden states for the three sequences the state of the cell at time step 30 or am I getting back the states at the time steps [10,20,30] ?

推荐答案

返回两个张量:outputsstates.

tf.nn.dynamic_rnn returns two tensors: outputs and states.

outputs保留一批中所有序列的所有单元的输出.因此,如果特定序列更短并用零填充,则最后一个单元格的outputs将为零.

The outputs holds the outputs of all cells for all sequences in a batch. So if a particular sequence is shorter and padded with zeros, the outputs for the last cells will be zero.

states保留每个单元的最后一个单元状态,或等效的每个序列的最后一个非零输出(如果使用的是BasicRNNCell).

The states holds the last cell state, or equivalently the last non-zero output per sequence (if you're using BasicRNNCell).

这是一个例子:

import numpy as np
import tensorflow as tf

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
])
seq_length_batch = np.array([2, 1])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print('outputs:')
  print(outputs_val)
  print('\nstates:')
  print(states_val)

打印的内容如下:

outputs:
[[[-0.85381496 -0.19517037  0.36011398 -0.18617202  0.39162001]
  [-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]]

 [[-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]
  [ 0.          0.          0.          0.          0.        ]]]  # because len=1

states:
[[-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]
 [-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]]

请注意,states具有与output中相同的向量,它们是每个批处理实例的最后一个非零输出.

Note that the states holds the same vectors as in output, and they are the last non-zero outputs per batch instance.

这篇关于TensorFlow dynamic_rnn状态的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆