什么是"model.trainable = False"?在Keras中是什么意思 [英] What does "model.trainable = False" mean in Keras?

查看:1152
本文介绍了什么是"model.trainable = False"?在Keras中是什么意思的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在Keras中冻结一个预先训练的网络.我在文档中找到了base.trainable = False.但是我不明白它是如何工作的. 使用len(model.trainable_weights)时,我发现有30个可训练的砝码.怎么可能?网络显示可训练的参数总数:16,812,353. 冻结后,我有4个可训练的砝码. 也许我不了解参数和权重之间的区别.不幸的是,我是深度学习的初学者.也许有人可以帮助我.

I want to freeze a pre-trained network in Keras. I found base.trainable = False in the documentation. But I didn't understand how it works. With len(model.trainable_weights) I found out that I have 30 trainable weights. How can that be? The network shows total trainable params: 16,812,353. After freezing I have 4 trainable weights. Maybe I don't understand the difference between params and weights. Unfortunately I am a beginner in Deep Learning. Maybe someone can help me.

推荐答案

Keras Model

A Keras Model is trainable by default - you have two means of freezing all the weights:

  1. model.trainable = False 之前编译模型
  2. for layer in model.layers: layer.trainable = False-在&之前工作编译后
  1. model.trainable = False before compiling the model
  2. for layer in model.layers: layer.trainable = False - works before & after compiling

(1)必须在编译之前完成,因为Keras在编译时会将model.trainable视为布尔标志,并在后台执行(2).完成以上任一操作后,您应该会看到:

(1) must be done before compilation since Keras treats model.trainable as a boolean flag at compiling, and performs (2) under the hood. After doing either of the above, you should see:

print(model.trainable_weights)
# [] 


关于文档,可能已过时-请参见上面的链接源代码,最新.


Regarding the docs, likely outdated - see linked source code above, up-to-date.

这篇关于什么是"model.trainable = False"?在Keras中是什么意思的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆