NotImplementedError:在__init__中带有参数的图层必须覆盖get_config. [英] NotImplementedError: Layers with arguments in `__init__` must override `get_config`

查看:1708
本文介绍了NotImplementedError:在__init__中带有参数的图层必须覆盖get_config.的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正尝试使用model.save()保存TensorFlow模型,但是-我遇到此错误.

I'm trying to save my TensorFlow model using model.save(), however - I am getting this error.

此处提供了模型摘要: 模型摘要

The model summary is provided here: Model Summary

变压器模型的代码:

def transformer(vocab_size, num_layers, units, d_model, num_heads, dropout, name="transformer"):
    inputs = tf.keras.Input(shape=(None,), name="inputs")
    dec_inputs = tf.keras.Input(shape=(None,), name="dec_inputs")

    enc_padding_mask = tf.keras.layers.Lambda(
        create_padding_mask, output_shape=(1, 1, None),
        name='enc_padding_mask')(inputs)
    # mask the future tokens for decoder inputs at the 1st attention block
    look_ahead_mask = tf.keras.layers.Lambda(
        create_look_ahead_mask,
        output_shape=(1, None, None),
        name='look_ahead_mask')(dec_inputs)
    # mask the encoder outputs for the 2nd attention block
    dec_padding_mask = tf.keras.layers.Lambda(
        create_padding_mask, output_shape=(1, 1, None),
        name='dec_padding_mask')(inputs)

    enc_outputs = encoder(
        vocab_size=vocab_size,
        num_layers=num_layers,
        units=units,
        d_model=d_model,
        num_heads=num_heads,
        dropout=dropout,
    )(inputs=[inputs, enc_padding_mask])

    dec_outputs = decoder(
        vocab_size=vocab_size,
        num_layers=num_layers,
        units=units,
        d_model=d_model,
        num_heads=num_heads,
        dropout=dropout,
    )(inputs=[dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask])

    outputs = tf.keras.layers.Dense(units=vocab_size, name="outputs")(dec_outputs)

    return tf.keras.Model(inputs=[inputs, dec_inputs], outputs=outputs, name=name)

我不明白为什么会出现此错误,因为模型训练得很好. 任何帮助将不胜感激.

I don't understand why it's giving this error since the model trains perfectly fine. Any help would be appreciated.

我的保存代码供参考:

print("Saving the model.")
saveloc = "C:/tmp/solar.h5"
model.save(saveloc)
print("Model saved to: " + saveloc + " succesfully.")

推荐答案

这不是bug,而是功能.

It's not a bug, it's a feature.

此错误使您知道TF无法保存模型,因为它无法加载模型.
具体来说,它将无法重新实例化自定义的Layer类: encoder > decoder .

This error lets you know that TF can't save your model, because it won't be able to load it.
Specifically, it won't be able to reinstantiate your custom Layer classes: encoder and decoder.

要解决此问题,只需覆盖其

To solve this, just override their get_config method according to the new arguments you've added.

图层配置是包含图层配置的Python字典(可序列化).以后可以从此配置中重新实例化同一层(没有经过训练的权重).

A layer config is a Python dictionary (serializable) containing the configuration of a layer. The same layer can be reinstantiated later (without its trained weights) from this configuration.


例如,如果您的encoder类看起来像这样:


For example, if your encoder class looks something like this:

class encoder(tf.keras.layers.Layer):

    def __init__(
        self,
        vocab_size, num_layers, units, d_model, num_heads, dropout,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.units = units
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = dropout

    # Other methods etc.

然后您只需要覆盖此方法:

then you only need to override this method:

    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'vocab_size': self.vocab_size,
            'num_layers': self.num_layers,
            'units': self.units,
            'd_model': self.d_model,
            'num_heads': self.num_heads,
            'dropout': self.dropout,
        })
        return config

当TF看到这两个类时,您将能够保存模型.

When TF sees this (for both classes), you will be able to save the model.

因为现在在加载模型时,TF将能够通过config重新实例化同一层.

Because now when the model is loaded, TF will be able to reinstantiate the same layer from config.

Layer.from_config 源代码可以更好地了解其工作原理:

Layer.from_config's source code may give a better sense of how it works:

@classmethod
def from_config(cls, config):
  return cls(**config)

这篇关于NotImplementedError:在__init__中带有参数的图层必须覆盖get_config.的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆