Keras模型输入形状错误 [英] Keras model input shape wrong
问题描述
我有一个布局如下的keras模型
I've got a keras model with the layout below
def keras_model(x_train, y_train, x_test, y_test):
model = Sequential()
model.add(Dense(128, input_dim=x_train.shape[1], activation='relu'))
model.add(Dense(256,activation='relu'))
model.add(Dense(512,activation='relu'))
model.add(Dense(256,activation='relu'))
model.add(Dense(128,activation='relu'))
#model.add(Dense(10,activation='relu'))
model.add(Dense(y_train.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
monitor = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5, verbose=1, mode='auto')
checkpointer = ModelCheckpoint(filepath="best_weights.hdf5", verbose=0, save_best_only=True) # save best model
model.fit(x_train ,y_train, validation_data=(x_test, y_test),callbacks=[monitor,checkpointer], verbose=2,epochs=1000)
model.load_weights('best_weights.hdf5') # load weights from best model
return model
训练了来自开放式健身房卡塔普尔游戏的数据,并保存了模型. 下一步是使用经过训练的模型进行预测
Trained on data from open-ai gym cartpole game and the model saved. The next step is to use the trained model to make predictions
from keras.models import load_model
model = load_model('data/model-v0.h5')
action = random.randrange(0,2)
import gym
env = gym.make("CartPole-v0")
env.reset()
>>> array([ 0.02429215, -0.00674185, -0.03713565, -0.0046836 ])
import random
action = random.randrange(0,2)
observation, reward, done, info = env.step(action)
observation.shape
>>> (4,)
new_observation, reward, done, info = env.step(action)
prev_obs = new_observation
prev_obs
>>> array([-0.00229585, 0.15330146, 0.02160273, -0.30723955])
prev_obs.shape
>>> (4,)
model.predict(prev_obs)
>>>
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-25-943f2f44ed64> in <module>()
----> 1 model.predict(prev_obs)
c:\users\samuel\appdata\local\programs\python\python35\lib\site-packages\keras\engine\training.py in predict(self, x, batch_size, verbose, steps)
1145 'argument.')
1146 # Validate user data.
-> 1147 x, _, _ = self._standardize_user_data(x)
1148 if self.stateful:
1149 if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0:
c:\users\samuel\appdata\local\programs\python\python35\lib\site-packages\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
747 feed_input_shapes,
748 check_batch_axis=False, # Don't enforce the batch size.
--> 749 exception_prefix='input')
750
751 if y is not None:
c:\users\samuel\appdata\local\programs\python\python35\lib\site-packages\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
135 ': expected ' + names[i] + ' to have shape ' +
136 str(shape) + ' but got array with shape ' +
--> 137 str(data_shape))
138 return data
139
ValueError: Error when checking input: expected dense_1_input to have shape (4,) but got array with shape (1,)
观测值的形状与所使用的训练数据的形状相似,即使您看到的是observation
且prev_observation
的形状为(4,)
,但问题仍然存在,但是当输入模型进行预测时引发错误并声称输入的形状为(1,)
.
The shape of the observation is similar that of the training data used, the issue is even as you can see as the observation
and the prev_observation
has a shape of (4,)
but when fed into the model to predict throws an error and claims the input has a shape of (1,)
.
我甚至尝试用它重塑
prev_obs.shape = (4,)
prev_obs.reshape((4,))
但它仍然会引发相同的错误.
but it still throws the same error.
推荐答案
keras
的API始终假定您以批处理或数组的形式提供数据,以便从中提取批处理.因此,尽管模型的第一层要求输入形状为(4,)
,但您仍必须重塑数据以使其形状为(1,4)
.
The API of keras
always assumes that you supply the data in batches or in an array from which it can extract batches. Therefore, eventhough the first layer of your model requires an input shape of (4,)
, you have to reshape the data to have the shape (1,4)
.
model.predict(prev_obs.reshape((1, -1)
这告诉模型对1个数据样本进行预测,该样本由4维矢量(观察值)组成.
This tells the model to make a prediction on 1 data sample, which consists out of a 4-dimensional vector (the observation).
这篇关于Keras模型输入形状错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!