为什么我会遇到Keras形状不匹配的情况? [英] Why am I getting Keras shape mismatch?

查看:70
本文介绍了为什么我会遇到Keras形状不匹配的情况?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在为初学者遵循Keras mnist示例.我试图更改标签以适合自己的数据,该数据具有3个不同的文本分类.我正在使用"to_categorical"来实现这一目标.形状对我来说看起来不错,但是合身"会出现错误:

I am following a Keras mnist example for beginners. I have tried to change the labels to suit my own data which has 3 distinct text classifications. I am using "to_categorical" to achieve this. The shape looks right to me, but "fit" gets an error:

train_labels = keras.utils.to_categorical(train_labels, num_classes=3)

print(train_images.shape)
print(train_labels.shape)

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(3, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=5)

(7074,28,28)

(7074, 28, 28)

(7074,3)

Blockquote 块引用 追溯(最近一次通话):文件 "C:/Users/lawrence/PycharmProjects/tester2019/KeraTest.py",第131行, 在 model.fit(train_images,train_labels,epochs = 5)文件"C:\ Users \ lawrence \ PycharmProjects \ tester2019 \ venv \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py", 1536行,适合 validate_split = validation_split)文件"C:\ Users \ lawrence \ PycharmProjects \ tester2019 \ venv \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py", _standardize_user_data中的第992行 class_weight,batch_size)文件"C:\ Users \ lawrence \ PycharmProjects \ tester2019 \ venv \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py", 第1154行,在_standardize_weights中 exception_prefix ='target')文件"C:\ Users \ lawrence \ PycharmProjects \ tester2019 \ venv \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training_utils.py", 第332行,位于standardize_input_data中 ',但数组的形状为'+ str(data_shape))ValueError:检查目标时出错:预期density_1具有形状(1,),但得到 形状为(3,)的数组

Blockquote Blockquote Traceback (most recent call last): File "C:/Users/lawrence/PycharmProjects/tester2019/KeraTest.py", line 131, in model.fit(train_images, train_labels, epochs=5) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1536, in fit validation_split=validation_split) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 992, in _standardize_user_data class_weight, batch_size) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1154, in _standardize_weights exception_prefix='target') File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 332, in standardize_input_data ' but got array with shape ' + str(data_shape)) ValueError: Error when checking target: expected dense_1 to have shape (1,) but got array with shape (3,)

推荐答案

由于标签是一种热编码的,因此您需要使用categorical_crossentropy而不是sparse_categorical_crossentropy作为损失.

You need to use categorical_crossentropy instead of sparse_categorical_crossentropy as loss since your labels are one hot encoded.

或者,如果您不对标签进行热编码,则可以使用sparse_categorical_crossentropy.在这种情况下,标签的形状应为(batch_size, 1).

Alternatively, you can use sparse_categorical_crossentropy if you do not one hot encode the labels. In that case, the labels should have shape (batch_size, 1).

这篇关于为什么我会遇到Keras形状不匹配的情况?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆