TensorFlow:从RNN获取所有状态 [英] TensorFlow: getting all states from a RNN
问题描述
如何从TensorFlow中的tf.nn.rnn()
或tf.nn.dynamic_rnn()
中获取所有隐藏状态?该API仅给了我最终状态.
How do you get all the hidden states from tf.nn.rnn()
or tf.nn.dynamic_rnn()
in TensorFlow? The API only gives me the final state.
第一种选择是在构建直接在RNNCell上运行的模型时编写一个循环.但是,时间步的数量对我来说不是固定的,它取决于传入的批次.
The first alternative would be to write a loop when building a model that operates directly on RNNCell. However, the number of timesteps is not fixed for me, and depends on the incoming batch.
一些选项是使用GRU或编写我自己的RNNCell来将状态连接到输出.前者的选择不够普遍,而后者听起来太不客气了.
Some options are to either use a GRU or to write my own RNNCell that concatenates the state to the output. The former choice isn't general enough and the latter sounds too hacky.
Another option is to do something like the answers in this question, getting all the variables from an RNN. However, I'm not sure how to separate the hidden states from other variables in a standard fashion here.
在仍然使用库提供的RNN API的情况下,是否有一种从RNN获取所有隐藏状态的好方法?
Is there a nice way to get all the hidden states from an RNN while still using the library-provided RNN APIs?
推荐答案
tf.nn.dynamic_rnn(也tf.nn.static_rnn)有两个返回值; 输出",状态"( https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn )
tf.nn.dynamic_rnn(also tf.nn.static_rnn) has two return values; "outputs", "state" (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)
正如您所说,状态"是RNN的最终状态,但输出"都是RNN的隐藏状态(形状为[batch_size,max_time,cell.output_size])
As you said, "state" is the final state of RNN, but "outputs" are all hidden states of RNN(which shape is [batch_size, max_time, cell.output_size])
您可以将输出"用作RNN的隐藏状态,因为在大多数库提供的RNNCell中,输出"和状态"是相同的. (LSTMCell除外)
You can use "outputs" as hidden states of RNN, because in most library-provided RNNCell, "output" and "state" are same. (except LSTMCell)
- 基本 https://github. com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L347
- GRU https://github.com com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L441
- Basic https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L347
- GRU https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L441
这篇关于TensorFlow:从RNN获取所有状态的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!