使用Keras API,如何在给定批次中导入每个ID恰好K个实例的批次图像? [英] Using Keras APIs, how can I import images in batches with exactly K instances of each ID in a given batch?

查看:53
本文介绍了使用Keras API,如何在给定批次中导入每个ID恰好K个实例的批次图像?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试实现批量三重损失,如 https://的第3.2节所示arxiv.org/pdf/2004.06271.pdf .

I'm trying to implement batch hard triplet loss, as seen in Section 3.2 of https://arxiv.org/pdf/2004.06271.pdf.

我需要导入图像,以使每个批次在特定批次中每个ID恰好具有K个实例.因此,每个批次必须是K的倍数.

I need to import my images so that each batch has exactly K instances of each ID in a particular batch. Therefore, each batch must be a multiple of K.

我的图像目录太大,无法容纳到内存中,因此我正在使用 ImageDataGenerator.flow_from_directory()导入图像,但是我看不到该函数的任何参数允许我需要的功能.

I have a directory of images too large to fit into memory and therefore I am using ImageDataGenerator.flow_from_directory() to import the images, but I can't see any parameters for this function to allow the functionality I need.

如何使用Keras实现这种批处理行为?

推荐答案

您可以尝试以受控方式将多个数据流合并在一起.

You can try merging several data streams together in a controlled manner.

鉴于您有K个 tf.data.Dataset 实例(与实例化方式无关),它们负责提供特定ID的训练实例,您可以将它们连接起来以在内部均匀分配迷你批处理:

Given you have K instances of tf.data.Dataset (does not matter how you instantiate them) that are responsible for supplying training instances of particular IDs, you can concatenate them to get even distribution inside a mini-batch:

ds1 = ...  # Training instances with ID == 1
ds2 = ...  # Training instances with ID == 2
...
dsK = ... # Training instances with ID == K



train_dataset = tf.data.Dataset.zip((ds1, ds2, ..., dsK)).flat_map(concat_datasets).batch(batch_size=N * K)

其中 concat_datasets 是合并功能:

def concat_datasets(*datasets):
    ds = tf.data.Dataset.from_tensors(datasets[0])
    for i in range(1, len(datasets)):
        ds = ds.concatenate(tf.data.Dataset.from_tensors(datasets[i]))
    return ds

这篇关于使用Keras API,如何在给定批次中导入每个ID恰好K个实例的批次图像?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆