无法加载keras训练的模型 [英] Not able to load keras trained model

查看:213
本文介绍了无法加载keras训练的模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用以下代码来训练HAN Network. 代码链接

I am using following code to train HAN Network. Code Link

我已经成功地训练了模型,但是当我尝试使用keras load_model加载模型时,出现以下错误- 未知层:AttentionWithContext

I have trained the model successfully but when I tried to load the model using keras load_model it gives me following error- Unknown layer: AttentionWithContext

推荐答案

在AttentionWithContext.py文件中添加以下功能:

Add the following function in the AttentionWithContext.py file:

def create_custom_objects():
    instance_holder = {"instance": None}

    class ClassWrapper(AttentionWithContext):
        def __init__(self, *args, **kwargs):
            instance_holder["instance"] = self
            super(ClassWrapper, self).__init__(*args, **kwargs)

    def loss(*args):
        method = getattr(instance_holder["instance"], "loss_function")
        return method(*args)

    def accuracy(*args):
        method = getattr(instance_holder["instance"], "accuracy")
        return method(*args)
    return {"ClassWrapper": ClassWrapper ,"AttentionWithContext": ClassWrapper, "loss": loss,
            "accuracy":accuracy}

加载模型时:

from AttentionWithContext import create_custom_objects

model = keras.models.load_model(model_path, custom_objects=create_custom_objects())

model.evaluate(X_test, y_test) # or model.predict

这篇关于无法加载keras训练的模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆