Keras:网络无法使用fit_generator()进行训练 [英] Keras: network doesn't train with fit_generator()

查看:352
本文介绍了Keras:网络无法使用fit_generator()进行训练的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在大型数据集上使用Keras(使用MagnaTagATune数据集进行音乐自动标记).因此,我尝试将fit_generator()功能与自定义数据生成器一起使用.但是损失函数和指标的价值在培训过程中不会改变.看来我的网络根本没有训练.

I'm using Keras on the large dataset (Music autotagging with MagnaTagATune dataset). So I've tried to use fit_generator() fuction with a custom data generator. But the value of loss function and metrics doesn't change during the training process. It looks like my network doesen't train at all.

当我使用fit()函数而不是fit_generator()时,一切都很好,但是我无法将整个数据集保留在内存中.

When I use fit() function instead of fit_generator() everything is OK, but I can't keep the whole dataset in memory.

我已经尝试了Theano和TensorFlow后端

I've tried with both Theano and TensorFlow backends

主要代码:

if __name__ == '__main__':
    model = models.FCN4()
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy', 'categorical_accuracy', 'precision', 'recall'])
    gen = mttutils.generator_v2(csv_path, melgrams_dir)
    history = model.fit_generator(gen.generate(0,750),
                                  samples_per_epoch=750,
                                  nb_epoch=80,
                                  validation_data=gen.generate(750,1000,False),
                                  nb_val_samples=250)
    # RESULTS SAVING
    np.save(output_history, history.history)
    model.save(output_model)

类generator_v2:

genres = ['guitar', 'classical', 'slow', 'techno', 'strings', 'drums', 'electronic', 'rock', 'fast',
        'piano', 'ambient', 'beat', 'violin', 'vocal', 'synth', 'female', 'indian', 'opera', 'male', 'singing',
        'vocals', 'no vocals', 'harpsichord', 'loud', 'quiet', 'flute', 'woman', 'male vocal', 'no vocal',
        'pop', 'soft', 'sitar', 'solo', 'man', 'classic', 'choir', 'voice', 'new age', 'dance', 'male voice',
        'female vocal', 'beats', 'harp', 'cello', 'no voice', 'weird', 'country', 'metal', 'female voice', 'choral']

def __init__(self, csv_path, melgrams_dir):

    def get_dict_vals(dictionary, keys):
        vals = []
        for key in keys:
            vals.append(dictionary[key])
        return vals

    self.melgrams_dir = melgrams_dir
    with open(csv_path, newline='') as csvfile:
        reader = csv.DictReader(csvfile, dialect='excel-tab')
        self.labels = []
        for row in reader:
            labels_arr = np.array(get_dict_vals(
                row, self.genres)).astype(np.int)
            labels_arr = labels_arr.reshape((1, labels_arr.shape[0]))
            if (np.sum(labels_arr) > 0):
                self.labels.append((row['mp3_path'], labels_arr))
        self.size = len(self.labels)


def generate(self, begin, end):
    while(1):
        for count in range(begin, end):
            try:
                item = self.labels[count]
                mels = np.load(os.path.join(
                    self.melgrams_dir, item[0] + '.npy'))
                tags = item[1]
                yield((mels, tags))
            except FileNotFoundError:
                continue

要为fit()函数准备数组,请使用以下代码:

To prepare arrays for fit() function I use this code:

def TEST_get_data_array(csv_path, melgrams_dir):
    gen = generator_v2(csv_path, melgrams_dir).generate(0,100)
    item = next(gen)
    x = np.array(item[0])
    y = np.array(item[1])
    for i in range(0,100):
        item = next(gen.training)
        x = np.concatenate((x,item[0]),axis = 0)
        y = np.concatenate((y,item[1]),axis = 0)
    return(x,y)

对不起,如果我的代码风格不好.谢谢你!

Sorry, if the style of my code is not good. And thank you!

UPD 1: 我尝试使用return(X,y)代替yield(X,y),但没有任何变化.

UPD 1: I've tried to use return(X,y) instead of yield(X,y) but nothing changes.

我的新生成器类的一部分:

def generate(self):  
    if((self.count < self.begin) or (self.count >= self.end)):
        self.count = self.begin
    item = self.labels[self.count]
    mels = np.load(os.path.join(self.melgrams_dir, item[0] + '.npy'))
    tags = item[1]
    self.count = self.count + 1
    return((mels, tags))

def __next__(self):   # fit_generator() uses this method
    return self.generate() 

fit_generator调用:

history = model.fit_generator(tr_gen,
                              samples_per_epoch = tr_gen.size,
                              nb_epoch = 120,
                              validation_data = val_gen,
                              nb_val_samples = val_gen.size)

日志:

Epoch 1/120
10554/10554 [==============================] - 545s - loss: 1.7240 - acc: 0.8922 
Epoch 2/120
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
Epoch 3/120
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
Epoch 4/120
10554/10554 [==============================] - 526s - loss: 1.8922 - acc: 0.8820 
... etc (loss is always 1.8922; acc is always 0.8820)

推荐答案

在yield方法上,我遇到了与您相同的问题.因此,我只存储了当前索引,并使用return语句为每个调用返回了一批.

I had the same problem as you with the yield method. So i just stored the current index and returned one batch per call with the return statement.

所以我只用了return (X, y)而不是yield (X,y),它起作用了.我不确定为什么会这样.如果有人可以阐明这一点,那就太好了.

So I just used return (X, y) instead of yield (X,y) and it worked. I am not sure why this is. It would be cool if someone could shed a light on this.

您不仅需要调用生成器,还需要将生成器传递给函数.像这样:

You need to pass in the generator to the function not only call the function. Something like this:

model.fit_generator(gen, samples_per_epoch=750,
                                  nb_epoch=80,
                                  validation_data=gen,
                                  nb_val_samples=250)

Keras将在训练数据时调用__next__函数.

Keras will call your __next__ function, while training on the data.

这篇关于Keras:网络无法使用fit_generator()进行训练的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆