张量流如何在梯度计算过程中处理不可微节点? [英] How does tensorflow handle non differentiable nodes during gradient calculation?

查看:143
本文介绍了张量流如何在梯度计算过程中处理不可微节点?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我理解自动微分的概念,但是找不到任何解释来说明张量流如何计算不可微函数的误差梯度,例如损失函数中的tf.where或图形中的tf.cond.它工作得很好,但是我想了解张量流如何通过这些节点反向传播错误,因为没有公式可以根据它们计算梯度.

I understood the concept of automatic differentiation, but couldn't find any explanation how tensorflow calculates the error gradient for non differentiable functions as for example tf.where in my loss function or tf.cond in my graph. It works just fine, but I would like to understand how tensorflow backpropagates the error through such nodes, since there is no formula to calculate the gradient from them.

推荐答案

对于tf.where,您有一个具有三个输入的函数,条件C,值为true T,值为false 和一个输出Out.渐变接收一个值,并且必须返回三个值.当前,没有为该条件计算梯度(这几乎没有道理),因此您只需要为TF做梯度.假设输入和输出是向量,假设C[0]True.然后Out[0]来自T[0],其梯度应传播回去.另一方面,F[0]将被丢弃,因此其斜率应设为零.如果Out[1]False,则应传播F[1]的梯度,但不能传播T[1]的梯度.因此,简而言之,对于T,您应该传播给定的渐变,其中CTrue,在False处使其为零,而对于F则相反.如果您查看实现tf.where(Select操作)的梯度,它确实做到了:

In the case of tf.where, you have a function with three inputs, condition C, value on true T and value on false F, and one output Out. The gradient receives one value and has to return three values. Currently, no gradient is computed for the condition (that would hardly make sense), so you just need to do the gradients for T and F. Assuming the input and the outputs are vectors, imagine C[0] is True. Then Out[0] comes from T[0], and its gradient should propagate back. On the other hand, F[0] would have been discarded, so its gradient should be made zero. If Out[1] were False, then the gradient for F[1] should propagate but not for T[1]. So, in short, for T you should propagate the given gradient where C is True and make it zero where it is False, and the opposite for F. If you look at the implementation of the gradient of tf.where (Select operation), it does exactly that:

@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  zeros = array_ops.zeros_like(x)
  return (None, array_ops.where(c, grad, zeros), array_ops.where(
      c, zeros, grad))

请注意,输入值本身未在计算中使用,这将通过产生这些输入的操作的梯度来完成.对于tf.cond代码有点复杂,因为在不同的上下文中使用了相同的操作(Merge),并且tf.cond在内部也使用了Switch操作.但是,想法是相同的.本质上,每个输入都使用Switch运算,因此,已激活的输入(如果条件为True,则为第一个输入,否则为第二个)获得接收到的渐变,而另一个输入则获得关闭"渐变(例如None),并且不会传播回来.

Note the input values themselves are not used in the computation, that will be done by the gradients of the operation producing those inputs. For tf.cond, the code is a bit more complicated, because the same operation (Merge) is used in different contexts, and also tf.cond also uses Switch operations inside. However the idea is the same. Essentially, Switch operations are used for each input, so the input that was activated (the first if the condition was True and the second otherwise) gets the received gradient and the other input gets a "switched off" gradient (like None), and does not propagate back further.

这篇关于张量流如何在梯度计算过程中处理不可微节点?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆