使用 Dataset API 生成平衡的小批量 [英] Produce balanced mini batch with Dataset API

查看:24
本文介绍了使用 Dataset API 生成平衡的小批量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个关于新数据集 API (tensorflow 1.4rc1) 的问题.我有一个不平衡的数据集,用于标记 01.我的目标是在预处理期间创建平衡的小批量.

I've a question about the new dataset API (tensorflow 1.4rc1). I've a unbalanced dataset wrt to labels 0 and 1. My goal is to create balanced mini batches during the preprocessing.

假设我有两个过滤的数据集:

Assume I've two filtered datasets:

ds_pos = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 1), []))
ds_neg = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 0), [])).repeat()

有没有办法组合这两个数据集,使得结果数据集看起来像 ds = [0, 1, 0, 1, 0, 1]:

Is there a way to combine these two datasets such that the resulting dataset looks like ds = [0, 1, 0, 1, 0, 1]:

像这样:

dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.apply(...)
# dataset looks like [0, 1, 0, 1, 0, 1, ...]
dataset = dataset.batch(20)

我目前的做法是:

def _concat(x, y):
   return tf.cond(tf.random_uniform(()) > 0.5, lambda: x, lambda: y)
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.map(_concat)

但我觉得有一种更优雅的方式.

But I've the feeling there is a more elegant way.

提前致谢!

推荐答案

您走对了.下面的例子使用Dataset.flat_map()将每对正例和负例在结果中变成两个连续的例子:

You are on the right track. The following example uses Dataset.flat_map() to turn each pair of a positive example and a negative example into two consecutive examples in the result:

dataset = tf.data.Dataset.zip((ds_pos, ds_neg))

# Each input element will be converted into a two-element `Dataset` using
# `Dataset.from_tensors()` and `Dataset.concatenate()`, then `Dataset.flat_map()`
# will flatten the resulting `Dataset`s into a single `Dataset`.
dataset = dataset.flat_map(
    lambda ex_pos, ex_neg: tf.data.Dataset.from_tensors(ex_pos).concatenate(
        tf.data.Dataset.from_tensors(ex_neg)))

dataset = dataset.batch(20)

这篇关于使用 Dataset API 生成平衡的小批量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆