使用@tffunction 的 Tensorflow2 警告 [英] Tensorflow2 warning using @tffunction

查看:27
本文介绍了使用@tffunction 的 Tensorflow2 警告的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

此示例代码来自 Tensorflow 2

This example code from Tensorflow 2

writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")

@tf.function
def my_func(step):
  with writer.as_default():
    # other model code would go here
    tf.summary.scalar("my_metric", 0.5, step=step)

for step in range(100):
  my_func(step)
  writer.flush()

但它正在抛出警告.

警告:tensorflow:在最近 5 次触发 tf.function 回溯的调用中有 5 次.追踪成本高昂并且过多的跟踪可能是由于传递了python对象而不是张量.此外, tf.function 有Experiment_relax_shapes=放松参数形状的真选项这样可以避免不必要的回溯.请参阅https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_argshttps://www.tensorflow.org/api_docs/python/tf/function 了解更多详情.

WARNING:tensorflow:5 out of the last 5 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

有没有更好的方法来做到这一点?

Is there a better way to do this?

推荐答案

tf.function 有一些特殊性".我强烈推荐阅读这篇文章:https://www.tensorflow.org/tutorials/customization/performance

tf.function has some "peculiarities". I highly recommend reading this article: https://www.tensorflow.org/tutorials/customization/performance

在这种情况下,问题在于每次您使用不同的输入签名调用该函数时都会回溯"(即构建一个新图).对于张量,输入签名指的是 shape 和 dtype,但对于 Python 数字,每个新值都被解释为不同的".在这种情况下,因为您使用每次更改的 step 变量调用该函数,该函数也会每次都被回溯.对于真实"代码(例如在函数内部调用模型),这将非常慢.

In this case, the problem is that the function is "retraced" (i.e. a new graph is built) every time you call with a different input signature. For tensors, input signature refers to shape and dtype, but for Python numbers, every new value is interpreted as "different". In this case, because you call the function with a step variable that changes every time, the function is retraced every single time as well. This will be extremely slow for "real" code (e.g. calling a model inside the function).

您可以通过简单地将 step 转换为张量来修复它,在这种情况下,不同的值将算作新的输入签名:

You can fix it by simply converting step to a tensor, in which case the different values will not count as a new input signature:

for step in range(100):
    step = tf.convert_to_tensor(step, dtype=tf.int64)
    my_func(step)
    writer.flush()

或使用 tf.range 直接获取张量:

or use tf.range to get tensors directly:

for step in tf.range(100):
    step = tf.cast(step, tf.int64)
    my_func(step)
    writer.flush()

这不应该产生警告(而且速度要快得多).

This should not produce warnings (and be much faster).

这篇关于使用@tffunction 的 Tensorflow2 警告的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆