如何将 Tensorflow 数据集保存到文件中? [英] How do you save a Tensorflow dataset to a file?

查看:424
本文介绍了如何将 Tensorflow 数据集保存到文件中?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

SO 上至少还有两个这样的问题,但没有一个得到回答.

There are at least two more questions like this on SO but not a single one has been answered.

我有一个如下形式的数据集:

I have a dataset of the form:

<TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

和另一种形式:

<BatchDataset shapes: ((None, 512), (None, 512), (None, 512), (None,)), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

我看了又看,但找不到将这些数据集保存到以后可以加载的文件的代码.我得到的最接近的是 TensorFlow 文档中的这个页面,其中建议使用 tf.io.serialize_tensor 序列化张量,然后使用 tf.data.experimental.TFRecordWriter 将它们写入文件.

I have looked and looked but I can't find the code to save these datasets to files that can be loaded later. The closest I got was this page in the TensorFlow docs, which suggests serializing the tensors using tf.io.serialize_tensor and then writing them to a file using tf.data.experimental.TFRecordWriter.

但是,当我使用代码尝试此操作时:

However, when I tried this using the code:

dataset.map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(dataset)

第一行出现错误:

TypeError: serialize_tensor() 需要 1 到 2 个位置参数,但给出了 4 个

TypeError: serialize_tensor() takes from 1 to 2 positional arguments but 4 were given

我如何修改上述内容(或做其他事情)以实现我的目标?

How can I modify the above (or do something else) to accomplish my goal?

推荐答案

TFRecordWriter 似乎是最方便的选项,但遗憾的是它只能写入每个元素一个张量的数据集.您可以使用以下几种解决方法.首先,由于您所有的张量都具有相同的类型和相似的形状,您可以将它们全部连接成一个,然后在加载时将它们拆分回来:

TFRecordWriter seems to be the most convenient option, but unfortunately it can only write datasets with a single tensor per element. Here are a couple of workarounds you can use. First, since all your tensors have the same type and similar shape, you can concatenate them all into one, and split them back later on load:

import tensorflow as tf

# Write
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
print(ds)
# <TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
def write_map_fn(x1, x2, x3, x4):
    return tf.io.serialize_tensor(tf.concat([x1, x2, x3, tf.expand_dims(x4, -1)], -1))
ds = ds.map(write_map_fn)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(ds)

# Read
def read_map_fn(x):
    xp = tf.io.parse_tensor(x, tf.int32)
    # Optionally set shape
    xp.set_shape([1537])  # Do `xp.set_shape([None, 1537])` if using batches
    # Use `x[:, :512], ...` if using batches
    return xp[:512], xp[512:1024], xp[1024:1536], xp[-1]
ds = tf.data.TFRecordDataset('mydata.tfrecord').map(read_map_fn)
print(ds)
# <MapDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

但是,更一般地说,您可以简单地为每个张量创建一个单独的文件,然后将它们全部读取:

But, more generally, you can simply have a separate file per tensor and then read them all:

import tensorflow as tf

# Write
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
for i, _ in enumerate(ds.element_spec):
    ds_i = ds.map(lambda *args: args[i]).map(tf.io.serialize_tensor)
    writer = tf.data.experimental.TFRecordWriter(f'mydata.{i}.tfrecord')
    writer.write(ds_i)

# Read
NUM_PARTS = 4
parts = []
def read_map_fn(x):
    return tf.io.parse_tensor(x, tf.int32)
for i in range(NUM_PARTS):
    parts.append(tf.data.TFRecordDataset(f'mydata.{i}.tfrecord').map(read_map_fn))
ds = tf.data.Dataset.zip(tuple(parts))
print(ds)
# <ZipDataset shapes: (<unknown>, <unknown>, <unknown>, <unknown>), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

可以将整个数据集放在一个文件中,每个元素有多个单独的张量,即作为包含 tf.train.Examples 的 TFRecords 文件,但我不知道是否有一种方法可以在 TensorFlow 中创建这些数据,即不必将数据中的数据取出到 Python 中,然后将其写入记录文件.

It is possible to have the whole dataset in a single file with multiple separate tensors per element, namely as a file of TFRecords containing tf.train.Examples, but I don't know if there is a way to create those within TensorFlow, that is, without having to get the data out of the dataset into Python and then write it to the records file.

这篇关于如何将 Tensorflow 数据集保存到文件中?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆