TensorFlow - 一次从 TFRecords 中读取所有示例? [英] TensorFlow - Read all examples from a TFRecords at once?

查看:23
本文介绍了TensorFlow - 一次从 TFRecords 中读取所有示例?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

您如何一次读取 TFRecords 中的所有示例?

How do you read all examples from a TFRecords at once?

我一直在使用 tf.parse_single_example 使用类似于 fully_connected_reader 的示例.但是,我想一次针对我的整个验证数据集运行网络,因此想要完整地加载它们.

I've been using tf.parse_single_example to read out individual examples using code similar to that given in the method read_and_decode in the example of the fully_connected_reader. However, I want to run the network against my entire validation dataset at once, and so would like to load them in their entirety instead.

我不完全确定,但是文档 似乎建议我可以使用 tf.parse_example 而不是 tf.parse_single_example 一次加载整个 TFRecords 文件.不过,我似乎无法让它发挥作用.我猜这与我如何指定功能有关,但我不确定功能规范中如何说明有多个示例.

I'm not entirely sure, but the documentation seems to suggest I can use tf.parse_example instead of tf.parse_single_example to load the entire TFRecords file at once. I can't seem to get this to work though. I'm guessing it has to do with how I specify the features, but I'm not sure how in the feature specification to state that there are multiple examples.

换句话说,我尝试使用类似的东西:

In other words, my attempt of using something similar to:

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_example(serialized_example, features={
    'image_raw': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64),
})

不起作用,我认为这是因为这些功能不期望同时出现多个示例(但同样,我不确定).[这导致ValueError:Shape() must have rank 1的错误]

isn't working, and I assume it's because the features aren't expecting multiple examples at once (but again, I'm not sure). [This results in an error of ValueError: Shape () must have rank 1]

这是一次读取所有记录的正确方法吗?如果是这样,我需要更改什么才能实际读取记录?非常感谢!

Is this the proper way to read all the records at once? And if so, what do I need to change to actually read the records? Thank you much!

推荐答案

为了清楚起见,我在一个 .tfrecords 文件中有几千张图像,它们是 720 x 720 rgb png 文件.标签是 0、1、2、3 之一.

Just for clarity, I have a few thousand images in a single .tfrecords file, they're 720 by 720 rgb png files. The labels are one of 0,1,2,3.

我也尝试过使用 parse_example,但无法使其工作,但此解决方案适用于 parse_single_example.

I also tried using the parse_example and couldn't make it work but this solution works with the parse_single_example.

缺点是现在我必须知道每个 .tf 记录中有多少项目,这有点令人沮丧.如果我找到更好的方法,我会更新答案.另外,请注意超出 .tfrecords 文件中记录数的范围,如果循环超过最后一条记录,它将从第一条记录重新开始

The downside is that right now I have to know how many items are in each .tf record, which is kind of a bummer. If I find a better way, I'll update the answer. Also, be careful going out of bounds of the number of records in the .tfrecords file, it will start over at the first record if you loop past the last record

诀窍是让队列运行器使用协调器.

The trick was to have the queue runner use a coordinator.

我在此处留下了一些代码,用于在读取图像时保存图像,以便您验证图像是否正确.

I left some code in here to save the images as they're being read in so that you can verify the image is correct.

from PIL import Image
import numpy as np
import tensorflow as tf

def read_and_decode(filename_queue):
 reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(
  serialized_example,
  # Defaults are not specified since both keys are required.
  features={
      'image_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
      'height': tf.FixedLenFeature([], tf.int64),
      'width': tf.FixedLenFeature([], tf.int64),
      'depth': tf.FixedLenFeature([], tf.int64)
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 label = tf.cast(features['label'], tf.int32)
 height = tf.cast(features['height'], tf.int32)
 width = tf.cast(features['width'], tf.int32)
 depth = tf.cast(features['depth'], tf.int32)
 return image, label, height, width, depth


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([ FILE ])
   image, label, height, width, depth = read_and_decode(filename_queue)
   image = tf.reshape(image, tf.pack([height, width, 3]))
   image.set_shape([720,720,3])
   init_op = tf.initialize_all_variables()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   for i in range(2053):
     example, l = sess.run([image, label])
     img = Image.fromarray(example, 'RGB')
     img.save( "output/" + str(i) + '-train.png')

     print (example,l)
   coord.request_stop()
   coord.join(threads)

get_all_records('/path/to/train-0.tfrecords')

这篇关于TensorFlow - 一次从 TFRecords 中读取所有示例?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆