TensorFlow:`tf.data.Dataset.from_generator()` 不适用于 Python 3.x 上的字符串 [英] TensorFlow: `tf.data.Dataset.from_generator()` does not work with strings on Python 3.x

查看:25
本文介绍了TensorFlow:`tf.data.Dataset.from_generator()` 不适用于 Python 3.x 上的字符串的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我需要遍历大量图像文件并将数据提供给 tensorflow.我通过生成器函数创建了一个 Dataset,该函数将文件路径名生成为字符串,然后使用 map 将字符串路径转换为图像数据.但它失败了,因为生成字符串值不起作用,如下所示.是否有解决方法或解决方法?

I need to iterate through large number of image files and feed the data to tensorflow. I created a Dataset back by a generator function that produces the file path names as strings and then transform the string path to image data using map. But it failed as generating string values won't work, as shown below. Is there a fix or work around for this?

2017-12-07 15:29:05.820708: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
producing data/miniImagenet/val/n01855672/n0185567200001000.jpg
2017-12-07 15:29:06.009141: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
2017-12-07 15:29:06.009215: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type str
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
Traceback (most recent call last):
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call
    return fn(*args)
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn
    status, run_metadata)
  File "/Users/me/.tox/tf2/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type str
     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_STRING], token="pyfunc_1"](arg0)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,21168]], output_types=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

测试代码如下所示.它可以与 from_tensor_slices 一起正常工作,或者首先将文件名列表放入张量中.但是,任何一种解决方法都会耗尽 GPU 内存.

The test codes are shown below. It can work correctly with from_tensor_slices or by first putting the the file name list in a tensor. however, either work around would exhaust GPU memory.

import tensorflow as tf

if __name__ == "__main__":
    file_names = ['data/miniImagenet/val/n01855672/n0185567200001000.jpg',
                  'data/miniImagenet/val/n01855672/n0185567200001005.jpg']
    # note: converting the file list to tensor and returning an index from generator works
    # path_to_indexes = {p: i for i, p in enumerate(file_names)}
    # file_names_tensor = tf.convert_to_tensor(file_names)

    def dataset_producer():
        for s in file_names:
            print('producing', s)
            yield s
    dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
                                             output_shapes=(tf.TensorShape([])))

    # note: this would also work
    # dataset = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(file_names))

    def read_image(filename):
        # filename = file_names_tensor[filename_index]
        image_file = tf.read_file(filename, name='read_file')
        image = tf.image.decode_jpeg(image_file, channels=3)
        image.set_shape((84,84,3))
        image = tf.reshape(image, [21168])
        image = tf.cast(image, tf.float32) / 255.0
        return image

    dataset = dataset.map(read_image)
    dataset = dataset.batch(2)
    data_iterator = dataset.make_one_shot_iterator()
    images = data_iterator.get_next()
    print('images', images)
    max_value = tf.argmax(images)
    with tf.Session() as session:
        result = session.run(max_value)
        print(result)

推荐答案

这是一个影响 Python 3.x 的错误,它是 已修复 TensorFlow 1.4 版本后.TensorFlow 1.5 及以后的所有版本都包含此修复程序.

This is a bug affecting Python 3.x that was fixed after the TensorFlow 1.4 release. All releases of TensorFlow from 1.5 onwards contain the fix.

如果您只是使用早期版本,解决方法是在从生成器返回字符串之前将字符串转换为 bytes.以下代码应该可以工作:

If you just use an earlier version, the workaround is to convert the strings to bytes before returning them from the generator. The following code should work:

def dataset_producer():
    for s in file_names:
        print('producing', s)
        yield s.encode('utf-8')  # Convert `s` to `bytes`.

dataset = tf.data.Dataset.from_generator(dataset_producer, output_types=(tf.string),
                                         output_shapes=(tf.TensorShape([])))

这篇关于TensorFlow:`tf.data.Dataset.from_generator()` 不适用于 Python 3.x 上的字符串的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆