获取参数< tf.Tensor'batch:0'shape =(128,56,56,3)dtype = float32>不能解释为张量。 [英] Fetch argument <tf.Tensor 'batch:0' shape=(128, 56, 56, 3) dtype=float32> cannot be interpreted as a Tensor.
问题描述
我写了一个测试代码,当我运行它时,它说Fetch参数不能解释为张量。我真的不知道发生了什么。有人可以告诉我如何解决它吗?非常感谢。代码是
I write a test code,and when I run it ,it said Fetch argument cannot be interpreted as a Tensor.I really don't know what's going on .Can somebody tell me how to fix it? thank you very much .Here is the code
# coding=utf-8
from color_1 import read_and_decode, get_batch, get_test_batch
import color_inference
import cv2
import os
import time
import numpy as np
import tensorflow as tf
import color_train
import math
EVAL_INTERVAL_SECS=10
batch_size=128
num_examples = 10000
crop_size=56
def test(test_x, test_y):
with tf.Graph().as_default() as g:
image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')
y=color_inference.inference(image_holder)
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
saver=tf.train.Saver()
top_k_op = tf.nn.in_top_k(y, label_holder, 1)
while True:
with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state(color_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
image_batch, label_batch = sess.run([test_x, test_y])
predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch,
label_holder: label_batch})
true_count += np.sum(predictions)
precision = true_count * 1.0 / total_sample_count
print("After %s training step,the prediction is :%g",global_step,precision)
else:
print('No checkpoint file found')
return
time.sleep(EVAL_INTERVAL_SECS)
def main(argv=None):
test_image, test_label = read_and_decode('val.tfrecords')
test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)
test(test_images, test_labels)
if __name__=='__main__':
tf.app.run()
错误在这里:
File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 57, in <module>
tf.app.run()
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 54, in main
test(test_images, test_labels)
File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 39, in test
image_batch, label_batch = sess.run([test_x, test_y])
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
run_metadata_ptr)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 952, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 408, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 230, in for_fetch
return _ListFetchMapper(fetch)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 337, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 238, in for_fetch
return _ElementFetchMapper(fetches, contraction_fn)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 274, in __init__
'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'batch:0' shape=(128, 56, 56, 3) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("batch:0", shape=(128, 56, 56, 3), dtype=float32) is not an element of this graph.)
推荐答案
您专注于错误消息的错误部分。相关部分是
You focused on the wrong part of the error message. The relevant part is
张量不是该图的元素。
Tensor is not an element of this graph.
问题是您在函数 test
中创建了图 g
与创建了作为参数提供的占位符 test_x
和 test_y
的占位符不同。
The problem is that you create a graph g
in your function test
, that is not the same one in which placeholders test_x
and test_y
provided as arguments have been created.
最简单的解决方案是在 main
,
The easiest solution would be to create your graph g
in main
,
def main(argv=None):
test_image, test_label = read_and_decode('val.tfrecords')
with tf.Graph().as_default():
test_images, test_labels = get_test_batch(test_image, test_label,
batch_size, crop_size)
test(test_images, test_labels)
这篇关于获取参数< tf.Tensor'batch:0'shape =(128,56,56,3)dtype = float32>不能解释为张量。的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!