如何从.h5文件正确加载带有自定义图层的Keras模型? [英] How to load the Keras model with custom layers from .h5 file correctly?

查看:263
本文介绍了如何从.h5文件正确加载带有自定义图层的Keras模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我构建了一个具有自定义图层的Keras模型,并通过回调ModelCheckPoint将其保存到了.h5文件中. 在训练后尝试加载此模型时,出现以下错误消息:

I built a Keras model with a custom layers, and it was saved to a .h5 file by the callback ModelCheckPoint. When I tried to load this model after the training, the error message below showed up:

__init__() missing 1 required positional argument: 'pool_size'

这是自定义图层及其__init__方法的定义:

This is the definition of the custom layer and its __init__ method:

class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = pool_size
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

这是将图层添加到模型中的方法:

This is how I add this layer to my model:

x = MyMeanPooling(globalvars.pool_size)(x)

这是我加载模型的方式:

This is how I load the model:

from keras.models import load_model

model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})

这些是完整的错误消息:

These are the full error messages:

Traceback (most recent call last):
  File "D:/My Projects/Attention_BLSTM/script3.py", line 9, in <module>
    model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 419, in load_model
    model = _deserialize_model(f, custom_objects, compile)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 225, in _deserialize_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 458, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 145, in deserialize_keras_object
    list(custom_objects.items())))
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1022, in from_config
    process_layer(layer_data)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1008, in process_layer
    custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 147, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1109, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'pool_size'

推荐答案

实际上,我认为您无法加载此模型.

Actually I don't think you can load this model.

最可能的问题是您没有在图层中实现get_config()方法.此方法返回应保存的配置值字典:

The most likely issue is that you did not implement the get_config() method in your layer. This method returns a dictionary of configuration values that should be saved:

def get_config(self):
    config = {'pool_size': self.pool_size,
              'axis': self.axis}
    base_config = super(MyMeanPooling, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

将这种方法添加到图层后,您必须重新训练模型,因为先前保存的模型没有将此图层的配置保存到其中.这就是为什么您不能加载它的原因,进行此更改后需要重新培训.

You have to retrain the model after adding this method to your layer, as the previously saved model does not have the configuration for this layer saved into it. This is why you cannot load it, it requires retraining after making this change.

这篇关于如何从.h5文件正确加载带有自定义图层的Keras模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆