如何从 tensorflow 数据集迭代器两次返回同一批次? [英] How can I return the same batch twice from a tensorflow dataset iterator?

查看:48
本文介绍了如何从 tensorflow 数据集迭代器两次返回同一批次?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在转换一些遗留代码以使用数据集 API - 此代码使用 feed_dict 将一批提供给火车操作(实际上是 3 次),然后重新计算用于显示的损失使用同一批.所以我需要一个迭代器来两次(或多次)返回完全相同的批次.不幸的是,我似乎无法找到一种使用 tensorflow 数据集进行处理的方法 - 可能吗?

I am converting some legacy code to use the Dataset API - this code uses feed_dict to feed one batch to the train operation (actually three times) and then recalculates the losses for display using the same batch. So I need to have an iterator that returns the exact same batch two (or several) times. Unfortunately, I can't seem to find a way of doing it with tensorflow datasets - is it possible?

推荐答案

您可以使用 Dataset.flat_map(), Dataset.from_tensors()Dataset.repeat() 在一起.例如,重复元素两次:

You can repeat individual elements of a Dataset using Dataset.flat_map(), Dataset.from_tensors() and Dataset.repeat() together. For example, to repeat elements twice:

NUM_REPEATS = 2
dataset = tf.data.Dataset.range(10)  # ...or the output of `.batch()`, etc.

# Repeat each element of `dataset` NUM_REPEATS times.
dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(NUM_REPEATS))

这篇关于如何从 tensorflow 数据集迭代器两次返回同一批次?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆