筛选数据集以仅获取特定类别的图像 [英] Filter Dataset to get just images from specific class

查看:217
本文介绍了筛选数据集以仅获取特定类别的图像的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想准备omniglot数据集以进行n镜头学习. 因此,我需要10个类(字母)中的5个样本

I want to prepare the omniglot dataset for n-shot learning. Therefore I need 5 samples from 10 classes (alphabet)

要复制的代码

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

builder = tfds.builder("omniglot")
# assert builder.info.splits['train'].num_examples == 60000
builder.download_and_prepare()
# Load data from disk as tf.data.Datasets
datasets = builder.as_dataset()
dataset, test_dataset = datasets['train'], datasets['test']


def resize(example):
    image = example['image']
    image = tf.image.resize(image, [28, 28])
    image = tf.image.rgb_to_grayscale(image, )
    image = image / 255
    one_hot_label = np.zeros((51, 10))
    return image, one_hot_label, example['alphabet']


def stack(image, label, alphabet):
    return (image, label), label[-1]

def filter_func(image, label, alphabet):
    # get just images from alphabet in array, not just 2
    arr = np.array(2,3,4,5)
    result = tf.reshape(tf.equal(alphabet, 2 ), [])
    return result

# correct size
dataset = dataset.map(resize)
# now filter the dataset for the batch
dataset = dataset.filter(filter_func)
# infinite stream of batches (classes*samples + 1)
dataset = dataset.repeat().shuffle(1024).batch(51)
# stack the images together
dataset = dataset.map(stack)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(32)

for i, (image, label) in enumerate(tfds.as_numpy(dataset)):
    print(i, image[0].shape)

现在,我想使用过滤器功能过滤数据集中的图像. tf.equal只是让我按一个类过滤,我想要数组中的张量之类的东西.

Now I want to filter the images in the dataset by using the filter function. tf.equal just let me filter by one class, I want something like tensor in array.

您看到使用过滤器功能执行此操作的方法吗? 还是这是错误的方法,并且有更简单的方法?

Do you see a way doing this with the filter function? Or is this the wrong way and there is a much simpler way?

我想创建一批51张图像和相应的标签,它们来自相同的N = 10类.从每个班级,我需要K = 5个不同的图像和一个附加的图像(我需要对其进行分类).每批N * K + 1(51)张图像应来自10个新的随机类.

I want to create a batch of 51 images and according labels, which are from the same N=10 classes. From every class, I need K=5 different images and an additional one (which I need to classify). Every batch of N*K+1 (51) images should be from 10 new random classes.

非常感谢您.

推荐答案

tf.equal() 支持广播,并允许将标量与rank > 0的张量进行比较.

要仅保留特定标签,请使用以下谓词:

To KEEP only specific labels use this predicate:

dataset = datasets['train']

def predicate(x, allowed_labels=tf.constant([0., 1., 2.])):
    label = x['label']
    isallowed = tf.equal(allowed_labels, tf.cast(label, tf.float32))
    reduced = tf.reduce_sum(tf.cast(isallowed, tf.float32))
    return tf.greater(reduced, tf.constant(0.))

dataset = dataset.filter(predicate).batch(20)

for i, x in enumerate(tfds.as_numpy(dataset)):
    print(x['label'])
# [1 0 0 1 2 1 1 2 1 0 0 1 2 0 1 0 2 2 0 1]
# [1 0 2 2 0 2 1 2 1 2 2 2 0 2 0 2 1 2 1 1]
# [2 1 2 1 0 1 1 0 1 2 2 0 2 0 1 0 0 0 0 0]

allowed_labels指定要保留的标签.所有不在此张量中的标签将被过滤掉.

allowed_labels specifies labels you want to keep. All labels that are not in this tensor will be filtered out.

这篇关于筛选数据集以仅获取特定类别的图像的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆