使用@tffunction 的 Tensorflow2 警告 [英] Tensorflow2 warning using @tffunction
问题描述
此示例代码来自 Tensorflow 2
This example code from Tensorflow 2
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")
@tf.function
def my_func(step):
with writer.as_default():
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
for step in range(100):
my_func(step)
writer.flush()
但它正在抛出警告.
警告:tensorflow:在最近 5 次触发 tf.function 回溯的调用中有 5 次.追踪成本高昂并且过多的跟踪可能是由于传递了python对象而不是张量.此外, tf.function 有Experiment_relax_shapes=放松参数形状的真选项这样可以避免不必要的回溯.请参阅https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args和 https://www.tensorflow.org/api_docs/python/tf/function 了解更多详情.
WARNING:tensorflow:5 out of the last 5 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
有没有更好的方法来做到这一点?
Is there a better way to do this?
推荐答案
tf.function
有一些特殊性".我强烈推荐阅读这篇文章:https://www.tensorflow.org/tutorials/customization/performance
tf.function
has some "peculiarities". I highly recommend reading this article: https://www.tensorflow.org/tutorials/customization/performance
在这种情况下,问题在于每次您使用不同的输入签名调用该函数时都会回溯"(即构建一个新图).对于张量,输入签名指的是 shape 和 dtype,但对于 Python 数字,每个新值都被解释为不同的".在这种情况下,因为您使用每次更改的 step
变量调用该函数,该函数也会每次都被回溯.对于真实"代码(例如在函数内部调用模型),这将非常慢.
In this case, the problem is that the function is "retraced" (i.e. a new graph is built) every time you call with a different input signature. For tensors, input signature refers to shape and dtype, but for Python numbers, every new value is interpreted as "different". In this case, because you call the function with a step
variable that changes every time, the function is retraced every single time as well. This will be extremely slow for "real" code (e.g. calling a model inside the function).
您可以通过简单地将 step
转换为张量来修复它,在这种情况下,不同的值将不算作新的输入签名:
You can fix it by simply converting step
to a tensor, in which case the different values will not count as a new input signature:
for step in range(100):
step = tf.convert_to_tensor(step, dtype=tf.int64)
my_func(step)
writer.flush()
或使用 tf.range
直接获取张量:
or use tf.range
to get tensors directly:
for step in tf.range(100):
step = tf.cast(step, tf.int64)
my_func(step)
writer.flush()
这不应该产生警告(而且速度要快得多).
This should not produce warnings (and be much faster).
这篇关于使用@tffunction 的 Tensorflow2 警告的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!