TensorFlow dynamic_rnn 状态 [英] TensorFlow dynamic_rnn state

查看:24
本文介绍了TensorFlow dynamic_rnn 状态的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的问题是关于 TensorFlow 方法 tf.nn.dynamic_rnn.它返回每个时间步的输出和最终状态.

My question is about the TensorFlow method tf.nn.dynamic_rnn. It returns the output of every time step and the final state.

我想知道返回的最终状态是最大序列长度时单元格的状态还是由 sequence_length 参数单独确定.

I would like to know if the returned final state is the state of the cell at the maximum sequence length or if it is determined individually by the sequence_length argument.

为了更好地理解一个例子:我有 3 个序列,长度为 [10,20,30] 并返回最终状态 [3,512](如果隐藏状态单元格的长度为 512).

For better understanding an example: I have 3 sequences with length [10,20,30] and getting back the final state [3,512] (if the hidden state of the cell has the length 512).

这三个序列的三个返回隐藏状态是时间步长 30 时单元格的状态还是我在时间步长 [10,20,30] 处取回状态?>

Are the three returned hidden states for the three sequences the state of the cell at time step 30 or am I getting back the states at the time steps [10,20,30] ?

推荐答案

tf.nn.dynamic_rnn 返回两个张量:outputsstates.

outputs 保存一批中所有序列的所有单元格的输出.因此,如果特定序列较短并用零填充,则最后一个单元格的 outputs 将为零.

The outputs holds the outputs of all cells for all sequences in a batch. So if a particular sequence is shorter and padded with zeros, the outputs for the last cells will be zero.

states 保存最后一个单元格状态,或者等效地保存每个序列的最后一个非零输出(如果您使用的是 BasicRNNCell).

The states holds the last cell state, or equivalently the last non-zero output per sequence (if you're using BasicRNNCell).

这是一个例子:

import numpy as np
import tensorflow as tf

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
])
seq_length_batch = np.array([2, 1])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print('outputs:')
  print(outputs_val)
  print('
states:')
  print(states_val)

这会打印如下内容:

outputs:
[[[-0.85381496 -0.19517037  0.36011398 -0.18617202  0.39162001]
  [-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]]

 [[-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]
  [ 0.          0.          0.          0.          0.        ]]]  # because len=1

states:
[[-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]
 [-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]]

请注意,statesoutput 保持相同的向量,它们是每个批处理实例的最后一个非零输出.

Note that the states holds the same vectors as in output, and they are the last non-zero outputs per batch instance.

这篇关于TensorFlow dynamic_rnn 状态的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆