如何展平嵌套模型?(keras 函数式 API) [英] How to flatten a nested model? (keras functional API)

查看:49
本文介绍了如何展平嵌套模型?(keras 函数式 API)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经使用 keras 模型功能 API 定义了一个简单的模型.它的一个层是一个完全顺序模型,所以我得到了一个嵌套的层结构(见下图).

I have defined a simple model using the keras Model functional API. One of its layers is a fully sequential model, so I get a nested layer structure (see images below).

如何将这种嵌套层结构转换为平面层结构?(使用脚本,而不是手动...)

How can I convert this nested layer structure into a flat layer structure? (with a script, not manually...)

我所拥有的:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 32, 32, 1)         0         
_________________________________________________________________
sequential_1 (Sequential)    (None, 8, 8, 12)          720       
_________________________________________________________________
flatten_1 (Flatten)          (None, 768)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 769       
=================================================================

我想将其转换为:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 32, 32, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 6)         60        
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 16, 16, 6)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 6)         330       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 8, 8, 6)           0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 384)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 385       
=================================================================

<小时>

生成嵌套层结构的代码:


Code to generate nested layer structure:

def create_network_with_one_subnet():
    # define subnetwork
    subnet = keras.models.Sequential()
    subnet.add(keras.layers.Conv2D(6, (3, 3), padding='same'))
    subnet.add(keras.layers.MaxPool2D())
    subnet.add(keras.layers.Conv2D(12, (3, 3), padding='same'))
    subnet.add(keras.layers.MaxPool2D())
    #subnet.summary()


    # define complete network
    input_shape = (32, 32, 1)
    net_in = keras.layers.Input(shape=input_shape)
    net_out = subnet(net_in)
    net_out = keras.layers.Flatten()(net_out)
    net_out = keras.layers.Dense(1)(net_out)
    net_complete = keras.Model(inputs=net_in, outputs=net_out)
    net_complete.compile(loss='binary_crossentropy',
                         optimizer=keras.optimizers.Adam(lr=0.001),
                         metrics=['acc'],
                         )
    net_complete.summary()
    return net_complete

推荐答案

啊,这比预期容易得多.谷歌搜索正确关键字后的解决方案:https://groups.google.com/forum/#!msg/keras-users/lJcVK25YDuc/atB6TfwqBAAJ

Ah, it was much easier than expected. Solution from here after googling the right keywords: https://groups.google.com/forum/#!msg/keras-users/lJcVK25YDuc/atB6TfwqBAAJ

def flatten_model(model_nested):
    layers_flat = []
    for layer in model_nested.layers:
        try:
            layers_flat.extend(layer.layers)
        except AttributeError:
            layers_flat.append(layer)
    model_flat = keras.models.Sequential(layers_flat)
    return model_flat

这篇关于如何展平嵌套模型?(keras 函数式 API)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆