尽管掩盖了对喀拉拉邦零填充小批量LSTM训练的支持,但零预测 [英] Zero predictions despite masking support for zero-padded mini batch LSTM training in keras

查看:199
本文介绍了尽管掩盖了对喀拉拉邦零填充小批量LSTM训练的支持,但零预测的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

问题陈述

我正在使用带有Tensorflow后端(tf版本1.13.1)的标记文本序列在keras中训练多对多LSTM,以使用预训练的GloVe嵌入来预测序列中每个元素的标记.我的训练方法包括小型批次随机梯度下降,每个小型批次矩阵逐列填充零,以确保向网络输入相等的长度.

I’m training a many-to-many LSTM in keras with tensorflow backend (tf version 1.13.1) on tagged text sequences to predict the tag of each element in the sequence using pretrained GloVe embeddings. My training regime involves mini batch stochastic gradient descent, with each mini batch matrix zero-padded column-wise to ensure equal length input to the network.

至关重要的是,由于任务和数据的性质,在我的迷你批处理中存在自定义约束,因此我没有使用keras嵌入层.我的目标是对填充零的单元格实施屏蔽机制,以确保损失计算不会将这些单元格虚假地视为真实数据点.

Crucially, because of custom constrains on my mini batches due to the nature of the task and the data, I am not using the keras embedding layer. My goal is to implement a masking mechanism for my zero-padded cells to ensure the loss computation does not spuriously treat these cells as genuine data points.

方法

如keras 文档所述,keras有三种掩蔽方式可以设置图层:

As explained in the keras documentation, keras has three ways in which a masking layer can be set up:

  1. 使用mask_zero配置keras.layers.Embedding层 设置为True.
  2. 添加keras.layers.Masking层;
  3. 在调用循环图层时手动传递mask参数.
  1. Configuring a keras.layers.Embedding layer with mask_zero set to True.
  2. Adding a keras.layers.Masking layer;
  3. Passing a mask argument manually when calling recurrent layers.

因为我没有使用嵌入层来编码我的数据进行训练,所以对我而言,带有蒙版嵌入层的选项(1)不可用.因此,我选择了(2)并在初始化模型后立即添加了遮罩层.但是,此更改似乎没有产生作用.实际上,不仅模型的准确性没有得到改善,而且在预测阶段,模型仍会生成零个预测.为什么我的遮罩层不遮罩零填充的单元格?

Because I am not using an embedding layer to encode my data for training, option (1) with a masked embedding layer is not available to me. So instead, I chose (2) and added a masking layer right after initializing my model. This change, however, does not seem to have had an effect. In fact, not only has the accuracy of my model not improved, at the prediction stage the model still generates zero predictions. Why does my masking layer not mask zero-padded cells? Could it have to do with the fact that in my dense layer I'm specifying 3 classes rather than 2 (thus including 0 as a separate class)?

现有资源的限制

有人问过类似的问题,但我无法用它们来解决我的问题.尽管帖子未收到直接回复,注释中提到的链接帖子着重介绍了如何预处理数据分配掩码值,此处无争议.但是,掩膜层初始化与此处使用的掩膜初始化相同. 帖子提到了相同的问题-遮罩层对性能没有影响-答案与我的定义方式相同,但它定义了遮罩层,但再次着重于将特定值转换为遮罩值.最后,帖子中的答案提供了相同的层初始化无需进一步说明.

Similar questions have been asked and answered, but I wasn't able to use them to resolve my issue. While this post received no direct response, a linked post mentioned in a comment focuses on how to preprocess data to assign mask value, which is uncontroversial here. The masking layer initializtion, however, is identical to the one used here. This post mentions the same problem - a masking layer has no effect on performance - and the answer defines the masking layer in the same way as I do, but again focuses on converting specific values to mask values. Finally, the answer in this post provides the same layer initialization without elaborating further.

玩具数据生成

为重现我的问题,我生成了一个包含两个类(1,2)的玩具10批次数据集.批处理是一个可变长度序列,后补零,最大长度为20个嵌入,每个嵌入向量由5个单元组成,因此input_shape=(20,5).这两个类别的嵌入值是从不同但部分重叠的截断正态分布生成的,从而为网络创建了一个可学习但并非无关紧要的问题.我已经在下面包含了玩具数据,以便您可以重现该问题.

To reproduce my problem, I have generated a toy 10-batch dataset with two classes (1,2). A batch is a variable-length sequence post-padded with zeros to a maximum length of 20 embeddings, with each embedding vector consisting of 5 cells, so input_shape=(20,5). Embedding values for the two classes were generated from different but partially-overlapping truncated normal distributions to create a learnable but not trivial problem for the network. I've included the toy data below so you can reproduce the problem.

import pandas as pd
from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed, Bidirectional, Dropout, Masking
from keras import optimizers

# *** model initialization ***

model = Sequential()
model.add(Masking(mask_value=0., input_shape=(20, 5))) # <- masking layer here
model.add(Bidirectional(LSTM(20, return_sequences=True), input_shape=(20, 5)))
model.add(Dropout(0.2))
model.add(TimeDistributed(Dense(3, activation='sigmoid')))

sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd, metrics=['mse'])

# *** model training ***

for epoch in range(10):

    for X,y in data_train:

        X = X.reshape(1, 20, 5)
        y = y.reshape(1, 20, 1)

        history = model.fit(X, y, epochs=1, batch_size=20, verbose=0)

# *** model prediction ***

preds = pd.DataFrame(columns=['true', 'pred'])

for index, (X,y) in enumerate(data_test):
    X = X.reshape(1, 20, 5)
    y = y.reshape(1, 20, 1)

    y_pred = model.predict_classes(X, verbose=0)

    df = pd.DataFrame(columns=['true', 'pred'])

    df['true'] = [y[0, i][0] for i in range(20)]
    df['pred'] = [y_pred[0, i] for i in range(20)]

    preds = preds.append(df, ignore_index=True)

# convert true labels to int & drop padded rows (where y_true=0)
preds['true'] = [int(label) for label in preds['true']]
preds = preds[preds['true']!=0]

这是带遮罩的模型的摘要:

This is the summary of the model with masking:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
masking_2 (Masking)          (None, 20, 5)             0         
_________________________________________________________________
bidirectional_4 (Bidirection (None, 20, 40)            4160      
_________________________________________________________________
dropout_4 (Dropout)          (None, 20, 40)            0         
_________________________________________________________________
time_distributed_4 (TimeDist (None, 20, 3)             123       
=================================================================
Total params: 4,283
Trainable params: 4,283
Non-trainable params: 0

我训练了一个带有掩膜层的模型和一个没有掩膜层的模型,并使用以下方法计算了准确性:

I trained one model with and one without the masking layer and calculated accuracy using:

np.round(sum(preds['true']==preds['pred'])/len(preds)*100,1)

我没有遮罩的模型的准确度为53.3%,有遮罩的模型的准确度为33.3%.更令人惊讶的是,在这两个模型中,我一直将零作为预测标签.为什么遮罩层无法忽略填充零的单元格?

I got 53.3% accuracy for the model without masking and 33.3% for the model with masking. More surprisingly, I kept on getting zero as a predicted label in both models. Why does the masking layer fail to ignore zero-padded cells?

用于复制问题的数据:

data_train = list(zip(X_batches_train, y_batches_train))
data_test = list(zip(X_batches_test, y_batches_test))

X_batches_train

X_batches_train

[array([[-1.00612917,  1.47313952,  2.68021318,  1.54875809,  0.98385996,
          1.49465265,  0.60429106,  1.12396908, -0.24041602,  1.77266187,
          0.1961381 ,  1.28019637,  1.78803092,  2.05151245,  0.93606708,
          0.51554755,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.97596563,  2.04536053,  0.88367922,  1.013342  , -0.16605355,
          3.02994344,  2.04080806, -0.25153046, -0.5964068 ,  2.9607247 ,
         -0.49722121,  0.02734492,  2.16949987,  2.77367066,  0.15628842,
          2.19823207,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.31546283,  3.27420503,  3.23550769, -0.63724013,  0.89150128,
          0.69774266,  2.76627308, -0.58408384, -0.45681779,  1.98843041,
         -0.31850477,  0.83729882,  0.45471165,  3.61974147, -1.45610756,
          1.35217453,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.03329532,  1.97471646,  1.33949611,  1.22857243, -1.46890642,
          1.74105506,  1.40969261,  0.52465603, -0.18895266,  2.81025597,
          2.64901037, -0.83415186,  0.76956826,  1.48730868, -0.16190164,
          2.24389007,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.0676654 ,  3.08429323,  1.7601179 ,  0.85448051,  1.15537064,
          2.82487842,  0.27891413,  0.57842569, -0.62392063,  1.00343057,
          1.15348843, -0.37650332,  3.37355345,  2.22285473,  0.43444434,
          0.15743873,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.05258873, -0.17897376, -0.99932932, -1.02854121,  0.85159208,
          2.32349131,  1.96526709, -0.08398597, -0.69474809,  1.32820222,
          1.19514151,  1.56814867,  0.86013263,  1.48342922,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.1920635 , -0.48702788,  1.24353985, -1.3864121 ,  0.16713229,
          3.10134683,  0.61658271, -0.63360643,  0.86000807,  2.74876157,
          2.87604877,  0.16339724,  2.87595396,  3.2846962 ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.1380241 , -0.76783029,  0.18814436, -1.18165209, -0.02981728,
          1.49908113,  0.61521007, -0.98191097,  0.31250199,  1.39015803,
          3.16213211, -0.70891214,  3.83881766,  1.92683533,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.39080778, -0.59179216,  0.80348201,  0.64638205, -1.40144268,
          1.49751413,  3.0092166 ,  1.33099666,  1.43714841,  2.90734268,
          3.09688943,  0.32934884,  1.14592787,  1.58152023,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.77164353,  0.50293096,  0.0717377 ,  0.14487556, -0.90246591,
          2.32612179,  1.98628857,  1.29683166, -0.12399569,  2.60184685,
          3.20136653,  0.44056647,  0.98283455,  1.79026663,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-0.93359914,  2.31840281,  0.55691601,  1.90930758, -1.58260431,
         -1.05801881,  3.28012523,  3.84105406, -1.2127093 ,  0.00490079,
          1.28149304,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.03105486,  2.7703693 ,  0.16751813,  1.12127987, -0.44070271,
         -0.0789227 ,  2.79008301,  1.11456745,  1.13982551, -1.10128658,
          0.87430834,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.69710668,  1.72702833, -2.62599502,  2.34730002,  0.77756661,
          0.16415884,  3.30712178,  1.67331828, -0.44022431,  0.56837829,
          1.1566811 ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.71845983,  1.79908544,  0.37385522,  1.3870915 , -1.48823234,
         -1.487419  ,  3.0879945 ,  1.74617784, -0.91538815, -0.24244522,
          0.81393954,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-1.38501563,  3.73330047, -0.52494265,  2.37133716, -0.24546709,
         -0.28360782,  2.89384717,  2.42891743,  0.40144022, -1.21850571,
          2.00370751,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.27989188,  1.16254538, -0.06889142,  1.84133355,  1.3234908 ,
          1.29611702,  2.0019294 , -0.03220116,  1.1085194 ,  1.96495985,
          1.68544302,  1.94503544,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.3004439 ,  2.48768923,  0.59809607,  2.38155155,  2.78705889,
          1.67018683,  0.21731778, -0.59277191,  2.87427207,  2.63950475,
          2.39211459,  0.93083423,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.39239371,  0.30900383, -0.97307155,  1.98100711,  0.30613735,
          1.12827171,  0.16987791,  0.31959096,  1.30366416,  1.45881023,
          2.45668401,  0.5218711 ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.0826574 ,  2.05100254,  0.013161  ,  2.95120798,  1.15730011,
          0.75537024,  0.13708569, -0.44922143,  0.64834001,  2.50640862,
          2.00349347,  3.35573624,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.47135124,  2.10258532,  0.70212032,  2.56063126,  1.62466971,
          2.64026892,  0.21309489, -0.57752813,  2.21335957,  0.20453233,
          0.03106993,  3.01167822,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-0.42125521,  0.54016939,  1.63016057,  2.01555253, -0.10961255,
         -0.42549555,  1.55793753, -0.0998756 ,  0.36417335,  3.37126414,
          1.62151191,  2.84084192,  0.10831384,  0.89293054, -0.08671363,
          0.49340353,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.37615411,  2.00581062,  2.30426605,  2.02205839,  0.65871664,
          1.34478836, -0.55379752, -1.42787727,  0.59732227,  0.84969282,
          0.54345723,  0.95849568, -0.17131602, -0.70425277, -0.5337757 ,
          1.78207229,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.13863276,  1.71490034,  2.02677925,  2.60608619,  0.26916522,
          0.35928298, -1.26521844, -0.59859219,  1.19162219,  1.64565259,
          1.16787165,  2.95245196,  0.48681084,  1.66621053,  0.918077  ,
         -1.10583747,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.87763797,  2.38740754,  2.9111822 ,  2.21184069,  0.78091173,
         -0.53270909,  0.40100338, -0.83375593,  0.9860009 ,  2.43898437,
         -0.64499989,  2.95092003, -1.52360727,  0.44640918,  0.78131922,
         -0.24401283,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.92615066,  3.45437746,  3.28808981,  2.87207404, -1.60027223,
         -1.14164941, -1.63807699,  0.33084805,  2.92963629,  3.51170824,
         -0.3286093 ,  2.19108385,  0.97812366, -1.82565766, -0.34034678,
         -2.0485913 ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 1.96438618e+00,  1.88104784e-01,  1.61114494e+00,
          6.99567690e-04,  2.55271963e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 2.41578815e+00, -5.70625661e-01,  2.15545894e+00,
         -1.80948908e+00,  1.62049331e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 1.97017040e+00, -1.62556528e+00,  2.49469152e+00,
          4.18785985e-02,  2.61875866e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 3.14277819e+00,  3.01098398e-02,  7.40376369e-01,
          1.76517344e+00,  2.68922918e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00],
        [ 2.06250296e+00,  4.67605528e-01,  1.55927230e+00,
          1.85788889e-01,  1.30359922e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00]]),
 array([[ 1.22152427,  3.74926839,  0.64415552,  2.35268329,  1.98754653,
          2.89384829,  0.44589817,  3.94228743,  2.72405657,  0.86222004,
          0.68681903,  3.89952458,  1.43454512,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [-0.02203262,  0.95065123,  0.71669023,  0.02919391,  2.30714524,
          1.91843002,  0.73611294,  1.20560482,  0.85206836, -0.74221506,
         -0.72886308,  2.39872927, -0.95841402,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.55775319,  0.33773314,  0.79932151,  1.94966883,  3.2113281 ,
          2.70768249, -0.69745554,  1.23208345,  1.66199957,  1.69894081,
          0.13124461,  1.93256147, -0.17787952,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.45089205,  2.62430534, -1.9517961 ,  2.24040577,  1.75642049,
          1.94962325,  0.26796497,  2.28418304,  1.44944487,  0.28723885,
         -0.81081633,  1.54840214,  0.82652939,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.27678173,  1.17204606, -0.24738322,  1.02761617,  1.81060444,
          2.37830861,  0.55260134,  2.50046334,  1.04652821,  0.03467176,
         -2.07336654,  1.2628897 ,  0.61604732,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 3.86138405,  2.35068317, -1.90187438,  0.600788  ,  0.18011722,
          1.3469559 , -0.54708828,  1.83798823, -0.01957845,  2.88713217,
          3.1724991 ,  2.90802072,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.26785642,  0.51076756,  0.32070756,  2.33758816,  2.08146669,
         -0.60796736,  0.93777509,  2.70474711,  0.44785738,  1.61720609,
          1.52890594,  3.03072971,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 3.30219394,  3.1515445 ,  1.16550716,  2.07489374,  0.66441859,
          0.97529244,  0.35176367,  1.22593639, -1.80698271,  1.19936482,
          3.34017172,  2.15960657,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.34839018,  2.24827352, -1.61070856,  2.81044265, -1.21423372,
          0.24633846, -0.82196609,  2.28616568,  0.033922  ,  2.7557593 ,
          1.16178372,  3.66959512,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.32913219,  1.63231852,  0.58642744,  1.55873546,  0.86354741,
          2.06654246, -0.44036504,  3.22723595,  1.33279468,  0.05975892,
          2.48518999,  3.44690602,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[ 0.61424344, -1.03068819, -1.47929328,  2.91514641,  2.06867196,
          1.90384921, -0.45835234,  1.22054782,  0.67931536,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.76480464,  1.12442631, -2.36004758,  2.91912726,  1.67891181,
          3.76873596, -0.93874096, -0.32397781, -0.55732374,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.39953353, -1.26828104,  0.44482517,  2.85604975,  3.08891062,
          2.60268725, -0.15785176,  1.58549879, -0.32948578,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.65156484, -1.56545168, -1.42771206,  2.74216475,  1.8758154 ,
          3.51169147,  0.18353058, -0.14704149,  0.00442783,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.27736372,  0.37407608, -1.25713475,  0.53171176,  1.53714914,
          0.21015523, -1.06850669, -0.09755327, -0.92373834,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]]),
 array([[-1.39160433,  0.21014669, -0.89792475,  2.6702794 ,  1.54610601,
          0.84699037,  2.96726482,  1.84236946,  0.02211578,  0.32842575,
          1.02718924,  1.78447936, -1.20056829,  2.26699318, -0.23156537,
          2.50124959,  1.93372501,  0.10264369, -1.70813962,  0.        ],
        [ 0.38823591, -1.30348049, -0.31599117,  2.60044143,  2.32929389,
          1.40348483,  3.25758736,  1.92210728, -0.34150988, -1.22336921,
          2.3567069 ,  1.75456835,  0.28295694,  0.68114898, -0.457843  ,
          1.83372069,  2.10177851, -0.26664178, -0.26549595,  0.        ],
        [ 0.08540346,  0.71507504,  1.78164285,  3.04418137,  1.52975256,
          3.55159169,  3.21396003,  3.22720346,  0.68147142,  0.12466013,
         -0.4122895 ,  1.97986653,  1.51671949,  2.06096825, -0.6765908 ,
          2.00145086,  1.73723014,  0.50186043, -2.27525744,  0.        ],
        [ 0.00632717,  0.3050794 , -0.33167875,  1.48109172,  0.19653696,
          1.97504239,  2.51595821,  1.74499313, -1.65198805, -1.04424953,
         -0.23786945,  1.18639347, -0.03568057,  3.82541131,  2.84039446,
          2.88325909,  1.79827675, -0.80230291,  0.08165052,  0.        ],
        [ 0.89980086,  0.34690991, -0.60806566,  1.69472308,  1.38043417,
          0.97139487,  0.21977176,  1.01340944, -1.69946943, -0.01775586,
         -0.35851919,  1.81115864,  1.15105661,  1.21410373,  1.50667558,
          1.70155313,  3.1410754 , -0.54806167, -0.51879299,  0.        ]])]

y_batches_train

y_batches_train

[array([1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 1., 1., 2., 2., 1., 2., 0.,
        0., 0., 0.]),
 array([1., 1., 1., 1., 1., 2., 2., 1., 1., 2., 2., 1., 2., 2., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 2., 1., 2., 1., 1., 2., 2., 1., 1., 2., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([2., 2., 1., 2., 2., 2., 1., 1., 2., 2., 2., 2., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 2., 2., 2., 1., 1., 1., 1., 2., 2., 1., 2., 1., 1., 1., 1., 0.,
        0., 0., 0.]),
 array([2., 1., 2., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 2., 1., 2., 2., 2., 1., 2., 2., 1., 1., 2., 1., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([2., 2., 1., 2., 1., 1., 1., 2., 1., 2., 2., 2., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([2., 1., 1., 2., 2., 2., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.]),
 array([1., 1., 1., 2., 2., 2., 2., 2., 1., 1., 1., 2., 1., 2., 1., 2., 2.,
        1., 1., 0.])]

X_batches_test

X_batches_test

[array([[ 0.74119496,  1.97273418,  1.76675805,  0.51484268,  1.39422086,
          2.97184667, -1.35274514,  2.08825434, -1.2521965 ,  1.11556387,
          0.19776789,  2.38259223, -0.57140597, -0.79010112,  0.17038974,
          1.28075761,  0.696398  ,  3.0920007 , -0.41138503,  0.        ],
        [-1.39081797,  0.41079718,  3.03698894, -2.07333633,  2.05575621,
          2.73222939, -0.98182787,  1.06741172, -1.36310914,  0.20174856,
          0.35323654,  2.70305775,  0.52549713, -0.7786237 ,  1.80857093,
          0.96830907, -0.23610863,  1.28160768,  0.7026651 ,  0.        ],
        [ 1.16357113,  0.43907935,  3.40158623, -0.73923043,  1.484668  ,
          1.52809569, -0.02347205,  1.65349967,  1.79635118, -0.46647772,
         -0.78400883,  0.82695404, -1.34932627, -0.3200281 ,  2.84417045,
          0.01534261,  0.10047148,  2.70769609, -1.42669461,  0.        ],
        [-1.05475682,  3.45578027,  1.58589338, -0.55515227,  2.13477478,
          1.86777473,  0.61550335,  1.05781415, -0.45297406, -0.04317595,
         -0.15255388,  0.74669395, -1.43621979,  1.06229278,  0.99792794,
          1.24391783, -1.86484584,  1.92802343,  0.56148011,  0.        ],
        [-0.0835337 ,  1.89593955,  1.65769335, -0.93622246,  1.05002869,
          1.49675624, -0.00821712,  1.71541053,  2.02408452,  0.59011484,
          0.72719784,  3.44801858, -0.00957537,  0.37176007,  1.93481168,
          2.23125062,  1.67910471,  2.80923862,  0.34516993,  0.        ]]),
 array([[ 0.40691415,  2.31873444, -0.83458005, -0.17018249, -0.39177831,
          1.90353251,  2.98241467,  0.32808584,  3.09429553,  2.27183083,
          3.09576659,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.6862473 ,  1.0690102 , -0.07415598, -0.09846767,  1.14562424,
          2.52211963,  1.71911351,  0.41879894,  1.62787544,  3.50533394,
          2.69963456,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 3.27824216,  2.25067953,  0.40017321, -1.36011162, -1.41010106,
          0.98956203,  2.30881584, -0.29496046,  2.29748247,  3.24940966,
          1.06431776,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 2.80167214,  3.88324559, -0.6984172 ,  0.81889567,  1.86945352,
          3.07554419,  3.10357189,  1.31426767,  0.28163147,  2.75559628,
          2.00866885,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
        [ 1.54574419,  1.00720596, -1.55418837,  0.70823839,  0.14715209,
          1.03747262,  0.82988672, -0.54006372,  1.4960777 ,  0.34578788,
          1.10558132,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ]])]

y_batches_test

y_batches_test

[array([1., 2., 2., 1., 2., 2., 1., 2., 1., 1., 1., 2., 1., 1., 2., 2., 1.,
        2., 1., 0.]),
 array([2., 2., 1., 1., 1., 2., 2., 1., 2., 2., 2., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])]

推荐答案

第一个问题:重塑后的X数据不是您期望的.如果您在重塑后查看第一个样本,则为:

First problem: your X data after reshaping is not what you expected. If you look at the first sample after reshaping, it is:

array([[[-1.00612917,  1.47313952,  2.68021318,  1.54875809,
          0.98385996],
        [ 1.49465265,  0.60429106,  1.12396908, -0.24041602,
          1.77266187],
        [ 0.1961381 ,  1.28019637,  1.78803092,  2.05151245,
          0.93606708],
        [ 0.51554755,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-0.97596563,  2.04536053,  0.88367922,  1.013342  ,
         -0.16605355],
        [ 3.02994344,  2.04080806, -0.25153046, -0.5964068 ,
          2.9607247 ],
        [-0.49722121,  0.02734492,  2.16949987,  2.77367066,
          0.15628842],
        [ 2.19823207,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 0.31546283,  3.27420503,  3.23550769, -0.63724013,
          0.89150128],
        [ 0.69774266,  2.76627308, -0.58408384, -0.45681779,
          1.98843041],
        [-0.31850477,  0.83729882,  0.45471165,  3.61974147,
         -1.45610756],
        [ 1.35217453,  0.        ,  0.        ,  0.        ,
          0.        ],
        [ 1.03329532,  1.97471646,  1.33949611,  1.22857243,
         -1.46890642],
        [ 1.74105506,  1.40969261,  0.52465603, -0.18895266,
          2.81025597],
        [ 2.64901037, -0.83415186,  0.76956826,  1.48730868,
         -0.16190164],
        [ 2.24389007,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-1.0676654 ,  3.08429323,  1.7601179 ,  0.85448051,
          1.15537064],
        [ 2.82487842,  0.27891413,  0.57842569, -0.62392063,
          1.00343057],
        [ 1.15348843, -0.37650332,  3.37355345,  2.22285473,
          0.43444434],
        [ 0.15743873,  0.        ,  0.        ,  0.        ,
          0.        ]]])

因此实际上没有任何时间步被掩盖,因为掩蔽"层仅掩盖所有要素均为0的时间步,因此上述20个时间步均未被掩盖,因为它们都不完全为0.

So actually no timestep is masked, because Masking layer only mask timesteps where all features are 0, so the above 20 timesteps are not masked because none of them are completely 0.

对于Masking层,要确保已成功将Mask传播到输出层,可以执行以下操作:

For the Masking layer, to ensure you have the mask propogated to the output layer successfully, you can do:

for i, l in enumerate(model.layers):
    print(f'layer {i}: {l}')
    print(f'has input mask: {l.input_mask}')
    print(f'has output mask: {l.output_mask}')

layer 0: <tensorflow.python.keras.layers.core.Masking object at 0x6417b7f60>
has input mask: None
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 1: <tensorflow.python.keras.layers.wrappers.Bidirectional object at 0x641e25cf8>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 2: <tensorflow.python.keras.layers.core.Dropout object at 0x641814128>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
layer 3: <tensorflow.python.keras.layers.wrappers.TimeDistributed object at 0x6433b6ba8>
has input mask: Tensor("masking/Identity_1:0", shape=(None, 20), dtype=bool)
has output mask: Tensor("time_distributed/Reshape_3:0", shape=(None, 20), dtype=bool)

因此您可以看到最后一层还具有output_mask,这意味着已成功传播了遮罩.您似乎对Keras中Masking的工作方式有误解,实际上它会生成一个mask(布尔数组),mask的形状为(None,Timesteps),因为在模型定义中,Timestep是一个布尔数组.尺寸始终保持不变,因此遮罩将在没有任何更改的情况下传播到最后.然后,当Keras计算损耗(当然,当它计算梯度)时,具有掩码值为False的时间步将被忽略. Masking层不会更改输出值,当然您的模型仍将预测类0,它仅做的是生成一个布尔数组,该数组指示应跳过哪个时间步并将其传递到末尾(如果所有层都接受该值).面具).

So you can see that the final layer also has the output_mask, which means the masks are successfully propogated. You seem to have a misunderstanding of how Masking works in Keras, what it actually does is it will generate a mask, which is a boolean array, the shape of the mask is (None, Timesteps), since in your model definition, the Timestep dimension is always kept the same, so the mask will be propogated to the end without any changes. Then when Keras calculate loss (and of course when it calculate gradients), the timesteps which has a mask value False will be ignored. The Masking layer doesn't change the output value and of course your model will still predict class 0, what it only does is to produce a boolean array indicating which timestep should be skipped and pass it to the end (if all the layers accept the mask).

因此,您可以做的是如下更改模型定义的一行,并使y_labels偏移1,这表示您当前的类:

So what you can do is change one line of your model definition as follows and make your y_labels shifted by 1, which means your current classes:

0-> 0(由于这些时间步的损失将被忽略,不会对模型的训练有所帮助,因此无论是0还是1都不重要)

0 -> 0 (since the loss of these timesteps will be ignored, not contributing to the training of the model, so whether it is 0 or 1 doesn't matter)

1-> 0

2-> 1

# I would prefer softmax if doing classification
# here we only need to specify 2 classes
# and actually TimeDistributed can be thrown away (at least in recent Keras versions)
model.add(TimeDistributed(Dense(2, activation='softmax')))

您也可以在 https://stackoverflow.com/a/59313862/11819266 上查看我的答案,以了解如何损耗是在有/没有屏蔽的情况下计算的.

You can also see my answer here https://stackoverflow.com/a/59313862/11819266 for understanding how the loss are calculated with / without masking.

这篇关于尽管掩盖了对喀拉拉邦零填充小批量LSTM训练的支持,但零预测的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆