在 tensorflow2.3 中加载模型失败有什么解决方案吗? [英] Is there any solution for failing to load model in tensorflow2.3?

查看:63
本文介绍了在 tensorflow2.3 中加载模型失败有什么解决方案吗?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我尝试使用 tf.keras.models.load_model 在 tensorflow 2.3 中加载保存的模型.但是,我遇到了同样的错误https://github.com/tensorflow/tensorflow/issues/41535

I try to use tf.keras.models.load_model to load saved model in tensorflow 2.3. However, I got the same error in https://github.com/tensorflow/tensorflow/issues/41535

这似乎是一个重要的功能.但是这个问题仍然没有解决.有谁知道是否有其他方法可以实现相同的结果?

It seems an important function. But this issue is still not solved. Does anyone know if there is any alternative method to implement the same result?

推荐答案

我找到了一种在 tensorflow 2.3 中加载自定义模型的替代方法.您需要做一些以下更改.我会通过一些代码快照来解释

I found an alternative method to load custom model in tensorflow 2.3. You need to do some following changes. I will explain by some code snapshots

  • 用于自定义模型的__init__().之前,

def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
    layers = []
    layer_configs = {}
    if 'layers' in kwargs.keys():
        layer_configs = kwargs['layers']
    for config in layer_configs:
        layer = tf.keras.layers.deserialize(config)
        layers.append(layer)
    super(custom_model, self).__init__(layers)  # custom_model is your custom model class
    self.mask_ratio = mask_ratio
    self.hyperparam = hyperparam
    ...

之后,

def __init__(self, mask_ratio=0.1, hyperparam=0.1, **kwargs):
    super(custom_model, self).__init__()  # custom_model is your custom model class
    self.mask_ratio = mask_ratio
    self.hyperparam = hyperparam
    ...

  • 在自定义模型类中定义两个函数

  • define two functions in your custom model class

    def get_config(self):
        config = {
            'mask_ratio': self.mask_ratio,
            'hyperparam': self.hyperparam
        }
        base_config = super(custom_model, self).get_config()
        return dict(list(config.items()) + list(base_config.items()))
    @classmethod
    def from_config(cls, config):
        #config = cls().get_config()
        return cls(**config)
    

  • 训练完成后,以'h5'格式保存模型

  • After finishing training, save model using 'h5' format

    model.save(file_path, save_format='h5')
    

  • 最后,加载模型如下代码,

  • Finally, load model as following codes,

    model = tf.keras.models.load_model(model_path, compile=False, custom_objects={'custom_model': custom_model})
    

  • 这篇关于在 tensorflow2.3 中加载模型失败有什么解决方案吗?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

    查看全文
    登录 关闭
    扫码关注1秒登录
    发送“验证码”获取 | 15天全站免登陆