Keras Dense层的输入未展平 [英] Keras Dense layer's input is not flattened

查看:409
本文介绍了Keras Dense层的输入未展平的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这是我的测试代码:

from keras import layers
input1 = layers.Input((2,3))
output = layers.Dense(4)(input1)
print(output)

输出为:

<tf.Tensor 'dense_2/add:0' shape=(?, 2, 4) dtype=float32>

但是会发生什么?

文档说:

注意:如果该图层的输入的秩大于2,则为 在具有内核的初始点积之前变平.

Note: if the input to the layer has a rank greater than 2, then it is flattened prior to the initial dot product with kernel.

输出会被重塑吗?

推荐答案

当前,与文档中所述相反,Dense

Currently, contrary to what has been stated in documentation, the Dense layer is applied on the last axis of input tensor:

与文档相反,我们实际上并未对其进行拼合.它是 单独应用在最后一个轴上.

Contrary to the documentation, we don't actually flatten it. It's applied on the last axis independently.

换句话说,如果将具有m单位的Dense层应用于形状为(n_dim1, n_dim2, ..., n_dimk)的输入张量,则其输出形状为(n_dim1, n_dim2, ..., m).

In other words, if a Dense layer with m units is applied on an input tensor of shape (n_dim1, n_dim2, ..., n_dimk) it would have an output shape of (n_dim1, n_dim2, ..., m).

作为旁注::这使TimeDistributed(Dense(...))Dense(...)彼此等效.

另一注:请注意,这具有共享权重的作用.例如,考虑以下玩具网络:

Another side note: be aware that this has the effect of shared weights. For example, consider this toy network:

model = Sequential()
model.add(Dense(10, input_shape=(20, 5)))

model.summary()

模型摘要:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 20, 10)            60        
=================================================================
Total params: 60
Trainable params: 60
Non-trainable params: 0
_________________________________________________________________

如您所见,Dense层只有60个参数.如何? Dense层中的每个单元都以相同的权重(相同的权重)(因此10 * 5 + 10 (bias params per unit) = 60)连接到输入中每行的5个元素.

As you can see the Dense layer has only 60 parameters. How? Each unit in the Dense layer is connected to the 5 elements of each row in the input with the same weights, therefore 10 * 5 + 10 (bias params per unit) = 60.

更新.这是上面示例的直观图示:

Update. Here is a visual illustration of the example above:

这篇关于Keras Dense层的输入未展平的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆