如何使用 tf.data.Dataset.from_generator() 向生成器函数发送参数? [英] How do you send arguments to a generator function using tf.data.Dataset.from_generator()?

查看:28
本文介绍了如何使用 tf.data.Dataset.from_generator() 向生成器函数发送参数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用 from_generator() 函数创建多个 tf.data.Dataset.我想向生成器函数 (raw_data_gen) 发送一个参数.这个想法是生成器函数将根据发送的参数产生不同的数据.通过这种方式,我希望 raw_data_gen 能够提供训练、验证或测试数据.

I would like to create a number of tf.data.Dataset using the from_generator() function. I would like to send an argument to the generator function (raw_data_gen). The idea is that the generator function will yield different data depending on the argument sent. In this way I would like raw_data_gen to be able to provide either training, validation or test data.

training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))

validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))

test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))

当我尝试以这种方式调用 from_generator() 时得到的错误信息是:

The error message I get when I try to call from_generator() in this way is:

TypeError: from_generator() got an unexpected keyword argument 'args'

这是 raw_data_gen 函数,虽然我不确定您是否需要它,因为我的直觉是问题出在 from_generator() 的调用上:

Here is the raw_data_gen function although I'm not sure if you will need this as my hunch is that the problem is with the call of from_generator():

def raw_data_gen(train_val_or_test):

    if train_val_or_test == 1:        
        #For every filename collected in the list
        for filename, lab in training_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 2:
        #For every filename collected in the list
        for filename, lab in validation_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 3:
        #For every filename collected in the list
        for filename, lab in test_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    else:
        print("generator function called with an argument not in [1, 2, 3]")
        raise ValueError()

推荐答案

您需要基于 raw_data_gen 定义一个不带任何参数的新函数.您可以使用 lambda 关键字来执行此操作.

You need to define a new function based on raw_data_gen that doesn't take any arguments. You can use the lambda keyword to do this.

training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...

现在,我们将一个不带任何参数的函数传递给 from_generator,但它只会作为 raw_data_gen 的参数设置为 1.您可以验证集和测试集使用相同的方案,分别通过 2 和 3.

Now, we are passing a function to from_generator that doesn't take any arguments, but that will simply act as raw_data_gen with the argument set to 1. You can use the same scheme for the validation and test sets, passing 2 and 3 respectively.

这篇关于如何使用 tf.data.Dataset.from_generator() 向生成器函数发送参数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆