关于fit_generator()和线程安全 [英] On fit_generator() and thread-safety

查看:478
本文介绍了关于fit_generator()和线程安全的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

上下文

Context

为了在Keras中使用fit_generator(),我使用了类似 pseudocode -one这样的生成器功能:

In order to use fit_generator() in Keras I use a generator-function like this pseudocode-one:

def generator(data: np.array) -> (np.array, np.array):
    """Simple generator yielding some samples and targets"""

    while True:
        for batch in range(number_of_batches):
            yield data[batch * length_sequence], data[(batch + 1) * length_sequence]

在Keras的fit_generator()函数中,我想使用workers=4use_multiprocessing=True-因此,我需要一个线程安全的生成器.

In Keras' fit_generator() function I want to use workers=4 and use_multiprocessing=True - Hence, I need a threadsafe generator.

在stackoverflow的答案中,例如此处此处或Keras 文档中,我读到了有关创建从Keras.utils.Sequence()继承的类的信息,例如这个:

In answers on stackoverflow like here or here or in the Keras docs, I read about creating a class inheriting from Keras.utils.Sequence() like this:

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return ...

使用Sequences Keras不会使用多个工作和多个处理发出任何警告;生成器应该是线程安全的.

By using Sequences Keras does not throw any warning using multiple workes and multiprocessing; the generator is supposed to be threadsafe.

无论如何,由于我使用的是自定义函数,因此我偶然发现了 github 上提供的Omer Zohars代码>允许通过添加装饰器使我的generator()线程安全. 代码如下:

Anyhow, since I am using my custom function I stumbled upon Omer Zohars code provided on github which allows to make my generator() threadsafe by adding a decorator. The code looks like:

import threading

class threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe."""
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))

    return g

现在我可以做:

@threadsafe_generator
def generator(data):
    ...

问题是:使用此版本的线程安全生成器时,Keras仍会发出警告,指出使用workers > 1use_multiprocessing=True时生成器必须是线程安全的,并且可以通过使用Sequences避免这种情况.

The thing is: Using this version of a threadsafe generator Keras still emits a warning that the generator has to be threadsafe when using workers > 1 and use_multiprocessing=True and that this can be avoided by using Sequences.


我现在的问题是:


My questions now are:

  1. Keras是否仅由于生成器未继承Sequences而发出此警告,还是Keras还会检查生成器是否总体上是线程安全的?
  2. 正在使用从 Keras-docs ?
  3. 是否有其他方法可以导致线程安全生成器Keras处理,这不同于这两个示例?
  1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?
  2. Is using the approach I choosed as threadsafe as using the generatorClass(Sequence)-version from the Keras-docs?
  3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?

推荐答案

在对此进行研究期间,我遇到了一些信息,回答了我的问题.

During my research on this I came across some information answering my questions.

1.?Keras是否仅由于生成器未继承序列而发出此警告,还是Keras还会检查生成器是否总体上是线程安全的?

1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?

来自Keras的gitRepo( training_generators.py )我在46-52行中找到了以下内容:

Taken from Keras' gitRepo (training_generators.py) I found in lines 46-52 the following:

use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
    warnings.warn(
        UserWarning('Using a generator with `use_multiprocessing=True`'
                    ' and multiple workers may duplicate your data.'
                    ' Please consider using the `keras.utils.Sequence'
                    ' class.'))

is_sequence()的定义来自 training_utils 624-635行中的.py 是:

The definition of is_sequence() taken from training_utils.py in lines 624-635 is:

def is_sequence(seq):
    """Determine if an object follows the Sequence API.
    # Arguments
        seq: a possible Sequence object
    # Returns
        boolean, whether the object follows the Sequence API.
    """
    # TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
    return (getattr(seq, 'use_sequence_api', False)
            or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))

乱码这段代码Keras仅检查传递的生成器是否为Keras序列(或者使用Keras的序列API),而不检查生成器是否总体上是线程安全的.

Ragarding this piece of code Keras only checks if a passed generator is a Keras-sequence (or rather uses Keras' sequence API) and does not check if a generator is threadsafe in general.


2.正在使用我从 Keras-docs ?

正如Omer Zohar在 gitHub 他的装饰器是线程安全的-我看不到为何它不应该对Keras具有线程安全性的任何原因(即使Keras会发出警告,如图1所示). 根据 docs :

As Omer Zohar has shown on gitHub his decorator is threadsafe - I don't see any reason why it shouldn't be as threadsafe for Keras (even though Keras will warn as shown in 1.). The implementation of thread.Lock() can be concidered as threadsafe according to the docs:

一个工厂函数,它返回一个新的原始锁对象. 一旦线程获取了它,随后尝试获取它会阻塞,直到被释放;任何线程都可以释放它.

A factory function that returns a new primitive lock object. Once a thread has acquired it, subsequent attempts to acquire it block, until it is released; any thread may release it.

生成器也是可提取的,可以像这样进行测试(请参阅此SO-Q& A

The generator is also picklable, which can be tested like (see this SO-Q&A here for further information):

#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
    for yielded_data in generator(data):
        pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)

恢复这一点,我什至建议在扩展Keras的Sequence()时实现thread.Lock()

Resuming this, I would even suggest to implement thread.Lock() when you extend Keras' Sequence() like:

import threading

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.lock = threading.Lock()   #Set self.lock

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        with self.lock:                #Use self.lock
            batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

            return ...

编辑24/04/2020:

使用self.lock = threading.Lock()可能会遇到以下错误:

By using self.lock = threading.Lock() you might run into the following error:

TypeError:无法腌制_thread.lock对象

TypeError: can't pickle _thread.lock objects

如果发生这种情况,请尝试将__getitem__中的with self.lock:替换为with threading.Lock():,并注释掉/c删除__init__中的self.lock = threading.Lock().

In case this happens try to replace with self.lock: inside __getitem__ with with threading.Lock(): and comment out / delete the self.lock = threading.Lock() inside the __init__.

在类中存储lock对象时似乎存在一些问题(例如,参见

It seems there are some problems when storing the lock-object inside a class (see for example this Q&A).


3..是否还有其他方法可以导致Keras处理线程安全生成器,而这与这两个示例不同?

3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?

在研究期间,我没有遇到任何其他方法. 当然,我不能100%肯定地说出来.

During my research I did not encounter any other method. Of course I cannot say this with 100% certainty.

这篇关于关于fit_generator()和线程安全的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆