tensorflow:为什么collect_nd是可区分的? [英] tensorflow: how come gather_nd is differentiable?

查看:149
本文介绍了tensorflow:为什么collect_nd是可区分的?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在研究一个tensorflow网络,该网络为 CartPole 开放式实施强化学习-ai环境

I'm looking at a tensorflow network implementing reinforcement-learning for the CartPole open-ai env.

网络实施似然比方法用于策略梯度代理.

The network implements the likelihood ratio approach for a policy gradient agent.

问题是,使用gather_nd op可以定义策略丢失!在这里,看看:

The thing is, that the policy loss is defined using the gather_nd op!! here, look:

    ....
    self.y = tf.nn.softmax(tf.matmul(self.W3,self.h2) + self.b3,dim=0)
    self.curr_reward = tf.placeholder(shape=[None],dtype=tf.float32)
    self.actions_array = tf.placeholder(shape=[None,2],dtype=tf.int32)
    self.pai_array = tf.gather_nd(self.y,self.actions_array)
    self.L = -tf.reduce_mean(tf.log(self.pai_array)*self.curr_reward)

然后,他们就网络的所有参数采用这种损失的导数:

And then they take the derivative of this loss with respect to all the parameters of the network:

    self.gradients = tf.gradients(self.L,tf.trainable_variables())

怎么可能??我认为神经网络的整个问题总是与可区分的ops一起使用,例如cross-entropy,并且绝不会做任何奇怪的事情,例如根据随机选择且显然不可区分的self.actions_array来选择self.y的索引.

How can this be?? I thought that the whole point in neural networks is always working with differentiable ops, like cross-entropy and never do something strange like selecting indexes of self.y according to some self.actions_array selected by random and clearly not differentiable.

我在这里想念什么?谢谢!

What am I missing here? thanks!

推荐答案

如果参数已收集,则渐变为1,否则为0.收集运算符的一个用例是像稀疏的一热点矩阵乘法一样工作.第二个参数是稀疏矩阵的密集表示形式,您只需选择正确的行即可将其与第一个参数相乘".

The gradient is one if the parameter is gathered and zero if it is not. One use-case for the gather operator is to act like a sparse one-hot matrix multiplication. The second argument is the dense representation of the sparse matrix and you "multiply" it with the first argument by just selecting the right rows.

这篇关于tensorflow:为什么collect_nd是可区分的?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆