tf.keras.Model保存:"AssertionError:试图导出引用未跟踪对象Tensor的函数". [英] tf.keras.Model save: "AssertionError: Tried to export a function which references untracked object Tensor"

查看:434
本文介绍了tf.keras.Model保存:"AssertionError:试图导出引用未跟踪对象Tensor的函数".的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在运行Tensorflow NMT教程: https: //github.com/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb

I'm running this Tensorflow NMT tutorial: https://github.com/tensorflow/addons/blob/master/docs/tutorials/networks_seq2seq_nmt.ipynb

当我尝试保存解码器时: decoder.save('decoder'),我得到:

When I try to save the decoder: decoder.save('decoder'), I get:

AssertionError: Tried to export a function which references untracked object Tensor("LuongAttention/memory_layer/Tensordot:0", shape=(1024, 23, 256), dtype=float32).TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

我还尝试注册LuongAttention对象,例如:

I also tried registering the LuongAttention object like:

attention_mechanism = tfa.seq2seq.LuongAttention(units=units, memory=None, memory_sequence_length=BATCH_SIZE*[max_length_input])
custom_objects = {"LuongAttention": attention_mechanism}
with tf.keras.utils.custom_object_scope(custom_objects):
  decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE, attention_mechanism)```

推荐答案

为了保存/加载具有自定义图层的模型或子类模型,应覆盖get_config和from_config方法(可选).另外,您应该使用注册自定义对象,以便Keras知道它.

In order to save/load a model with custom-defined layers, or a subclassed model, you should overwrite the get_config and optionally from_config methods. Additionally, you should use register the custom object so that Keras is aware of it.

请参阅此处: https://www.tensorflow.org/guide/keras/save_and_serialize

这篇关于tf.keras.Model保存:"AssertionError:试图导出引用未跟踪对象Tensor的函数".的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆