Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告 [英] Tensorflow 2.0: custom keras metric caused tf.function retracing warning

查看:79
本文介绍了Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当我使用以下自定义指标(keras 风格)时:

When I use the following custom metric (keras-style):

from sklearn.metrics import classification_report, f1_score
from tensorflow.keras.callbacks import Callback

class Metrics(Callback):
    def __init__(self, dev_data, classifier, dataloader):
        self.best_f1_score = 0.0
        self.dev_data = dev_data
        self.classifier = classifier
        self.predictor = Predictor(classifier, dataloader)
        self.dataloader = dataloader

    def on_epoch_end(self, epoch, logs=None):
        print("start to evaluate....")
        _, preds = self.predictor(self.dev_data)
        y_trues, y_preds = [self.dataloader.label_vector(v["label"]) for v in self.dev_data], preds
        f1 = f1_score(y_trues, y_preds, average="weighted")
        print(classification_report(y_trues, y_preds,
                                    target_names=self.dataloader.vocab.labels))
        if f1 > self.best_f1_score:
            self.best_f1_score = f1
            self.classifier.save_model()
            print("best metrics, save model...")

我收到以下警告:

W1106 10:49:14.171694 4745115072 def_function.py:474] 在 0x14a3f9d90 处对 .distributed_function 的最后 11 次调用中有 6 次触发了 tf.function 回溯.跟踪是昂贵的,过多的跟踪可能是由于传递了 python 对象而不是张量.此外, tf.function 有 Experiment_relax_shapes=True 选项,可以放宽参数形状,以避免不必要的回溯.请参考 https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_argshttps://www.tensorflow.org/api_docs/python/tf/function 了解更多详情.

W1106 10:49:14.171694 4745115072 def_function.py:474] 6 out of the last 11 calls to .distributed_function at 0x14a3f9d90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

推荐答案

当一个 TF 函数被回溯时会出现这个警告,因为它的参数在形状或 dtype(对于张量)甚至值(Python 或 np 对象或变量)上发生了变化.

This warning occurs when a TF function is retraced because its arguments change in shape or dtype (for Tensors) or even in value (Python or np objects or variables).

在一般情况下,修复方法是在定义您传递给 Keras 或 TF 某处的自定义函数之前使用 @tf.function(experimental_relax_shapes=True).这会尝试检测并避免不必要的回溯,但不能保证解决问题.

In the general case, the fix is to use @tf.function(experimental_relax_shapes=True) before the definition of the custom function that you pass to Keras or TF somewhere. This tries to detect and avoid unnecessary retracing, but is not guaranteed to solve the issue.

在你的例子中,我猜 Predictor 类是一个自定义类,所以将 @tf.function(experimental_relax_shapes=True) 放在 Predictor.predict() 的定义之前.

In your case, i guess the Predictor class is a custom class, so place @tf.function(experimental_relax_shapes=True) before the definition of Predictor.predict().

这篇关于Tensorflow 2.0:自定义 keras 指标导致 tf.function 回溯警告的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆