如何在Keras中获取模型的可训练参数的数量? [英] How can I get the number of trainable parameters of a model in Keras?

查看:625
本文介绍了如何在Keras中获取模型的可训练参数的数量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在通过Model API实现的所有层中都设置trainable=False,但是我想验证一下是否有效. model.count_params()返回参数总数,但是除了查看model.summary()的最后几行之外,有什么方法可以获取可训练参数的总数?

I am setting trainable=False in all my layers, implemented through the Model API, but I want to verify whether that is working. model.count_params() returns the total number of parameters, but is there any way in which I can get the total number of trainable parameters, other than looking at the last few lines of model.summary()?

推荐答案

from keras import backend as K

trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

以上代码段可以在 layer_utils.print_summary() 定义,其中 summary() 正在呼叫.

The above snippet can be discovered in the end of layer_utils.print_summary() definition, which summary() is calling.

Keras的最新版本具有助手功能

more recent version of Keras has a helper function count_params() for this purpose:

from keras.utils.layer_utils import count_params

trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)

这篇关于如何在Keras中获取模型的可训练参数的数量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆