Tensorflow:如何通过 tf.gather 传播梯度? [英] Tensorflow: How to Propagate Gradient Through tf.gather?

查看:48
本文介绍了Tensorflow:如何通过 tf.gather 传播梯度?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在尝试针对表示聚集索引的变量传播我的损失函数的梯度时遇到了一些问题,类似于在空间变换器网络 (https://github.com/tensorflow/models/blob/master/transformer/spatial_transformer.py).我觉得我可能遗漏了一些非常简单的东西.这是我想要做的一个简化的玩具示例:

I'm having some issues trying to propagate the gradient of my loss function with respect to a variable that represents the gather index, similar to what is done in spatial transformer networks (https://github.com/tensorflow/models/blob/master/transformer/spatial_transformer.py). I feel like I might be missing something very simple. Here is a simplified toy example of what I would want to do:

import tensorflow as tf
import numpy as np
lf = np.array([1.0,2.0,3.0])
lf_b = 2.0
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=(3))
pt = tf.Variable(0, name='point')
y_ = tf.placeholder(tf.float32, shape=())
sess.run(tf.initialize_all_variables())
y = tf.gather(x, pt)
data_loss = tf.reduce_mean(tf.squared_difference(y,y_))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(data_loss)

目前,这会返回错误:ValueError:没有为任何变量提供梯度

Currently, this returns the error: ValueError: No gradients provided for any variable

推荐答案

问题是你无法区分,因为 pt 是一个整数.它在 x 占位符中选择一个索引,因此它没有导数.通常,当您执行此操作时,您会输入一个整数并使用它来选择一个浮点值.你是反过来做的.

The problem is that you can't differentiate since pt is an integer. It is selecting one index in the x placeholder so it does not have a derivative. Normally when you do this you would input an integer and use this to select a floating point value. You are doing it the other way around.

这篇关于Tensorflow:如何通过 tf.gather 传播梯度?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆