使用keras.utils.Sequence多处理和数据库-何时连接? [英] Using keras.utils.Sequence multiprocessing and data base - when to connect?

查看:119
本文介绍了使用keras.utils.Sequence多处理和数据库-何时连接?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在用带有Tensorflow后端的Keras训练神经网络.数据集不适合RAM,因此,我将其存储在Mongo数据库中,并使用keras.utils.Sequence的子类检索批次.

I'm training a neural network with Keras with Tensorflow backend. Data set does not fit in RAM, therefore, I store it in the Mongo database and retrieve batches using subclass of keras.utils.Sequence.

如果我使用use_multiprocessing=False运行model.fit_generator(),则一切正常.

Everything works fine, if I run model.fit_generator() with use_multiprocessing=False.

启用多处理功能时,在生成工作程序时或与数据库连接时都会出错.

When I turn on multiprocessing, I get errors either during spawning of workers or in connection to the data base.

如果我在__init__中创建连接,则会遇到一个异常,该异常的文本说明了腌制锁定对象中的错误.抱歉,我记不清了.但是培训甚至还没有开始.

If I create a connection in __init__, I've got an exception whose text says something about errors in pickling lock objects. Sorry, I don't remember exactly. But the training even does not start.

如果我在__get_item__中创建连接,则培训开始并运行一些纪元,然后出现错误[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted.

If I create a connection in __get_item__, the training starts and runs some epochs, then I get errors [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted.

根据 pyMongo手册,不是分叉安全的,每个子进程都必须创建自己的数据库连接.我使用Windows,它不使用forks,而是生成进程,但是,恕我直言,这里的区别并不重要.

According to the pyMongo manuals, it is not fork-safe, and each child process must create its own connection to the data base. I use Windows, that does not use forks, but spawns processes instead, however, the difference does not matter here, IMHO.

这说明了为什么无法在__init__中进行连接.

This explains, why it is impossible to connect in __init__.

这里是

为每个进程创建一次此客户端,然后将其重新用于所有操作.为每个请求创建一个新的客户端是一个普遍的错误,这是非常低效的.

Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.

这说明了__get_item__中的错误.

但是,目前还不清楚,我的班级如何理解Keras创建了新的过程.

However, it is unclear, how my class can understand that the Keras has created new process.

这是我的Sequence实现的最后一个变体的伪代码(每个请求都有新连接):

Here is the pseudocode of the last variant of my Sequence implementation (new connection on each request):

import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical

class MongoSequence(Sequence):
    def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
        self._train_set = train_set
        self._server = server
        self._db = database
        self.collection = collection
        self._batch_size = batch_size

        query = {}  # train_set query
        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = self._client[self._db]
        return _db[self._collection]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y

也就是说,在对象构造上,我根据标准检索了所有相关的ObjectIDs成型火车组.实际对象是在对__getitem__的调用中从数据库中检索的.它们的ObjectIDs是从列表切片中确定的.

That is, on object construction, I retrieve all relevant ObjectIDs forming train set according to the criteria. Actual objects are retrieved from the database in calls to __getitem__. Their ObjectIDs are determined from a list slice.

此调用model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )的代码会产生多个python进程,每个进程都会根据日志消息初始化Tensorflow后端,然后开始训练.

This code that calls model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... ) spawns several python processes, each of which initializes Tensorflow backend, according to log messages, and the training starts.

但是最终,从connect函数抛出的异常在pymongo的深处.

But finally the exception is thrown from the function, called connect, somewhere deeply inside pymongo.

不幸的是,我还没有存储调用堆栈.该错误已在上面描述,我重复:[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted.

Unfortunately, I haven't stored call stack. The error is described above, I repeat: [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted.

我的假设是此代码创建了太多与服务器的连接,因此,在__getitem__中进行连接是错误的.

My assumption is that this code creates too many connections to the server, therefore, connecting in __getitem__ is wrong.

构造函数中的连接也是错误的,因为它是在主过程中执行的,而Mongo文档直接反对它.

Connection in the constructor is also wrong, since it is performed in main process, and Mongo docs directly object against it.

Sequence类中还有另一个方法,on_epoch_end.但是,我需要在纪元开始而不是结束时进行连接.

There is one more method in the Sequence class, on_epoch_end. But, I need a connection on epoch begin, not end.

来自Keras文档的报价:

Quote from Keras docs:

如果您想在各个时期之间修改数据集,则可以实现on_epoch_end

If you want to modify your dataset between epochs you may implement on_epoch_end

那么,有什么建议吗?文档在这里不是很具体.

So, are there any recommendations? Docs are not very specific here.

推荐答案

好像我已经找到了解决方案.解决方案是-跟踪进程ID并在更改时重新连接

Look like I've found a solution. The solution is - track process id and reconnect when it changes

class MongoSequence(Sequence):
    def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
        self._server = server
        self._db = database
        self._collection_name = collection
        self._batch_size = batch_size
        self._query = query
        self._collection = self._connect()

        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]

        self._pid = os.getpid()
        del self._collection   #  to be sure, that we've disconnected
        self._collection = None

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = client[self._db]
        return db[self._collection_name]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        if self._collection is None or self._pid != os.getpid():
            self._collection = self._connect()
            self._pid = os.getpid()

        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y

这篇关于使用keras.utils.Sequence多处理和数据库-何时连接?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆