KeyError : 张量变量,引用不存在的张量 [英] KeyError : The tensor variable , Refer to the tensor which does not exists

查看:43
本文介绍了KeyError : 张量变量,引用不存在的张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

使用 LSTMCell 我训练了一个模型来生成文本.我启动了 tensorflow 会话并使用 tf.global_variables_initializer() 保存了所有的 tensorflow 变量.

Using LSTMCell i trained a model to do text generation . I started the tensorflow session and save all the tensorflow varibles using tf.global_variables_initializer() .

import tensorflow as tf
sess = tf.Session()
//code blocks
run_init_op = tf.global_variables_intializer()
sess.run(run_init_op)
saver = tf.train.Saver()
#varible that makes prediction
prediction = tf.nn.softmax(tf.matmul(last,weight)+bias)
#feed the inputdata into model and trained
#saved the model
#save the tensorflow model
save_path= saver.save(sess,'/tmp/text_generate_trained_model.ckpt')
print("Model saved in the path : {}".format(save_path))

模型得到训练并保存其所有会话.查看整个代码的链接 lstm_rnn.py

The model get trained and saved all its session . Link to review the whole code lstm_rnn.py

现在我加载了存储的模型并尝试为文档生成文本.所以,我用以下代码恢复了模型

Now i loaded the stored model and tried to do text generation for the document . So,i restored the model with following code

tf.reset_default_graph()
imported_data = tf.train.import_meta_graph('text_generate_trained_model.ckpt.meta')
with tf.Session() as sess:
    imported_meta.restore(sess,tf.train.latest_checkpoint('./'))

    #accessing the default graph which we restored
    graph = tf.get_default_graph()

    #op that we can be processed to get the output
    #last is the tensor that is the prediction of the network
    y_pred = graph.get_tensor_by_name("prediction:0")
    #generate characters
    for i in range(500):
        x = np.reshape(pattern,(1,len(pattern),1))
        x = x / float(n_vocab)
        prediction = sess.run(y_pred,feed_dict=x)
        index = np.argmax(prediction)
        result = int_to_char[index]
        seq_in = [int_to_char[value] for value in pattern]
        sys.stdout.write(result)
        patter.append(index)
        pattern = pattern[1:len(pattern)]

    print("\n Done...!")
sess.close()

我开始知道图中不存在预测变量.

I came to know that the prediction variable does not exist in the graph.

KeyError:名称‘预测:0’指的是一个张量,它不存在.图中不存在预测"操作."

KeyError: "The name 'prediction:0' refers to a Tensor which does not exist. The operation, 'prediction', does not exist in the graph."

完整代码可在此处text_generation.py

虽然我保存了所有 tensorflow 变量,但预测张量并未保存在 tensorflow 计算图中.我的 lstm_rnn.py 文件有什么问题.

Though i saved all tensorflow varibles , the prediction tensor is not saved in the tensorflow computation graph . whats wrong in my lstm_rnn.py file .

谢谢!

推荐答案

要使 graph.get_tensor_by_name("prediction:0") 工作,您应该在创建它时为其命名.这就是你如何命名它

For graph.get_tensor_by_name("prediction:0") to work you should have named it when you created it. This is how you can name it

prediction = tf.nn.softmax(tf.matmul(last,weight)+bias, name="prediction")

如果您已经训练了模型并且无法重命名张量,您仍然可以通过其默认名称获取该张量,

If you have already trained the model and can't rename the tensor, you can still get that tensor by its default name as in,

y_pred = graph.get_tensor_by_name("Reshape_1:0")

如果 Reshape_1 不是张量的实际名称,您必须查看图中的名称并找出答案.你可以用

If Reshape_1 is not the actual name of the tensor, you'll have to look at the names in the graph and figure it out. You can inspect that with

for op in graph.get_operations():
    print(op.name)

这篇关于KeyError : 张量变量,引用不存在的张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆