为每个进程创建一次此客户端,然后将其重新用于所有操作.为每个请求创建一个新的客户端是一个普遍的错误,这是非常低效的.
Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.
这说明了__get_item__
中的错误.
但是,目前还不清楚,我的班级如何理解Keras创建了新的过程.
However, it is unclear, how my class can understand that the Keras has created new process.
这是我的Sequence实现的最后一个变体的伪代码(每个请求都有新连接):
Here is the pseudocode of the last variant of my Sequence implementation (new connection on each request):
import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical
class MongoSequence(Sequence):
def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
self._train_set = train_set
self._server = server
self._db = database
self.collection = collection
self._batch_size = batch_size
query = {} # train_set query
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]
def _connect(self):
client = pymongo.MongoClient(self._server)
db = self._client[self._db]
return _db[self._collection]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
也就是说,在对象构造上,我根据标准检索了所有相关的ObjectIDs
成型火车组.实际对象是在对__getitem__
的调用中从数据库中检索的.它们的ObjectIDs
是从列表切片中确定的.
That is, on object construction, I retrieve all relevant ObjectIDs
forming train set according to the criteria. Actual objects are retrieved from the database in calls to __getitem__
. Their ObjectIDs
are determined from a list slice.
此调用model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )
的代码会产生多个python进程,每个进程都会根据日志消息初始化Tensorflow后端,然后开始训练.
This code that calls model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )
spawns several python processes, each of which initializes Tensorflow backend, according to log messages, and the training starts.
但是最终,从connect
函数抛出的异常在pymongo
的深处.
But finally the exception is thrown from the function, called connect
, somewhere deeply inside pymongo
.
不幸的是,我还没有存储调用堆栈.该错误已在上面描述,我重复:[WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
.
Unfortunately, I haven't stored call stack. The error is described above, I repeat: [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
.
我的假设是此代码创建了太多与服务器的连接,因此,在__getitem__
中进行连接是错误的.
My assumption is that this code creates too many connections to the server, therefore, connecting in __getitem__
is wrong.
构造函数中的连接也是错误的,因为它是在主过程中执行的,而Mongo文档直接反对它.
Connection in the constructor is also wrong, since it is performed in main process, and Mongo docs directly object against it.
Sequence
类中还有另一个方法,on_epoch_end
.但是,我需要在纪元开始而不是结束时进行连接.
There is one more method in the Sequence
class, on_epoch_end
. But, I need a connection on epoch begin, not end.
来自Keras文档的报价:
Quote from Keras docs:
如果您想在各个时期之间修改数据集,则可以实现on_epoch_end
If you want to modify your dataset between epochs you may implement on_epoch_end
那么,有什么建议吗?文档在这里不是很具体.
So, are there any recommendations? Docs are not very specific here.
推荐答案
好像我已经找到了解决方案.解决方案是-跟踪进程ID并在更改时重新连接
Look like I've found a solution. The solution is - track process id and reconnect when it changes
class MongoSequence(Sequence):
def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
self._server = server
self._db = database
self._collection_name = collection
self._batch_size = batch_size
self._query = query
self._collection = self._connect()
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]
self._pid = os.getpid()
del self._collection # to be sure, that we've disconnected
self._collection = None
def _connect(self):
client = pymongo.MongoClient(self._server)
db = client[self._db]
return db[self._collection_name]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
if self._collection is None or self._pid != os.getpid():
self._collection = self._connect()
self._pid = os.getpid()
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
这篇关于使用keras.utils.Sequence多处理和数据库-何时连接?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!