张量流 - LSTM - 'Tensor'对象不可迭代 [英] Tensor Flow - LSTM - 'Tensor' object not iterable
问题描述
我正在为 lstm rnn 单元使用以下函数.
def LSTM_RNN(_X, _istate, _weights, _biases):# 函数从给定参数返回一个张量流 LSTM (RNN) 人工神经网络.# 注意,这个笔记本的一些代码是从一个稍微不同的# 在另一个数据集上使用的 RNN 架构:# https://tensorhub.com/aymericdamien/tensorflow-rnn#(注意:这一步可以通过对数据集进行一次整形来大大优化# 输入形状:(batch_size, n_steps, n_input)_X = tf.transpose(_X, [1, 0, 2]) # 置换 n_steps 和 batch_size# 重塑以准备隐藏激活的输入_X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)# 线性激活_X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']# 用 tensorflow 定义一个 lstm 单元lstm_cell = rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0)# 拆分数据,因为 rnn cell 需要一个用于 RNN 内循环的输入列表_X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)# 获取 lstm 单元格输出输出,状态 = rnn.rnn(lstm_cell, _X, initial_state=_istate)# 线性激活# 获取内循环最后输出返回 tf.matmul(outputs[-1], _weights['out']) + _biases['out']
函数的输出存储在 pred 变量下.
pred = LSTM_RNN(x, itate, weights, bias)
但它显示以下错误.(这表明张量对象不可迭代.)
这是错误图片链接 -
从 tensorflow 版本 r0.11(或主版本)开始,state_is_tuple
的默认设置被设置为 True
.请参阅此处的
如果你安装了 r0.11 或 tensorflow 的主版本,尝试将 BasicLSTMCell
初始化行更改为:lstm_cell = rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0,state_is_tuple=False)
.您遇到的错误应该会消失.虽然,他们的页面确实说 state_is_tuple=False
行为将很快被弃用.
Hi I am using the following function for lstm rnn cell.
def LSTM_RNN(_X, _istate, _weights, _biases):
# Function returns a tensorflow LSTM (RNN) artificial neural network from given parameters.
# Note, some code of this notebook is inspired from an slightly different
# RNN architecture used on another dataset:
# https://tensorhub.com/aymericdamien/tensorflow-rnn
# (NOTE: This step could be greatly optimised by shaping the dataset once
# input shape: (batch_size, n_steps, n_input)
_X = tf.transpose(_X, [1, 0, 2]) # permute n_steps and batch_size
# Reshape to prepare input to hidden activation
_X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)
# Linear activation
_X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']
# Define a lstm cell with tensorflow
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Split data because rnn cell needs a list of inputs for the RNN inner loop
_X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)
# Get lstm cell output
outputs, states = rnn.rnn(lstm_cell, _X, initial_state=_istate)
# Linear activation
# Get inner loop last output
return tf.matmul(outputs[-1], _weights['out']) + _biases['out']
The function's output is stored under pred variable.
pred = LSTM_RNN(x, istate, weights, biases)
But its showing the following error. (which states that tensor object is not iterable.)
Here is the ERROR image link - http://imgur.com/a/NhSFK
Please help me with this and I apologize if this question seems silly as I am fairly new to the lstm and tensor flow library.
Thanks.
The error happened when it's trying to unpack state
with statement c, h=state
. Depending on which version of tensorflow you are using (you can check the version info by typing import tensorflow; tensorflow.__version__
in python interpreter), in version prior to r0.11, the default setting for the state_is_tuple
argument when you initialize the rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
is set to be False
. See the documentation here.
Since tensorflow version r0.11 (or the master version), the default setting for state_is_tuple
is set to be True
. See the documentation here.
If you installed r0.11 or the master version of tensorflow, try change the BasicLSTMCell
initialization line into:
lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=False)
. The error you are encountering should go away. Although, their page does say that the state_is_tuple=False
behavior will be deprecated soon.
这篇关于张量流 - LSTM - 'Tensor'对象不可迭代的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!