model.get_weights()在训练后由于NaN屏蔽而返回NaN数组 [英] model.get_weights() returning array of NaNs after training due to NaN masking

查看:100
本文介绍了model.get_weights()在训练后由于NaN屏蔽而返回NaN数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试训练LSTM对各种长度的序列进行分类.我想获得此模型的权重,因此可以在模型的有状态版本中使用它们.训练前,体重是正常的.另外,训练似乎已成功进行,并且误差逐渐减小.但是,当我将掩码值从-10更改为np.Nan时,mod.get_weights()开始返回NaN s的数组,并且验证错误突然降至接近零的值.为什么会这样?

I'm trying to train an LSTM to classify sequences of various lengths. I want to get the weights of this model, so I can use them in stateful version of the model. Before training, the weights are normal. Also, the training seems to run successfully, with a gradually decreasing error. However, when I change the mask value from -10 to np.Nan, mod.get_weights() starts returning arrays of NaNs and the validation error drops suddenly to a value close to zero. Why is this occurring?

from keras import models
from keras.layers import Dense, Masking, LSTM
from keras.optimizers import RMSprop
from keras.losses import categorical_crossentropy
from keras.preprocessing.sequence import pad_sequences

import numpy as np
import matplotlib.pyplot as plt


def gen_noise(noise_len, mag):
    return np.random.uniform(size=noise_len) * mag


def gen_sin(t_val, freq):
    return 2 * np.sin(2 * np.pi * t_val * freq)


def train_rnn(x_train, y_train, max_len, mask, number_of_categories):
    epochs = 3
    batch_size = 100

    # three hidden layers of 256 each
    vec_dims = 1
    hidden_units = 256
    in_shape = (max_len, vec_dims)

    model = models.Sequential()

    model.add(Masking(mask, name="in_layer", input_shape=in_shape,))
    model.add(LSTM(hidden_units, return_sequences=False))
    model.add(Dense(number_of_categories, input_shape=(number_of_categories,),
              activation='softmax', name='output'))

    model.compile(loss=categorical_crossentropy, optimizer=RMSprop())

    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,
              validation_split=0.05)

    return model


def gen_sig_cls_pair(freqs, t_stops, num_examples, noise_magnitude, mask, dt=0.01):
    x = []
    y = []

    num_cat = len(freqs)

    max_t = int(np.max(t_stops) / dt)

    for f_i, f in enumerate(freqs):
        for t_stop in t_stops:
            t_range = np.arange(0, t_stop, dt)
            t_len = t_range.size

            for _ in range(num_examples):
                sig = gen_sin(f, t_range) + gen_noise(t_len, noise_magnitude)
                x.append(sig)

                one_hot = np.zeros(num_cat, dtype=np.bool)
                one_hot[f_i] = 1
                y.append(one_hot)

    pad_kwargs = dict(padding='post', maxlen=max_t, value=mask, dtype=np.float32)
    return pad_sequences(x, **pad_kwargs), np.array(y)


if __name__ == '__main__':
    noise_mag = 0.01
    mask_val = -10
    frequencies = (5, 7, 10)
    signal_lengths = (0.8, 0.9, 1)
    dt_val = 0.01

    x_in, y_in = gen_sig_cls_pair(frequencies, signal_lengths, 50, noise_mag, mask_val)
    mod = train_rnn(x_in[:, :, None], y_in, int(np.max(signal_lengths) / dt_val), mask_val, len(frequencies))

即使我将网络体系结构更改为return_sequences=True并用TimeDistributed包裹Dense层,也不会删除LSTM层.

This persists even if I change the network architecture to return_sequences=True and wrap the Dense layer with TimeDistributed, nor does removing the LSTM layer.

推荐答案

我遇到了同样的问题.在您的情况下,我可以看到它可能有所不同,但有人可能遇到相同的问题,因此可以从Google那里来.因此,在我的情况下,我将sample_weight参数传递给fit()方法,并且当样本权重中包含一些零时,get_weights()返回具有NaNs的数组.当我省略了sample_weight = 0的样本时(如果sample_weight = 0则它们毫无用处),它开始起作用.

I had the same problem. In your case I can see it was probably something different but someone might have the same problem and come here from Google. So in my case I was passing sample_weight parameter to fit() method and when the sample weights contained some zeros in it, get_weights() was returning an array with NaNs. When I omitted the samples where sample_weight=0 (they were useless anyway if sample_weight=0), it started to work.

这篇关于model.get_weights()在训练后由于NaN屏蔽而返回NaN数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆