Tensorflow,如何访问 RNN 的所有中间状态,而不仅仅是最后一个状态 [英] Tensorflow, how to access all the middle states of an RNN, not just the last state

查看:30
本文介绍了Tensorflow,如何访问 RNN 的所有中间状态,而不仅仅是最后一个状态的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的理解是 tf.nn.dynamic_rnn 在每个时间步以及最终状态返回 RNN 单元(例如 LSTM)的输出.如何在所有时间步骤中访问单元格状态,而不仅仅是最后一个?例如,我希望能够对所有隐藏状态进行平均,然后在后续层中使用.

My understanding is that tf.nn.dynamic_rnn returns the output of an RNN cell (e.g. LSTM) at each time step as well as the final state. How can I access cell states in all time steps not just the last one? For example, I want to be able to average all the hidden states and then use it in the subsequent layer.

以下是我如何定义 LSTM 单元,然后使用 tf.nn.dynamic_rnn 展开它.但这仅给出了 LSTM 的最后一个单元状态.

The following is how I define an LSTM cell and then unroll it using tf.nn.dynamic_rnn. But this only gives the last cell state of the LSTM.

import tensorflow as tf
import numpy as np

# [batch-size, sequence-length, dimensions] 
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)

outputs, last_state = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
out, last = sess.run([outputs, last_state], feed_dict=None)

推荐答案

这样的事情应该可行.

import tensorflow as tf
import numpy as np


class CustomRNN(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
        self._output_size = self._state_size # change the output size to the state size
        return returns
    def __call__(self, inputs, state):
        output, next_state = super(CustomRNN, self).__call__(inputs, state)
        return next_state, next_state # return two copies of the state, instead of the output and the state

X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 10]

cell = CustomRNN(num_units=64)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
states, last_state = sess.run([outputs, last_states], feed_dict=None)

这使用连接状态,因为我不知道您是否可以存储任意数量的元组状态.状态变量的形状为 (batch_size, max_time_size, state_size).

This uses concatenated states, as I don't know if you can store an arbitrary number of tuple states. The states variable is of shape (batch_size, max_time_size, state_size).

这篇关于Tensorflow,如何访问 RNN 的所有中间状态,而不仅仅是最后一个状态的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆