Tensor Flow-LSTM-'Tensor'对象不可迭代 [英] Tensor Flow - LSTM - 'Tensor' object not iterable

查看:201
本文介绍了Tensor Flow-LSTM-'Tensor'对象不可迭代的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在为lstm rnn单元格使用以下功能.

def LSTM_RNN(_X, _istate, _weights, _biases):
    # Function returns a tensorflow LSTM (RNN) artificial neural network from given parameters. 
    # Note, some code of this notebook is inspired from an slightly different 
    # RNN architecture used on another dataset: 
    # https://tensorhub.com/aymericdamien/tensorflow-rnn

    # (NOTE: This step could be greatly optimised by shaping the dataset once
    # input shape: (batch_size, n_steps, n_input)
    _X = tf.transpose(_X, [1, 0, 2])  # permute n_steps and batch_size

    # Reshape to prepare input to hidden activation
    _X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)

    # Linear activation
    _X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']

    # Define a lstm cell with tensorflow
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)


    # Split data because rnn cell needs a list of inputs for the RNN inner loop
    _X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)

    # Get lstm cell output
    outputs, states = rnn.rnn(lstm_cell, _X, initial_state=_istate)

    # Linear activation
    # Get inner loop last output
    return tf.matmul(outputs[-1], _weights['out']) + _biases['out']

函数的输出存储在pred变量下.

pred = LSTM_RNN(x, istate, weights, biases)

但是它显示以下错误. (指出张量对象不可迭代.)

这是错误图片链接- http://imgur.com/a/NhSFK

请帮助我解决这个问题,如果这个问题看起来很愚蠢,我道歉,因为我对lstm和张量流库还很陌生.

谢谢.

解决方案

尝试使用语句c, h=state解压缩state时发生错误.根据所使用的tensorflow版本(您可以在python解释器中键入import tensorflow; tensorflow.__version__来检查版本信息),在r0.11之前的版本中,这是初始化rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)state_is_tuple参数的默认设置设置为False.

请参见文档.

从tensorflow版本r0.11(或主版本)开始,state_is_tuple的默认设置设置为True.请参阅此处的文档.

中的alt ="BasicLSTMCell文档

如果您安装了r0.11或tensorflow的主版本,请尝试将BasicLSTMCell初始化行更改为: lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=False).您遇到的错误应该消失.尽管,他们的页面确实说state_is_tuple=False行为将很快被弃用.

Hi I am using the following function for lstm rnn cell.

def LSTM_RNN(_X, _istate, _weights, _biases):
    # Function returns a tensorflow LSTM (RNN) artificial neural network from given parameters. 
    # Note, some code of this notebook is inspired from an slightly different 
    # RNN architecture used on another dataset: 
    # https://tensorhub.com/aymericdamien/tensorflow-rnn

    # (NOTE: This step could be greatly optimised by shaping the dataset once
    # input shape: (batch_size, n_steps, n_input)
    _X = tf.transpose(_X, [1, 0, 2])  # permute n_steps and batch_size

    # Reshape to prepare input to hidden activation
    _X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)

    # Linear activation
    _X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']

    # Define a lstm cell with tensorflow
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)


    # Split data because rnn cell needs a list of inputs for the RNN inner loop
    _X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)

    # Get lstm cell output
    outputs, states = rnn.rnn(lstm_cell, _X, initial_state=_istate)

    # Linear activation
    # Get inner loop last output
    return tf.matmul(outputs[-1], _weights['out']) + _biases['out']

The function's output is stored under pred variable.

pred = LSTM_RNN(x, istate, weights, biases)

But its showing the following error. (which states that tensor object is not iterable.)

Here is the ERROR image link - http://imgur.com/a/NhSFK

Please help me with this and I apologize if this question seems silly as I am fairly new to the lstm and tensor flow library.

Thanks.

解决方案

The error happened when it's trying to unpack state with statement c, h=state. Depending on which version of tensorflow you are using (you can check the version info by typing import tensorflow; tensorflow.__version__ in python interpreter), in version prior to r0.11, the default setting for the state_is_tuple argument when you initialize the rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0) is set to be False. See the documentation here.

Since tensorflow version r0.11 (or the master version), the default setting for state_is_tuple is set to be True. See the documentation here.

If you installed r0.11 or the master version of tensorflow, try change the BasicLSTMCell initialization line into: lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=False). The error you are encountering should go away. Although, their page does say that the state_is_tuple=False behavior will be deprecated soon.

这篇关于Tensor Flow-LSTM-'Tensor'对象不可迭代的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆