使用子类模型时,model.summary()无法打印输出形状 [英] model.summary() can't print output shape while using subclass model

查看:1629
本文介绍了使用子类模型时,model.summary()无法打印输出形状的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这是创建keras模型的两种方法,但是两种方法的摘要结果的output shapes不同.显然,前者可以打印更多信息,并且可以更轻松地检查网络的正确性.

This is the two methods for creating a keras model, but the output shapes of the summary results of the two methods are different. Obviously, the former prints more information and makes it easier to check the correctness of the network.

import tensorflow as tf
from tensorflow.keras import Input, layers, Model

class subclass(Model):
    def __init__(self):
        super(subclass, self).__init__()
        self.conv = layers.Conv2D(28, 3, strides=1)

    def call(self, x):
        return self.conv(x)


def func_api():
    x = Input(shape=(24, 24, 3))
    y = layers.Conv2D(28, 3, strides=1)(x)
    return Model(inputs=[x], outputs=[y])

if __name__ == '__main__':
    func = func_api()
    func.summary()

    sub = subclass()
    sub.build(input_shape=(None, 24, 24, 3))
    sub.summary()

输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 24, 24, 3)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 22, 22, 28)        784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            multiple                  784       
=================================================================
Total params: 784
Trainable params: 784
Non-trainable params: 0
_________________________________________________________________

那么,我应该如何使用子类方法在summary()处获取output shape?

So, how should I use the subclass method to get the output shape at the summary()?

推荐答案

我使用此方法来解决此问题,我不知道是否有更简单的方法.

I have used this method to solve this problem, I don't know if there is an easier way.

class subclass(Model):
    def __init__(self):
        ...
    def call(self, x):
        ...

    def model(self):
        x = Input(shape=(24, 24, 3))
        return Model(inputs=[x], outputs=self.call(x))



if __name__ == '__main__':
    sub = subclass()
    sub.model().summary()

这篇关于使用子类模型时,model.summary()无法打印输出形状的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆