在fit_generator()/fit()和线程安全上 [英] On fit_generator() / fit() and thread-safety
问题描述
上下文
Context
为了在Keras中使用fit_generator()
,我使用了类似 pseudocode -one这样的生成器功能:
In order to use fit_generator()
in Keras I use a generator-function like this pseudocode-one:
def generator(data: np.array) -> (np.array, np.array):
"""Simple generator yielding some samples and targets"""
while True:
for batch in range(number_of_batches):
yield data[batch * length_sequence], data[(batch + 1) * length_sequence]
在Keras的fit_generator()
函数中,我想使用workers=4
和use_multiprocessing=True
-因此,我需要一个线程安全的生成器.
In Keras' fit_generator()
function I want to use workers=4
and use_multiprocessing=True
- Hence, I need a threadsafe generator.
在stackoverflow的答案中,例如此处或此处或Keras 文档中,我读到了有关创建从Keras.utils.Sequence()
继承的类的信息,例如这个:
In answers on stackoverflow like here or here or in the Keras docs, I read about creating a class inheriting from Keras.utils.Sequence()
like this:
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return ...
使用Sequences
Keras不会使用多个工作和多个处理发出任何警告;生成器应该是线程安全的.
By using Sequences
Keras does not throw any warning using multiple workes and multiprocessing; the generator is supposed to be threadsafe.
无论如何,由于我使用的是自定义函数,因此我偶然发现了 github 上提供的Omer Zohars代码>允许通过添加装饰器使我的generator()
线程安全.
代码如下:
Anyhow, since I am using my custom function I stumbled upon Omer Zohars code provided on github which allows to make my generator()
threadsafe by adding a decorator.
The code looks like:
import threading
class threadsafe_iter:
"""
Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()
def threadsafe_generator(f):
"""A decorator that takes a generator function and makes it thread-safe."""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
现在我可以做:
@threadsafe_generator
def generator(data):
...
问题是:使用此版本的线程安全生成器时,Keras仍会发出警告,指出使用workers > 1
和use_multiprocessing=True
时生成器必须是线程安全的,并且可以通过使用Sequences
避免这种情况.
The thing is: Using this version of a threadsafe generator Keras still emits a warning that the generator has to be threadsafe when using workers > 1
and use_multiprocessing=True
and that this can be avoided by using Sequences
.
我现在的问题是:
My questions now are:
- Keras是否仅由于生成器未继承
Sequences
而发出此警告,还是Keras还会检查生成器是否总体上是线程安全的? - 正在使用从 Keras-docs ?
- 是否存在其他导致线程安全生成器Keras的方法与上述两个示例不同?
- Does Keras emit this warning only because the generator is not inheriting
Sequences
, or does Keras also check if a generator is threadsafe in general? - Is using the approach I choosed as threadsafe as using the
generatorClass(Sequence)
-version from the Keras-docs? - Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?
在较新的tensorflow
/keras
版本(tf
> 2)中,不推荐使用fit_generator()
.相反,建议将fit()
与生成器一起使用.但是,该问题仍然适用于使用生成器的fit()
.
In newer tensorflow
/keras
-versions (tf
> 2) fit_generator()
is deprecated. Instead, it is recommended to use fit()
with the generator. However, the question still applies to fit()
using a generator as well.
推荐答案
在对此进行研究期间,我遇到了一些信息,回答了我的问题.
During my research on this I came across some information answering my questions.
注意:不建议使用新的tensorflow
/keras
-版本(tf
> 2)fit_generator()
.相反,建议将fit()
与生成器一起使用.但是,答案仍然适用于使用生成器的fit()
.
Note: As updated in the question in newer tensorflow
/keras
-versions (tf
> 2) fit_generator()
is deprecated. Instead, it is recommended to use fit()
with the generator. However, the answer still applies to fit()
using a generator as well.
1.?Keras是否仅由于生成器未继承序列而发出此警告,还是Keras还会检查生成器是否总体上是线程安全的?
1. Does Keras emit this warning only because the generator is not inheriting Sequences, or does Keras also check if a generator is threadsafe in general?
来自Keras的gitRepo( training_generators.py )我在46-52
行中找到了以下内容:
Taken from Keras' gitRepo (training_generators.py) I found in lines 46-52
the following:
use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the `keras.utils.Sequence'
' class.'))
is_sequence()
的定义取自 training_utils 624-635
行中的.py 是:
The definition of is_sequence()
taken from training_utils.py in lines 624-635
is:
def is_sequence(seq):
"""Determine if an object follows the Sequence API.
# Arguments
seq: a possible Sequence object
# Returns
boolean, whether the object follows the Sequence API.
"""
# TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
return (getattr(seq, 'use_sequence_api', False)
or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))
乱码这段代码Keras仅检查传递的生成器是否为Keras序列(或者使用Keras的序列API),并且通常不检查生成器是否是线程安全的.
Ragarding this piece of code Keras only checks if a passed generator is a Keras-sequence (or rather uses Keras' sequence API) and does not check if a generator is threadsafe in general.
2.正在使用我从 Keras-docs ?
正如Omer Zohar在
As Omer Zohar has shown on gitHub his decorator is threadsafe - I don't see any reason why it shouldn't be as threadsafe for Keras (even though Keras will warn as shown in 1.).
The implementation of thread.Lock()
can be concidered as threadsafe according to the docs:
一个工厂函数,它返回一个新的原始锁对象. 一旦线程获取了它,随后尝试获取它会阻塞,直到被释放;任何线程都可以释放它.
A factory function that returns a new primitive lock object. Once a thread has acquired it, subsequent attempts to acquire it block, until it is released; any thread may release it.
生成器也是可提取的,可以像这样进行测试(请参阅此SO-Q& A
The generator is also picklable, which can be tested like (see this SO-Q&A here for further information):
#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
for yielded_data in generator(data):
pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)
恢复这一点,我什至建议在扩展Keras的Sequence()
时实现thread.Lock()
:
Resuming this, I would even suggest to implement thread.Lock()
when you extend Keras' Sequence()
like:
import threading
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.lock = threading.Lock() #Set self.lock
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
with self.lock: #Use self.lock
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return ...
编辑24/04/2020:
使用self.lock = threading.Lock()
可能会遇到以下错误:
By using self.lock = threading.Lock()
you might run into the following error:
TypeError:无法腌制_thread.lock对象
TypeError: can't pickle _thread.lock objects
如果发生这种情况,请尝试将__getitem__
中的with self.lock:
替换为with threading.Lock():
,并注释掉/删除__init__
中的self.lock = threading.Lock()
.
In case this happens try to replace with self.lock:
inside __getitem__
with with threading.Lock():
and comment out / delete the self.lock = threading.Lock()
inside the __init__
.
It seems there are some problems when storing the lock
-object inside a class (see for example this Q&A).
3..是否还有其他方法可以导致Keras处理线程安全生成器,而这与这两个示例不同?
3. Are there any other approaches leading to a thread-safe-generator Keras can deal with which are different from these two examples?
在研究期间,我没有遇到任何其他方法. 当然,我不能100%肯定地说出来.
During my research I did not encounter any other method. Of course I cannot say this with 100% certainty.
这篇关于在fit_generator()/fit()和线程安全上的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!