Keras模型输入形状错误 [英] Keras model input shape wrong

查看:308
本文介绍了Keras模型输入形状错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个布局如下的keras模型

I've got a keras model with the layout below

def keras_model(x_train, y_train, x_test, y_test):
    model = Sequential()
    model.add(Dense(128, input_dim=x_train.shape[1], activation='relu'))
    model.add(Dense(256,activation='relu'))
    model.add(Dense(512,activation='relu'))
    model.add(Dense(256,activation='relu'))
    model.add(Dense(128,activation='relu'))
    #model.add(Dense(10,activation='relu'))
    model.add(Dense(y_train.shape[1], activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    monitor = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5, verbose=1, mode='auto')
    checkpointer = ModelCheckpoint(filepath="best_weights.hdf5", verbose=0, save_best_only=True) # save best model

    model.fit(x_train ,y_train, validation_data=(x_test, y_test),callbacks=[monitor,checkpointer], verbose=2,epochs=1000)
    model.load_weights('best_weights.hdf5') # load weights from best model

    return model

训练了来自开放式健身房卡塔普尔游戏的数据,并保存了模型. 下一步是使用经过训练的模型进行预测

Trained on data from open-ai gym cartpole game and the model saved. The next step is to use the trained model to make predictions

from keras.models import load_model
model = load_model('data/model-v0.h5')
action = random.randrange(0,2)

import gym
env = gym.make("CartPole-v0")
env.reset()
>>> array([ 0.02429215, -0.00674185, -0.03713565, -0.0046836 ])

import random
action = random.randrange(0,2)
observation, reward, done, info = env.step(action)
observation.shape
>>> (4,)

new_observation, reward, done, info = env.step(action)
prev_obs = new_observation
prev_obs
>>> array([-0.00229585,  0.15330146,  0.02160273, -0.30723955])

prev_obs.shape
>>> (4,)

model.predict(prev_obs)
>>>
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-25-943f2f44ed64> in <module>()
----> 1 model.predict(prev_obs)

c:\users\samuel\appdata\local\programs\python\python35\lib\site-packages\keras\engine\training.py in predict(self, x, batch_size, verbose, steps)
   1145                              'argument.')
   1146         # Validate user data.
-> 1147         x, _, _ = self._standardize_user_data(x)
   1148         if self.stateful:
   1149             if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:

c:\users\samuel\appdata\local\programs\python\python35\lib\site-packages\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    747             feed_input_shapes,
    748             check_batch_axis=False,  # Don't enforce the batch size.
--> 749             exception_prefix='input')
    750 
    751         if y is not None:

c:\users\samuel\appdata\local\programs\python\python35\lib\site-packages\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    135                             ': expected ' + names[i] + ' to have shape ' +
    136                             str(shape) + ' but got array with shape ' +
--> 137                             str(data_shape))
    138     return data
    139 

ValueError: Error when checking input: expected dense_1_input to have shape (4,) but got array with shape (1,)

观测值的形状与所使用的训练数据的形状相似,即使您看到的是observationprev_observation的形状为(4,),但问题仍然存在,但是当输入模型进行预测时引发错误并声称输入的形状为(1,).

The shape of the observation is similar that of the training data used, the issue is even as you can see as the observation and the prev_observation has a shape of (4,) but when fed into the model to predict throws an error and claims the input has a shape of (1,).

我甚至尝试用它重塑

prev_obs.shape = (4,)
prev_obs.reshape((4,))

但它仍然会引发相同的错误.

but it still throws the same error.

推荐答案

keras的API始终假定您以批处理或数组的形式提供数据,以便从中提取批处理.因此,尽管模型的第一层要求输入形状为(4,),但您仍必须重塑数据以使其形状为(1,4).

The API of keras always assumes that you supply the data in batches or in an array from which it can extract batches. Therefore, eventhough the first layer of your model requires an input shape of (4,), you have to reshape the data to have the shape (1,4).

model.predict(prev_obs.reshape((1, -1)

这告诉模型对1个数据样本进行预测,该样本由4维矢量(观察值)组成.

This tells the model to make a prediction on 1 data sample, which consists out of a 4-dimensional vector (the observation).

这篇关于Keras模型输入形状错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆