keras模型子类化示例 [英] keras model subclassing examples

查看:575
本文介绍了keras模型子类化示例的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

从Keras 2.2.0开始,发布了第三个模型定义API:模型子类.

starting Keras 2.2.0, the 3rd API of model definition is released: Model subclassing.

根据常见问题解答:

但是,在子类模型中,模型的拓扑定义为 Python代码(而不是作为层的静态图).这意味着 模型的拓扑无法检查或序列化.结果, 以下方法和属性不适用于子类 型号:

However, in subclassed models, the model's topology is defined as Python code (rather than as a static graph of layers). That means the model's topology cannot be inspected or serialized. As a result, the following methods and attributes are not available for subclassed models:

model.inputs和model.outputs. model.to_yaml()和model.to_json() model.get_config()和model.save().

model.inputs and model.outputs. model.to_yaml() and model.to_json() model.get_config() and model.save().

保存经过训练的模型进行推理的唯一选择是使用 model.save_weights 方法.但是,我没有运回模型进行推理的运气.遇到的错误消息包括:

The only option to save the trained model for inference is to use model.save_weights method. However, I have not had luck in loading the model back for inference. Encountered error messages include:

此模型从未被调用过,因此尚未创建其权重,因此无法显示摘要.首先构建模型(例如,通过对某些测试数据进行调用). 您正在尝试将包含4层的权重文件加载到具有0层的模型中. NotImplementedError

This model has never been called, thus its weights have not yet been created, so no summary can be displayed. Build the model first (e.g. by calling it on some test data). You are trying to load a weight file containing 4 layers into a model with 0 layers. NotImplementedError

任何人都可以举一个完整的玩具示例来创建子类的keras模型,train和save_weights,然后将其加载回以进行推断吗?

Can anyone give a full toy example for creating a subclassed keras model, train, and save_weights, then load it back for inference?

推荐答案

在尝试保存子类模型权重之前,您需要调用tf.keras.Model.build方法.替代方法是在尝试保存模型权重之前,在某些输入上调用tf.keras.Model.fit或tf.keras.Model.fit.call.这同样适用于将权重加载到子类模型的新创建实例中.您需要先调用上述方法之一,然后再尝试加载体重. 这是一个示例,显示了子类模型的保存权重和加载权重

You need to call the tf.keras.Model.build method before you try to save a subclassed model weights. An alternative to this would be calling tf.keras.Model.fit or tf.keras.Model.fit.call on some inputs before you try to save your model weights. This same applies to load weights into a newly created instance of your subclassed model. you need to call one of the above-mentioned methods before you try to load your weights. Here is an example showing both saving and loading weights for a subclassed model

import tensorflow as tf

print('TensorFlow', tf.__version__)

class ResidualBlock(tf.keras.Model):
    def __init__(self, block_type=None, n_filters=None):
        super(ResidualBlock, self).__init__()
        self.n_filters = n_filters
        if block_type == 'identity':
            self.strides = 1
        elif block_type == 'conv':
            self.strides = 2
            self.conv_shorcut = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=1, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
            self.bn_shortcut = tf.keras.layers.BatchNormalization(momentum=0.9)

        self.conv_1 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same',
                               strides=self.strides,
                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()

        self.conv_2 = tf.keras.layers.Conv2D(filters=self.n_filters, 
                               kernel_size=3, 
                               padding='same', 
                               kernel_initializer='he_normal')
        self.bn_2 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_2 = tf.keras.layers.ReLU()

    def call(self, x, training=False):
        shortcut = x
        if self.strides == 2:
            shortcut = self.conv_shorcut(x)
            shortcut = self.bn_shortcut(shortcut)
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.conv_2(y)
        y = self.bn_2(y)
        y = tf.add(shortcut, y)
        y = self.relu_2(y)
        return y

class ResNet34(tf.keras.Model):
    def __init__(self, include_top=True, n_classes=1000):
        super(ResNet34, self).__init__()

        self.n_classes = n_classes
        self.include_top = include_top
        self.conv_1 = tf.keras.layers.Conv2D(filters=64, 
                                               kernel_size=7, 
                                               padding='same', 
                                               strides=2, 
                                               kernel_initializer='he_normal')
        self.bn_1 = tf.keras.layers.BatchNormalization(momentum=0.9)
        self.relu_1 = tf.keras.layers.ReLU()
        self.maxpool = tf.keras.layers.MaxPool2D(3, 2, padding='same')
        self.residual_blocks = tf.keras.Sequential()
        for n_filters, reps, downscale in zip([64, 128, 256, 512], 
                                              [3, 4, 6, 3], 
                                              [False, True, True, True]):
            for i in range(reps):
                if i == 0 and downscale:
                    self.residual_blocks.add(ResidualBlock(block_type='conv', 
                                                              n_filters=n_filters))
                else:
                    self.residual_blocks.add(ResidualBlock(block_type='identity', 
                                                              n_filters=n_filters))
        self.GAP = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=self.n_classes)

    def call(self, x, training=False):
        y = self.conv_1(x)
        y = self.bn_1(y)
        y = self.relu_1(y)
        y = self.maxpool(y)
        y = self.residual_blocks(y)
        if self.include_top:
            y = self.GAP(y)
            y = self.fc(y)
        return y

## saving weights
model = ResNet34()
model.build((1, 224, 224, 3))
model.summary()
model.save_weights('model_weights.h5')

## loading saved weights
model_new = ResNet34()
model_new.build((1, 224, 224, 3))
model_new.load_weights('model_weights.h5')

这篇关于keras模型子类化示例的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆