关于 fit_generator()/fit() 和线程安全 [英] On fit_generator() / fit() and thread-safety

查看:34
本文介绍了关于 fit_generator()/fit() 和线程安全的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

上下文

为了在 Keras 中使用 fit_generator(),我使用了一个像这样的生成器函数伪代码-one:

In order to use fit_generator() in Keras I use a generator-function like this pseudocode-one:

def generator(data: np.array) -> (np.array, np.array):
    """Simple generator yielding some samples and targets"""

    while True:
        for batch in range(number_of_batches):
            yield data[batch * length_sequence], data[(batch + 1) * length_sequence]

在 Keras 的 fit_generator() 函数中,我想使用 workers=4use_multiprocessing=True - 因此,我需要一个线程安全生成器.

In Keras' fit_generator() function I want to use workers=4 and use_multiprocessing=True - Hence, I need a threadsafe generator.

在像这里这里 或在 Keras docs 中,我阅读了有关创建从 Keras 继承的类的信息.utils.Sequence() 像这样:

In answers on stackoverflow like here or here or in the Keras docs, I read about creating a class inheriting from Keras.utils.Sequence() like this:

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return ...

通过使用 Sequences Keras 不会在使用多工作和多处理时发出任何警告;生成器应该是线程安全的.

By using Sequences Keras does not throw any warning using multiple workes and multiprocessing; the generator is supposed to be threadsafe.

无论如何,由于我正在使用我的自定义函数,我偶然发现了 github 它允许通过添加装饰器使我的 generator() 线程安全.代码如下:

Anyhow, since I am using my custom function I stumbled upon Omer Zohars code provided on github which allows to make my generator() threadsafe by adding a decorator. The code looks like:

import threading

class threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe."""
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))

    return g

现在我可以:

@threadsafe_generator
def generator(data):
    ...

问题是:使用此版本的线程安全生成器 Keras 仍然会发出警告,即在使用 workers > 时生成器必须是线程安全的.1use_multiprocessing=True 并且可以通过使用 Sequences 避免这种情况.

The thing is: Using this version of a threadsafe generator Keras still emits a warning that the generator has to be threadsafe when using workers > 1 and use_multiprocessing=True and that this can be avoided by using Sequences.


我现在的问题是:

  1. Keras 是否仅仅因为生成器没有继承 Sequences 才发出这个警告,还是 Keras 还检查生成器是否是线程安全的?
  2. 正在使用我选择作为线程安全的方法作为使用来自 Keras 文档?
  3. 是否有任何其他方法导致 Keras 可以处理与这两个示例不同的线程安全生成器?

  1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?
  2. Is using the approach I choosed as threadsafe as using the generatorClass(Sequence)-version from the Keras-docs?
  3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?

<小时>

在较新的 tensorflow/keras-versions (tf > 2) 中,不推荐使用 fit_generator().相反,建议将 fit() 与生成器一起使用.然而,这个问题仍然适用于使用生成器的 fit().


In newer tensorflow/keras-versions (tf > 2) fit_generator() is deprecated. Instead, it is recommended to use fit() with the generator. However, the question still applies to fit() using a generator as well.

推荐答案

在我对此进行研究的过程中,我发现了一些信息来回答我的问题.

During my research on this I came across some information answering my questions.

注意:在较新的 tensorflow/keras-versions (tf > 2) fit_generator() 已弃用.相反,建议将 fit() 与生成器一起使用.然而,答案仍然适用于使用生成器的 fit() .

Note: As updated in the question in newer tensorflow/keras-versions (tf > 2) fit_generator() is deprecated. Instead, it is recommended to use fit() with the generator. However, the answer still applies to fit() using a generator as well.

1. Keras 发出这个警告只是因为生成器没有继承序列,还是 Keras 也检查生成器是否是线程安全的?

1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?

取自 Keras 的 gitRepo (training_generators.py) 我在 46-52 行中发现了以下内容:

Taken from Keras' gitRepo (training_generators.py) I found in lines 46-52 the following:

use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
    warnings.warn(
        UserWarning('Using a generator with `use_multiprocessing=True`'
                    ' and multiple workers may duplicate your data.'
                    ' Please consider using the `keras.utils.Sequence'
                    ' class.'))

is_sequence() 的定义取自 624-635 行中的 ="nofollow noreferrer">training_utils.py 是:

The definition of is_sequence() taken from training_utils.py in lines 624-635 is:

def is_sequence(seq):
    """Determine if an object follows the Sequence API.
    # Arguments
        seq: a possible Sequence object
    # Returns
        boolean, whether the object follows the Sequence API.
    """
    # TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
    return (getattr(seq, 'use_sequence_api', False)
            or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))

关于这段代码,Keras 只检查传递的生成器是否是 Keras 序列(或者更确切地说是使用 Keras 的序列 API),并且通常不检查生成器是否是线程安全的.

Regarding this piece of code Keras only checks if a passed generator is a Keras-sequence (or rather uses Keras' sequence API) and does not check if a generator is threadsafe in general.

2. 正在使用我选择作为线程安全的方法作为使用来自 Keras 文档?

2. Is using the approach I choosed as threadsafe as using the generatorClass(Sequence)-version from the Keras-docs?

正如 Omer Zohar 在 gitHub 上所展示的那样,他的装饰器是线程安全的 - 我没有看到它不应该像 Keras 那样线程安全的任何原因(即使 Keras 会发出警告,如 1. 所示).thread.Lock() 的实现可以根据 文档:

As Omer Zohar has shown on gitHub his decorator is threadsafe - I don't see any reason why it shouldn't be as threadsafe for Keras (even though Keras will warn as shown in 1.). The implementation of thread.Lock() can be concidered as threadsafe according to the docs:

返回一个新的原始锁对象的工厂函数.一旦一个线程获得它,随后的获取它的尝试就会阻塞,直到它被释放;任何线程都可以释放它.

A factory function that returns a new primitive lock object. Once a thread has acquired it, subsequent attempts to acquire it block, until it is released; any thread may release it.

生成器也是可腌制的,可以像这样进行测试(参见这个 SO-Q&A 此处 了解更多信息):

The generator is also picklable, which can be tested like (see this SO-Q&A here for further information):

#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
    for yielded_data in generator(data):
        pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)

重新开始,我什至建议在扩展 Keras 的 Sequence() 时实现 thread.Lock() ,例如:

Resuming this, I would even suggest to implement thread.Lock() when you extend Keras' Sequence() like:

import threading

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.lock = threading.Lock()   #Set self.lock

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        with self.lock:                #Use self.lock
            batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

            return ...

编辑 24/04/2020:

通过使用 self.lock = threading.Lock() 你可能会遇到以下错误:

By using self.lock = threading.Lock() you might run into the following error:

TypeError: 不能pickle _thread.lock 对象

TypeError: can't pickle _thread.lock objects

如果发生这种情况,请尝试将 __getitem__ 中的 with self.lock: 替换为 with threading.Lock(): 并注释掉/删除__init__中的self.lock = threading.Lock().

In case this happens try to replace with self.lock: inside __getitem__ with with threading.Lock(): and comment out / delete the self.lock = threading.Lock() inside the __init__.

在类中存储 lock 对象时似乎存在一些问题(参见例如 this Q&A).

It seems there are some problems when storing the lock-object inside a class (see for example this Q&A).

3. 是否有任何其他方法导致 Keras 可以处理与这两个示例不同的线程安全生成器?

3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?

在我的研究过程中,我没有遇到任何其他方法.当然,我不能100%肯定地说.

During my research I did not encounter any other method. Of course I cannot say this with 100% certainty.

这篇关于关于 fit_generator()/fit() 和线程安全的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆