TensorFlow、Keras:替换预训练模型中的激活层 [英] TensorFlow, Keras: Replace Activation layer in pretrained model

查看:188
本文介绍了TensorFlow、Keras:替换预训练模型中的激活层的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在预训练的 TF 模型 EfficientNetB0 中用 relu 激活替换 swish 激活.EfficientNetB0 在 Conv2D 和 Activation 层中使用 swish 激活.这个 SO post 与我正在寻找的非常相似.我还找到了 一个答案,它适用于没有跳过连接的模型.代码如下:

I'm trying to replace swish activation with relu activation in pretrained TF model EfficientNetB0. EfficientNetB0 uses swish activation in Conv2D and Activation layers. This SO post is very similar to what I'm looking for. I also found an answer which works for models without skip connection. Below is the code:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import ReLU

def replace_swish_with_relu(model):
    '''
    Modify passed model by replacing swish activation with relu
    '''
    for layer in tuple(model.layers):
        layer_type = type(layer).__name__
        if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
            print(layer_type, layer.activation.__name__)
            if layer_type == "Conv2D":
                # conv layer with swish activation.
                # Do something
                layer.activation = ReLU() # This didn't work
            else:
                # activation layer
                # Do something
                layer = tf.keras.layers.Activation('relu', name=layer.name + "_relu") # This didn't work
    return model

# load pretrained efficientNet
model = tf.keras.applications.EfficientNetB0(
    include_top=True, weights='imagenet', input_tensor=None,
    input_shape=(224, 224, 3), pooling=None, classes=1000,
    classifier_activation='softmax')

# convert swish activation to relu activation
model = replace_swish_with_relu(model)
model.save("efficientNet-relu")

如何修改replace_swish_with_relu以在传递的模型中用relu替换swish激活?

How to modify replace_swish_with_relu to replace swish activations with relu in the passed model?

感谢您的指点/帮助.

推荐答案

layer.activation 指向 tf.keras.activations.swish 函数地址.我们可以修改它以指向tf.keras.activations.relu.下面是修改后的,replace_swish_with_relu:

layer.activation points to tf.keras.activations.swish function address. We can modify it to point to tf.keras.activations.relu. Below is the modified, replace_swish_with_relu:

def replace_swish_with_relu(model):
    '''
    Modify passed model by replacing swish activation with relu
    '''
    for layer in tuple(model.layers):
        layer_type = type(layer).__name__
        if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
            print(layer_type, layer.activation.__name__)
            if layer_type == "Conv2D":
                # conv layer with swish activation
                layer.activation = tf.keras.activations.relu
            else:
                # activation layer
                layer.activation = tf.keras.activations.relu
    return model

注意:如果您正在修改激活函数,那么您需要重新训练模型以使用新的激活函数.相关.

Note: If you are modifying the activation function, then you need to retrain the model to work with the new activation. Related.

这篇关于TensorFlow、Keras:替换预训练模型中的激活层的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆