为什么Keras在这种情况下抱怨输入形状不兼容? [英] Why is Keras complaining about incompatible input shape in this case?
问题描述
我使用以下输入层训练了基于Keras的自动编码器模型:
I have trained a Keras-based autoencoder model with the following input layer:
depth = 1
width = height = 100
input_shape = (height, width, depth)
inputs = Input(shape=input_shape)
# rest of network definition ...
我的训练图像的宽度和高度是100像素灰度,因此深度为1.现在,我想在另一个脚本中加载训练后的模型,在其中加载图像,调整大小并将其发送到Keras模型:
Width and height of my training images were 100 pixels in grayscale, thus with a depth of 1. Now I want to load my trained model in another script, load an image there, resize and send it to the Keras model:
size = 100
image = cv2.imread(args.image, cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA)
image = image.astype("float32") / 255.0
image = np.expand_dims(image, axis=-1)
# at this point image.shape = (100, 100, 1)
recon = autoencoder.predict(image)
但是,对 autoencoder.predict(image)
的调用导致以下错误:
However, the call to autoencoder.predict(image)
leads to the following error:
WARNING:tensorflow:Model was constructed with shape (None, 100, 100, 1) for input KerasTensor(type_spec=TensorSpec(shape=(None, 100, 100, 1), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (None, 100, 1, 1).
我不明白这一点,因为调用 predict()
时图像的形状是(100,100,1)
,对我来说很好.为什么Keras抱怨(None,100,1,1)
的输入形状不兼容?
I don't understand this, as the shape of the image when calling predict()
is (100, 100, 1)
, which looks fine to me. Why is Keras complaining about an incompatible input shape of (None, 100, 1, 1)
?
推荐答案
这些简单的代码行会产生错误
These simple lines of code generate the error
X = np.random.uniform(0,1, (100,100,1))
inp = Input((100,100,1))
out = Dense(1)(Flatten()(inp))
model = Model(inp, out)
model.predict(X)
这是因为您的Keras模型期望数据采用以下格式(n_sample,100,100,1)
This is because your Keras model expects data in this format (n_sample, 100, 100, 1)
当您预测单个图像时,简单的重塑就能达到目的
A simple reshape when you predict a single image does the trick
model.predict(X.reshape(1,100,100,1))
这篇关于为什么Keras在这种情况下抱怨输入形状不兼容?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!