等式比较在 TensorFlow 2.0 tf.function() 中不起作用 [英] Equality comparison does not work inside TensorFlow 2.0 tf.function()

查看:38
本文介绍了等式比较在 TensorFlow 2.0 tf.function() 中不起作用的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在关于 TensorFlow 2.0 AutoGraphs 的讨论之后,我一直在玩,并注意到>< 等不等式比较是直接指定的,而等式比较则使用 tf.equal 表示.

Following the discussion on TensorFlow 2.0 AutoGraphs, I've been playing around and noticed that inequality comparisons such as > and < are specified directly, whereas equality comparisons are represented using tf.equal.

这是一个演示示例.此函数使用 > 运算符,并且在调用时运行良好:

Here's an example to demonstrate. This function uses > operator and works well when called:

@tf.function
def greater_than_zero(value):
    return value > 0

greater_than_zero(tf.constant(1))
#  <tf.Tensor: id=1377, shape=(), dtype=bool, numpy=True>
greater_than_zero(tf.constant(-1))
# <tf.Tensor: id=1380, shape=(), dtype=bool, numpy=False>

这是另一个使用相等比较的函数,但不起作用:

Here is another function that uses equality comparison, but does not work:

@tf.function
def equal_to_zero(value):
    return value == 0

equal_to_zero(tf.constant(1))
# <tf.Tensor: id=1389, shape=(), dtype=bool, numpy=False>  # OK...

equal_to_zero(tf.constant(0))
# <tf.Tensor: id=1392, shape=(), dtype=bool, numpy=False>  # WHAT?

如果我将 == 相等比较更改为 tf.equal,它将起作用.

If I change the == equality comparison to tf.equal, it will work.

@tf.function
def equal_to_zero2(value):
    return tf.equal(value, 0)

equal_to_zero2(tf.constant(0))
# <tf.Tensor: id=1402, shape=(), dtype=bool, numpy=True>

我的问题是:为什么在 tf.function 函数中使用不等式比较运算符可以,而等式比较不行?

My question is: Why does using inequality comparison operators work inside tf.function functions, whereas equality comparisons do not?

推荐答案

我在文章的第 3 部分分析了这种行为 "分析 tf.function 以发现 Autograph 的优势和微妙之处"(我强烈建议阅读所有 3 部分以了解如何正确在使用 tf.function 装饰之前编写一个函数 - 答案底部的链接).

I analyzed this behavior in part 3 of the article "Analysing tf.function to discover Autograph strengths and subtleties" (and I highly recommend reading all the 3 parts to understand how to correctly write a function before decorating it with tf.function - links at the bottom of the answer).

对于 __eq__tf.equal 问题,答案是:

For the __eq__ and tf.equal question, the answer is:

简而言之:__eq__ 运算符(用于 tf.Tensor)已被覆盖,但该运算符不使用 tf.equal检查 Tensor 相等性,它只检查 Python 变量标识(如果您熟悉 Java 编程语言,这与用于字符串对象的 == 运算符完全相同).原因是 tf.Tensor 对象需要是可散列的,因为它在 Tensorflow 代码库中的任何地方都被用作 dict 对象的键.

In short: the __eq__ operator (for tf.Tensor) has been overridden, but the operator does not use tf.equal to check for the Tensor equality, it just checks for the Python variable identity (if you are familiar with the Java programming language, this is precisely like the == operator used on string objects). The reason is that the tf.Tensor object needs to be hashable since it is used everywhere in the Tensorflow codebase as key for dict objects.

对于所有其他运算符,答案是 AutoGraph 不会将 Python 运算符转换为 TensorFlow 逻辑运算符.在部分 AutoGraph(不)如何转换运算符 我展示了每个 Python 运算符都被转换为一个总是被评估为假的图形表示.

While for all the other operators, the answer is that AutoGraph doesn't convert Python operators to TensorFlow logical operators. In the section How AutoGraph (don’t) converts the operators I showed that every Python operator gets converted to a graph representation that is always evaluated as false.

事实上,下面的例子产生作为输出wat"

In fact, the following example produces as output "wat"

@tf.function
def if_elif(a, b):
  if a > b:
    tf.print("a > b", a, b)
  elif a == b:
    tf.print("a == b", a, b)
  elif a < b:
    tf.print("a < b", a, b)
  else:
    tf.print("wat")
x = tf.constant(1)
if_elif(x,x)

在实践中,AutoGraph 无法将 Python 代码转换为图形代码;我们必须仅使用 TensorFlow 原语来帮助它.在这种情况下,您的代码将按预期工作.

In practice, AutoGraph is unable to convert Python code to graph code; we have to help it using only the TensorFlow primitives. In that case, your code will work as you expect.

@tf.function
def if_elif(a, b):
  if tf.math.greater(a, b):
    tf.print("a > b", a, b)
  elif tf.math.equal(a, b):
    tf.print("a == b", a, b)
  elif tf.math.less(a, b):
    tf.print("a < b", a, b)
  else:
    tf.print("wat")

我把所有三篇文章的链接都放在这里,我想你会发现它们很有用:

I let here the links to all the three articles, I guess you'll find them usefult:

第 1 部分, 第 2 部分, 第 3 部分

这篇关于等式比较在 TensorFlow 2.0 tf.function() 中不起作用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆