如何停止张量在张量流中的某些条目的梯度 [英] How to stop gradient for some entry of a tensor in tensorflow

查看:32
本文介绍了如何停止张量在张量流中的某些条目的梯度的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试实现一个嵌入层.嵌入将使用预训练的手套嵌入进行初始化.对于可以在手套中找到的单词,它将被修复.对于那些没有出现在glove中的词,它会被随机初始化,并且是可训练的.我如何在张量流中做到这一点?我知道整个张量有一个 tf.stop_gradient ,这种情况有没有任何类型的 stop_gradient api?或者,有什么解决方法吗?任何建议表示赞赏

I am trying to implement an embedding layer. The embedding is going to be initialized using pre-trained glove embedding. For words that can be found in glove, it will be fixed. For those words that don't appear in glove, it will be initialized randomly, and will be trainable. How do I do it in tensorflow? I am aware that there is a tf.stop_gradient for a whole tensor, is there any kind of stop_gradient api for this kind of scenario? or, is there any workaround for this? any suggestion is appreciated

推荐答案

所以思路是用masktf.stop_gradient来破解这个问题:

So the idea is to use mask and tf.stop_gradient to crack this problem:

res_matrix = tf.stop_gradient(mask_h*E) + mask*E,

在矩阵 mask 中,1 表示我想应用渐变的条目,0 表示我不想应用渐变的条目(设置渐变为 0),mask_hmask 的倒数(1 翻转为 0,0 翻转为 1).然后我们可以从 res_matrix 中获取.这是测试代码:

where in matrix mask, 1 denotes to which entry I would like to apply gradient, 0 denotes to which entry I don't want to apply gradient(set gradient to 0), mask_h is the invese of mask (1 flip to 0, 0 flip to 1) .Then we can fetch from the res_matrix . here is the testing code:

import tensorflow as tf
import numpy as np

def entry_stop_gradients(target, mask):
    mask_h = tf.abs(mask-1)
    return tf.stop_gradient(mask_h * target) + mask * target

mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)

emb = tf.constant(np.ones([10, 5]))

matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))

parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)

loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
    print sess.run(loss)
    print sess.run([grad1, grad2])

这篇关于如何停止张量在张量流中的某些条目的梯度的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆