使用Keras ImageDataGenerator时多输入模型中的内存错误 [英] Memory error in multi-inputs model when using Keras ImageDataGenerator

查看:80
本文介绍了使用Keras ImageDataGenerator时多输入模型中的内存错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

使用Keras ImageDataGenerator时不起作用.任何人都可以告诉我这些代码的问题,谢谢.

It's not working when using Keras ImageDataGenerator. Anyone could tell me the problem of these codes, thx.

  • keras:2.1.5
  • TFgpu:1.4.0
  • 操作系统:Win 10

错误如下:

Epoch 1/50

Epoch 1/50

98/27100 [..............................]

98/27100 [..............................]

............

............

MemoryError

MemoryError

如何解决此MemoryError?

How can I resolve this MemoryError?

X = {
    'anc_input': anc_ins,
    'pos_input': pos_ins,
    'neg_input': neg_ins
}

anc_ins_te = te_pairs[:, 0]
pos_ins_te = te_pairs[:, 1]
neg_ins_te = te_pairs[:, 2]

X_te = {
    'anc_input': anc_ins_te,
    'pos_input': pos_ins_te,
    'neg_input': neg_ins_te
}

# ------------------------------------------
# self.model.fit(
#     X, np.ones(len(anc_ins)),
#     batch_size=32,
#     epochs=50,
#     validation_data=[X_te, np.ones(len(anc_ins_te))],
#     # verbose=1,
#     callbacks=self.callbacks)
# ------------------------------------------
aug = ImageDataGenerator(rotation_range=5,
                         zoom_range=0.15,
                         width_shift_range=0.2,
                         height_shift_range=0.2,
                         fill_mode="constant",
                         cval=0)
batch_size = 2
y = np.ones(batch_size)

def gen_flow_multi_inputs(X, y):
    while True:
        XX = {}
        for k, X_ in X.items():
            gen_X_ = aug.flow(X_, batch_size=batch_size, seed=7)
            XX[k] = gen_X_.next()
        yield XX, y

self.model.fit_generator(gen_flow_multi_inputs(X, y),
                         validation_data=[X_te, np.ones(len(anc_ins_te))],
                         steps_per_epoch=len(anc_ins) // batch_size,
                         epochs=50,
                         callbacks=self.callbacks)

推荐答案

我已解决问题:)

batch_size = 32
# y = np.ones(batch_size)
aug.fit(X['anc_input'])

def gen_flow_multi_inputs(X):
    gen_X_ = {}
    for k, X_ in X.items():
        gen_X_[k] = aug.flow(X_, batch_size=batch_size, seed=7)
    while True:
        XX = {}
        for k, X_ in X.items():
            XX[k] = gen_X_[k].next()
        N = len(XX['anc_input'])
        yield XX, np.ones(N)

self.model.fit_generator(gen_flow_multi_inputs(X),
                         validation_data=[X_te, np.ones(len(anc_ins_te))],
                         steps_per_epoch=len(anc_ins) // batch_size,
                         epochs=50,
                         callbacks=self.callbacks)

这篇关于使用Keras ImageDataGenerator时多输入模型中的内存错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆