如何为由 tf 操作组成的操作注册自定义渐变 [英] How to register a custom gradient for a operation composed of tf operations

查看:17
本文介绍了如何为由 tf 操作组成的操作注册自定义渐变的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

更具体地说,我有一个简单的 fprop,它是 tf 操作的组合.我想使用 RegisterGradient 用我自己的梯度方法覆盖 tensorflow 梯度计算.

More specifically I have a simple fprop that is a composition of tf operations. I want to override the tensorflow gradient computation with my own gradient method using RegisterGradient.

这段代码有什么问题?

import tensorflow as tf
from tensorflow.python.framework import ops

@ops.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
    x = op.inputs[0]
    return 0 * x  # zero out to see the difference:

def fprop(x):
    x = tf.sqrt(x)
    out = tf.maximum(x, .2)
    return out

a = tf.Variable(tf.constant([5., 4., 3., 2., 1.], dtype=tf.float32))
h = fprop(a)
h = tf.identity(h, name="Myop")
grad = tf.gradients(h, a)

g = tf.get_default_graph()
with g.gradient_override_map({'Myop': 'MyopGrad'}):
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        result = sess.run(grad)

print(result[0])

我想在打印件中看到全零,但我得到了:

I want to see all zeros in the print, but instead I am getting:

[ 0.2236068   0.25000003  0.28867513  0.35355341  0.5       ]

推荐答案

需要定义范围内的op with g.gradient_override_map({'Myop': 'MyopGrad'})

此外,您需要将 Identity 而不是名称 Myop 映射到您的新渐变.

Also, you need to map Identity rather than the name Myop to your new gradient.

完整代码如下:

import tensorflow as tf
from tensorflow.python.framework import ops

@ops.RegisterGradient("MyopGrad")
def frop_grad(op, grad):
    x = op.inputs[0]
    return 0 * x  # zero out to see the difference:

def fprop(x):
    x = tf.sqrt(x)
    out = tf.maximum(x, .2)
    return out

a = tf.Variable(tf.constant([5., 4., 3., 2., 1.], dtype=tf.float32))
h = fprop(a)

g = tf.get_default_graph()
with g.gradient_override_map({'Identity': 'MyopGrad'}):
    h = tf.identity(h, name="Myop")
    grad = tf.gradients(h, a)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    result = sess.run(grad)

print(result[0])

输出:

[ 0.  0.  0.  0.  0.]

这篇关于如何为由 tf 操作组成的操作注册自定义渐变的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆