如何打印 Tensorflow V2 keras 模型的所有激活形状(比 summary() 更详细)? [英] How to print all activation shapes (more detailed than summary()) for Tensorflow V2 keras model?
问题描述
我在 Tensorflow v.0 和 v.1 上花了很多时间,现在我正在尝试 Tensorflow v.2 keras 模型.model.summary()
看起来简单方便,但缺乏细节.
I've spent a lot of time with Tensorflow v.0 and v.1, and now I'm trying Tensorflow v.2 keras model. model.summary()
looked easy and convenient, but lack details.
这是一个玩具示例.假设我定义了如下自定义层和模型(函数式 API 样式和子类样式).
Here's a toy example. Let's say I define custom layers and models as below (a functional API style and a subclass syle).
请看下文.我想在自定义层中看到原始层,但 .summary()
只显示浅层信息(仅直接子层).
Please see below. I wanted to see primitive layers inside the custom layers, but .summary()
only shows shallow information (only direct children layers).
玩具自定义层(层只是玩具定义):
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import (Dense, Conv2D, BatchNormalization)
class LayerA(tf.keras.layers.Layer):
def __init__(self, num_outputs, **kwargs):
super(LayerA, self).__init__(**kwargs)
self.num_outputs = num_outputs
self.dense1 = Dense(64)
self.dense2 = Dense(128)
self.dense3 = Dense(num_outputs)
self.bn = BatchNormalization(trainable=True)
def call(self, inputs, training=True):
x = self.dense1(inputs)
x = self.bn(x, training=training)
x = self.dense1(inputs)
x = self.bn(x, training=training)
x = self.dense1(inputs)
x = self.bn(x, training=training)
return x
class LayerB(tf.keras.layers.Layer):
def __init__(self, num_outputs, **kwargs):
super(LayerB, self).__init__(**kwargs)
self.num_outputs = num_outputs
self.dense = Dense(64)
self.bn = BatchNormalization(trainable=True)
def call(self, inputs, training=True):
x = self.dense(inputs)
x = self.bn(x, training=training)
return x
使用函数式 API 的模型定义:
inputs = tf.keras.Input(shape=(28), name='input')
x = LayerA(7, name='layer_a')(inputs)
x = LayerB(13, name='layer_b')(x)
x = tf.reduce_max(x, 1)
model_func = keras.Model(inputs=inputs, outputs=x, name='model')
model_func.summary()
# Results:
# Layer (type) Output Shape # Param #
# =================================================================
# input (InputLayer) [(None, 28)] 0
# _________________________________________________________________
# layer_a (LayerA) (None, 64) 2112
# layer_b (LayerB) (None, 64) 4416
# tf_op_layer_Max_1 (TensorFlo [(None,)] 0
# =================================================================
# Total params: 6,528
# Trainable params: 6,272
# Non-trainable params: 256
模型定义子类化:
class ModelA(tf.keras.Model):
def __init__(self):
super(ModelA, self).__init__()
self.block_1 = LayerA(7, name='layer_a')
self.block_2 = LayerB(13, name='layer_b')
def call(self, inputs):
x = self.block_1(inputs)
x = self.block_2(x)
x = tf.reduce_max(x, 1)
return x
model_subclass = ModelA()
y = model_subclass(inputs)
model_subclass.summary()
### Result:
# Layer (type) Output Shape Param #
# =================================================================
# layer_a (LayerA) (None, 64) 2112
# layer_b (LayerB) (None, 64) 4416
# =================================================================
# Total params: 6,528
# Trainable params: 6,016
# Non-trainable params: 512
如何打印模型中 Conv
和 Dense
层的所有激活形状?例如,
How can I prints all the activation shapes of Conv
and Dense
layers in the model? For example,
layer_a/dense_1 (None, ...)
layer_a/dense_2 (None, ...)
layer_b/dense_1 (None, ...)
layer_b/maybe-even-deeper-layer/conv2d_1 (None, ...)
... etc ...
在 Tensorflow v.0 或 v.1 中,我会这样做:
In Tensorflow v.0 or v.1, I would do something like:
for n in tf.get_default_graph().as_graph_def().node:
print(n.name, n.shape)
当我有 keras 模型时,有没有办法打印更多细节?
Is there a way to prints more details when I have a keras model?
推荐答案
摘要不会自动执行此操作,因此您必须进行调整.例如,您可以创建一个循环摘要:
The summary will not do this automatically, so you have to adapt. You can, for instance, create a recurrent summary:
def full_summary(layer):
#check if this layer has layers
if hasattr(layer, 'layers'):
print('summary for ' + layer.name)
layer.summary()
print('\n\n')
for l in layer.layers:
full_summary(l)
将其用作:
full_summary(my_model)
这篇关于如何打印 Tensorflow V2 keras 模型的所有激活形状(比 summary() 更详细)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!