使用子类模型时,model.summary() 无法打印输出形状 [英] model.summary() can't print output shape while using subclass model

查看:33
本文介绍了使用子类模型时,model.summary() 无法打印输出形状的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这是两种创建keras模型的方法,但是两种方法的汇总结果的输出形状是不同的.显然,前者打印的信息更多,更容易检查网络的正确性.

This is the two methods for creating a keras model, but the output shapes of the summary results of the two methods are different. Obviously, the former prints more information and makes it easier to check the correctness of the network.

import tensorflow as tf
from tensorflow.keras import Input, layers, Model

class subclass(Model):
    def __init__(self):
        super(subclass, self).__init__()
        self.conv = layers.Conv2D(28, 3, strides=1)

    def call(self, x):
        return self.conv(x)


def func_api():
    x = Input(shape=(24, 24, 3))
    y = layers.Conv2D(28, 3, strides=1)(x)
    return Model(inputs=[x], outputs=[y])

if __name__ == '__main__':
    func = func_api()
    func.summary()

    sub = subclass()
    sub.build(input_shape=(None, 24, 24, 3))
    sub.summary()

输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 24, 24, 3)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 22, 22, 28)        784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            multiple                  784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________

那么,我应该如何使用子类方法在summary()处获取输出形状?

So, how should I use the subclass method to get the output shape at the summary()?

推荐答案

我用这个方法解决了这个问题,不知道有没有更简单的方法.

I have used this method to solve this problem, I don't know if there is an easier way.

class subclass(Model):
    def __init__(self):
        ...
    def call(self, x):
        ...

    def model(self):
        x = Input(shape=(24, 24, 3))
        return Model(inputs=[x], outputs=self.call(x))



if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()

这篇关于使用子类模型时,model.summary() 无法打印输出形状的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆