TensorFlow:从RNN获取所有状态 [英] TensorFlow: getting all states from a RNN

查看:298
本文介绍了TensorFlow:从RNN获取所有状态的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何从TensorFlow中的tf.nn.rnn()tf.nn.dynamic_rnn()中获取所有隐藏状态?该API仅给了我最终状态.

How do you get all the hidden states from tf.nn.rnn() or tf.nn.dynamic_rnn() in TensorFlow? The API only gives me the final state.

第一种选择是在构建直接在RNNCell上运行的模型时编写一个循环.但是,时间步的数量对我来说不是固定的,它取决于传入的批次.

The first alternative would be to write a loop when building a model that operates directly on RNNCell. However, the number of timesteps is not fixed for me, and depends on the incoming batch.

一些选项是使用GRU或编写我自己的RNNCell来将状态连接到输出.前者的选择不够普遍,而后者听起来太不客气了.

Some options are to either use a GRU or to write my own RNNCell that concatenates the state to the output. The former choice isn't general enough and the latter sounds too hacky.

另一种选择是做类似

Another option is to do something like the answers in this question, getting all the variables from an RNN. However, I'm not sure how to separate the hidden states from other variables in a standard fashion here.

在仍然使用库提供的RNN API的情况下,是否有一种从RNN获取所有隐藏状态的好方法?

Is there a nice way to get all the hidden states from an RNN while still using the library-provided RNN APIs?

推荐答案

tf.nn.dynamic_rnn(也tf.nn.static_rnn)有两个返回值; 输出",状态"( https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn )

tf.nn.dynamic_rnn(also tf.nn.static_rnn) has two return values; "outputs", "state" (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)

正如您所说,状态"是RNN的最终状态,但输出"都是RNN的隐藏状态(形状为[batch_size,max_time,cell.output_size])

As you said, "state" is the final state of RNN, but "outputs" are all hidden states of RNN(which shape is [batch_size, max_time, cell.output_size])

您可以将输出"用作RNN的隐藏状态,因为在大多数库提供的RNNCell中,输出"和状态"是相同的. (LSTMCell除外)

You can use "outputs" as hidden states of RNN, because in most library-provided RNNCell, "output" and "state" are same. (except LSTMCell)

  • Basic https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L347
  • GRU https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L441

这篇关于TensorFlow:从RNN获取所有状态的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆