获取参数< tf.Tensor'batch:0'shape =(128,56,56,3)dtype = float32>不能解释为张量。 [英] Fetch argument <tf.Tensor 'batch:0' shape=(128, 56, 56, 3) dtype=float32> cannot be interpreted as a Tensor.

查看:417
本文介绍了获取参数< tf.Tensor'batch:0'shape =(128,56,56,3)dtype = float32>不能解释为张量。的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我写了一个测试代码,当我运行它时,它说Fetch参数不能解释为张量。我真的不知道发生了什么。有人可以告诉我如何解决它吗?非常感谢。代码是

I write a test code,and when I run it ,it said Fetch argument cannot be interpreted as a Tensor.I really don't know what's going on .Can somebody tell me how to fix it? thank you very much .Here is the code

# coding=utf-8
from  color_1 import read_and_decode, get_batch, get_test_batch
import color_inference
import cv2
import os
import time
import numpy as np
import tensorflow as tf
import color_train
import math

EVAL_INTERVAL_SECS=10
batch_size=128
num_examples = 10000
crop_size=56
def test(test_x, test_y):
    with tf.Graph().as_default() as g:
        image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
        label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')

        y=color_inference.inference(image_holder)

        num_iter = int(math.ceil(num_examples / batch_size))
        true_count = 0
        total_sample_count = num_iter * batch_size
        saver=tf.train.Saver()
        top_k_op = tf.nn.in_top_k(y, label_holder, 1)
        while True:
            with tf.Session() as sess:
                ckpt=tf.train.get_checkpoint_state(color_train.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess,ckpt.model_checkpoint_path)
                    global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    image_batch, label_batch = sess.run([test_x, test_y])
                    predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch,
                                                                  label_holder: label_batch})
                    true_count += np.sum(predictions)
                    precision = true_count * 1.0 / total_sample_count
                    print("After %s training step,the prediction is :%g",global_step,precision)
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(EVAL_INTERVAL_SECS)

def main(argv=None):
    test_image, test_label = read_and_decode('val.tfrecords')
    test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)
    test(test_images, test_labels)

if __name__=='__main__':
    tf.app.run()

错误在这里:

File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 57, in <module>
    tf.app.run()
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 54, in main
    test(test_images, test_labels)
  File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 39, in test
    image_batch, label_batch = sess.run([test_x, test_y])
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 952, in _run
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 408, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 230, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 337, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 238, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 274, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'batch:0' shape=(128, 56, 56, 3) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("batch:0", shape=(128, 56, 56, 3), dtype=float32) is not an element of this graph.)


推荐答案

您专注于错误消息的错误部分。相关部分是

You focused on the wrong part of the error message. The relevant part is


张量不是该图的元素。

Tensor is not an element of this graph.

问题是您在函数 test 中创建了图 g 与创建了作为参数提供的占位符 test_x test_y 的占位符不同。

The problem is that you create a graph g in your function test, that is not the same one in which placeholders test_x and test_y provided as arguments have been created.

最简单的解决方案是在 main , g >

The easiest solution would be to create your graph g in main,

def main(argv=None):
    test_image, test_label = read_and_decode('val.tfrecords')
    with tf.Graph().as_default():
        test_images, test_labels = get_test_batch(test_image, test_label,
                                                  batch_size, crop_size)
        test(test_images, test_labels)

这篇关于获取参数&lt; tf.Tensor'batch:0'shape =(128,56,56,3)dtype = float32&gt;不能解释为张量。的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆