默认情况下,Keras自定义图层参数是否不可训练? [英] Are Keras custom layer parameters non-trainable by default?

查看:191
本文介绍了默认情况下,Keras自定义图层参数是否不可训练?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在Keras中构建了一个简单的自定义层,并惊讶地发现默认情况下未将参数设置为可训练.我可以通过显式设置可训练的属性来使其工作.我无法通过查看文档或代码来解释为什么这样做.这是应该的样子吗,还是我做错了默认情况下使参数不可训练? 代码:

I built a simple custom layer in Keras and was surprised to find that the parameters were not set to trainable by default. I can get it to work by explicitly setting the trainable attribute. I can't explain why this is by looking at documentation or code. Is this how it is supposed to be or I am doing something wrong which is making the parameters non-trainable by default? Code:

import tensorflow as tf


class MyDense(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MyDense, self).__init__(kwargs)
        self.dense = tf.keras.layers.Dense(2, tf.keras.activations.relu)

    def call(self, inputs, training=None):
        return self.dense(inputs)


inputs = tf.keras.Input(shape=10)
outputs = MyDense()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='test')
model.compile(loss=tf.keras.losses.MeanSquaredError())
model.summary()

输出:

Model: "test"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
my_dense (MyDense)           (None, 2)                 22        
=================================================================
Total params: 22
Trainable params: 0
Non-trainable params: 22
_________________________________________________________________

如果我这样更改自定义图层的创建,则:

If I change the custom layer creation like this:

outputs = MyDense(trainable=True)(inputs)

输出是我期望的(所有参数都是可训练的):

the output is what I expect (all parameters are trainable):

=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________

然后按预期工作,并使所有参数均可训练.我不明白为什么需要这么做.

then it works as expected and makes all the parameters trainable. I don't understand why that is needed though.

推荐答案

毫无疑问,这是一个有趣的怪癖.

No doubt, that's an interesting quirk.

制作自定义图层时,tf.Variable将自动包含在trainable_variable的列表中.您没有使用tf.Variable,而是使用了tf.keras.layers.Dense对象,该对象不会被视为tf.Variable,并且默认情况下不会设置trainable=True.但是,您使用的Dense对象将被设置为可训练的.参见:

When making a custom layer, a tf.Variable will be automatically included in the list of trainable_variable. You didn't use tf.Variable, but a tf.keras.layers.Dense object instead, which will not be treated as a tf.Variable, and not set trainable=True by default. However, the Dense object you used will be set to trainable. See:

MyDense().dense.trainable

True

如果您使用了tf.Variable(应该使用),则默认情况下它是可以训练的.

If you used tf.Variable (as it should), it will be trainable by default.

import tensorflow as tf


class MyDense(tf.keras.layers.Layer):
    def __init__(self, units=2, input_dim=10):
        super(MyDense, self).__init__()
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"),
            trainable=True,
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )

    def call(self, inputs, **kwargs):
        return tf.matmul(inputs, self.w) + self.b


inputs = tf.keras.Input(shape=10)
outputs = MyDense()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='test')
model.compile(loss=tf.keras.losses.MeanSquaredError())
model.summary()

Model: "test"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        [(None, 10)]              0         
_________________________________________________________________
my_dense_18 (MyDense)        (None, 2)                 22        
=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________

这篇关于默认情况下,Keras自定义图层参数是否不可训练?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆