Tensorflow:如何使用dynamic_rnn从LSTMCell获得中间单元状态(c)? [英] Tensorflow: how to obtain intermediate cell states (c) from LSTMCell using dynamic_rnn?

查看:410
本文介绍了Tensorflow:如何使用dynamic_rnn从LSTMCell获得中间单元状态(c)?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

默认情况下,函数dynamic_rnn在每个时间点仅输出隐藏状态(称为m),可以通过以下方式获取该信息:

By default, function dynamic_rnn outputs only hidden states (known as m) for each time point which can be obtained as follows:

cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
                                   inputs=inputs,
                                   sequence_length=sequence_lengths,
                                   dtype=tf.float32)

是否还有一种方法可以获取中间(不是最终)单元状态(c)?

Is there a way get intermediate (not final) cell states (c) in addition?

tensorflow贡献者提到用细胞包装纸:

A tensorflow contributor mentions that it can be done with a cell wrapper:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell
  @property
  def state_size(self):
     return self._inner_cell.state_size
  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)
  def call(self, input, state)
    output, next_state = self._inner_cell(input, state)
    emit_output = (next_state, output)
    return emit_output, next_state

但是,它似乎不起作用.有什么想法吗?

However, it doesn't seem to work. Any ideas?

推荐答案

建议的解决方案对我有用,但是Layer.call方法规范更为通用,因此以下Wrapper对于API更改应更可靠.你的这个:

The proposed solution works for me, but Layer.call method spec is more general, so the following Wrapper should be more robust to API changes. Thy this:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell

  @property
  def state_size(self):
     return self._inner_cell.state_size

  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)

  def call(self, input, *args, **kwargs):
    output, next_state = self._inner_cell(input, *args, **kwargs)
    emit_output = (next_state, output)
    return emit_output, next_state

这是测试:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
print(outputs, states)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val = outputs[0].eval(feed_dict={X: X_batch})
  print(outputs_val)

返回的outputs(?, 2, 10)(?, 2, 5)张量的元组,它们都是LSTM状态和输出.请注意,我使用的是LSTMCell的分级"版本,而不是tf.contrib.rnn.另请注意state_is_tuple=True以避免与LSTMStateTuple接触.

Returned outputs is the tuple of (?, 2, 10) and (?, 2, 5) tensors, which are all LSTM states and outputs. Note that I'm using the "graduated" version of LSTMCell, from tf.nn.rnn_cell package, not tf.contrib.rnn. Also note state_is_tuple=True to avoid dealing with LSTMStateTuple.

这篇关于Tensorflow:如何使用dynamic_rnn从LSTMCell获得中间单元状态(c)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆