tensorflow:简单LSTM网络的共享变量错误 [英] tensorflow: shared variables error with simple LSTM network

查看:259
本文介绍了tensorflow:简单LSTM网络的共享变量错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试建立一个最简单的LSTM网络.只是希望它预测序列np_input_data中的下一个值.

I am trying to build a simplest possible LSTM network. Just want it to predict the next value in the sequence np_input_data.

import tensorflow as tf
from tensorflow.python.ops import rnn_cell
import numpy as np

num_steps = 3
num_units = 1
np_input_data = [np.array([[1.],[2.]]), np.array([[2.],[3.]]), np.array([[3.],[4.]])]

batch_size = 2

graph = tf.Graph()

with graph.as_default():
    tf_inputs = [tf.placeholder(tf.float32, [batch_size, 1]) for _ in range(num_steps)]

    lstm = rnn_cell.BasicLSTMCell(num_units)
    initial_state = state = tf.zeros([batch_size, lstm.state_size])
    loss = 0

    for i in range(num_steps-1):
        output, state = lstm(tf_inputs[i], state)
        loss += tf.reduce_mean(tf.square(output - tf_inputs[i+1]))

with tf.Session(graph=graph) as session:
    tf.initialize_all_variables().run()

    feed_dict={tf_inputs[i]: np_input_data[i] for i in range(len(np_input_data))}

    loss = session.run(loss, feed_dict=feed_dict)

    print(loss)

解释器返回:

ValueError: Variable BasicLSTMCell/Linear/Matrix already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:
    output, state = lstm(tf_inputs[i], state)

我该怎么办?

推荐答案

此处lstm的调用:

for i in range(num_steps-1):
  output, state = lstm(tf_inputs[i], state)

除非另有说明,否则

会在每次迭代时尝试创建具有相同名称的变量.您可以使用tf.variable_scope

will try to create variables with the same name each iteration unless you tell it otherwise. You can do this using tf.variable_scope

with tf.variable_scope("myrnn") as scope:
  for i in range(num_steps-1):
    if i > 0:
      scope.reuse_variables()
    output, state = lstm(tf_inputs[i], state)     

第一次迭代将创建代表您的LSTM参数的变量,并且随后的每次迭代(在调用reuse_variables之后)都将在范围内按名称查找它们.

The first iteration creates the variables that represent your LSTM parameters and every subsequent iteration (after the call to reuse_variables) will just look them up in the scope by name.

这篇关于tensorflow:简单LSTM网络的共享变量错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆