使用参数自定义激活 [英] Custom activation with parameter

查看:151
本文介绍了使用参数自定义激活的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在Keras中创建一个激活函数,该函数可以像这样输入参数beta:

I'm trying to create an activation function in Keras that can take in a parameter beta like so:

from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation

class Swish(Activation):

    def __init__(self, activation, beta, **kwargs):
        super(Swish, self).__init__(activation, **kwargs)
        self.__name__ = 'swish'
        self.beta = beta


def swish(x):
    return (K.sigmoid(beta*x) * x)

get_custom_objects().update({'swish': Swish(swish, beta=1.)})

它在没有beta参数的情况下运行良好,但是如何在激活定义中包括该参数?我还希望在执行model.to_json()时(如激活ELU一样)保存该值.

It runs fine without the beta parameter, but how can I include the parameter in the activation definition? I also want this value to be saved when I do model.to_json() like for ELU activation.

更新:我根据@today的答案编写了以下代码:

Update: I wrote the following code based on @today's answer:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)
        self.__name__ = 'swish'

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
    arch_file.write(arch)

但是,它当前不将beta值保存在.json文件中.我该如何保存它的值?

However, it does not currently save the beta value in the .json file. How can I make it save the value?

推荐答案

由于您要在序列化模型时保存激活函数的参数,因此我认为最好将激活函数定义为类似于已在Keras中定义的高级激活.您可以这样做:

Since you want to save the parameters of activation function when serializing the model, I think it is better to define the activation function as a layer like the advanced activations which have been defined in Keras. You can do it like this:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

然后,您可以像使用Keras层一样使用它:

Then you can use it the same way you use a Keras layer:

# ...
model.add(Swish(beta=0.3))

由于已在其定义中实现了get_config()方法,因此在使用to_json()save()之类的方法时,将保存参数beta.

Since get_config() method has been implemented in its definition, the parameter beta would be saved when using methods like to_json() or save().

这篇关于使用参数自定义激活的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆