TensorFlow DataSet API 导致图大小爆炸 [英] TensorFlow DataSet API causes graph size to explode

查看:24
本文介绍了TensorFlow DataSet API 导致图大小爆炸的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个非常错误的训练数据集.

I have a very bug data set for training.

我像这样使用数据集 API:

I'm using the data set API like so:

self._dataset = tf.contrib.data.Dataset.from_tensor_slices((self._images_list, self._labels_list))

self._dataset = self._dataset.map(self.load_image)

self._dataset = self._dataset.batch(batch_size)
self._dataset = self._dataset.shuffle(buffer_size=shuffle_buffer_size)
self._dataset = self._dataset.repeat()

self._iterator = self._dataset.make_one_shot_iterator()

如果我使用少量数据进行训练,那么一切都很好.如果我使用我的所有数据,那么 TensorFlow 将因以下错误而崩溃:ValueError: GraphDef 不能大于 2GB.

If I use for the training a small amount of the data then all is well. If I use all my data then TensorFlow will crash with this error: ValueError: GraphDef cannot be larger than 2GB.

似乎 TensorFlow 试图加载所有数据,而不是仅加载它需要的数据......不确定......

It seems like TensorFlow tries to load all the data instead of loading only the data that it needs... not sure...

任何建议都会很棒!

更新...找到了解决方案/变通方法

根据这篇文章:Tensorflow Dataset API 使图形 protobuff 文件大小翻倍

我用 make_initializable_iterator() 替换了 make_one_shot_iterator() 并且当然在创建会话后调用了迭代器初始值设定项:

I replaced the make_one_shot_iterator() with make_initializable_iterator() and of course called the iterator initializer after creating the session:

init = tf.global_variables_initializer()
sess.run(init)
sess.run(train_data._iterator.initializer)

但我对这个问题持开放态度,这似乎是一种解决方法而不是解决方案......

But I'm leaving the question open as to me it seems like a workaround and not a solution...

推荐答案

https://www.tensorflow.org/guide/datasets#sumption_numpy_arrays

请注意,上面的代码片段会将特征和标签数组作为 tf.constant() 操作嵌入到您的 TensorFlow 图中.这适用于小数据集,但会浪费内存——因为数组的内容将被多次复制——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制.作为替代方案,您可以根据 tf.placeholder() 张量定义数据集,并在对数据集初始化迭代器时提供 NumPy 数组.

Note that the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer. As an alternative, you can define the Dataset in terms of tf.placeholder() tensors, and feed the NumPy arrays when you initialize an Iterator over the dataset.

代替使用

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

使用

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

这篇关于TensorFlow DataSet API 导致图大小爆炸的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆