只更新 Tensorflow 中词嵌入矩阵的一部分 [英] Update only part of the word embedding matrix in Tensorflow

查看:45
本文介绍了只更新 Tensorflow 中词嵌入矩阵的一部分的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我想在训练时更新一个预先训练好的词嵌入矩阵,有没有办法只更新词嵌入矩阵的一个子集?

Assuming that I want to update a pre-trained word-embedding matrix during training, is there a way to update only a subset of the word embedding matrix?

我查看了 Tensorflow API 页面并找到了这个:

I have looked into the Tensorflow API page and found this:

# Create an optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1)

# Compute the gradients for a list of variables.
grads_and_vars = opt.compute_gradients(loss, <list of variables>)

# grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1])) for gv in grads_and_vars]

# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)

但是我如何将其应用于词嵌入矩阵.假设我这样做:

However how do I apply that to the word-embedding matrix. Suppose I do:

word_emb = tf.Variable(0.2 * tf.random_uniform([syn0.shape[0],s['es']], minval=-1.0, maxval=1.0, dtype=tf.float32),name='word_emb',trainable=False)

gather_emb = tf.gather(word_emb,indices) #assuming that I pass some indices as placeholder through feed_dict

opt = tf.train.AdamOptimizer(1e-4)
grad = opt.compute_gradients(loss,gather_emb)

然后我如何使用 opt.apply_gradientstf.scatter_update 来更新原始嵌入矩阵?(另外,如果 compute_gradient 的第二个参数不是 tf.Variable,tensorflow 会抛出错误)

How do I then use opt.apply_gradients and tf.scatter_update to update the original embeddign matrix? (Also, tensorflow throws an error if the second argument of compute_gradient is not a tf.Variable)

推荐答案

TL;DR: opt.minimize(loss),TensorFlow 将为 sparse 更新>word_emb 只修改参与前向传递的 word_emb 行.

TL;DR: The default implementation of opt.minimize(loss), TensorFlow will generate a sparse update for word_emb that modifies only the rows of word_emb that participated in the forward pass.

tf.gather(word_emb, indices) 的梯度 相对于 word_emb 的操作是一个 tf.IndexedSlices 对象(有关更多详细信息,请参阅实现).该对象表示一个稀疏张量,除了indices 选择的行外,它处处为零.调用 opt.minimize(loss) 调用 AdamOptimizer._apply_sparse(word_emb_grad, word_emb),它调用 tf.scatter_sub(word_emb, ...)* 只更新word_embindices 选择.

The gradient of the tf.gather(word_emb, indices) op with respect to word_emb is a tf.IndexedSlices object (see the implementation for more details). This object represents a sparse tensor that is zero everywhere, except for the rows selected by indices. A call to opt.minimize(loss) calls AdamOptimizer._apply_sparse(word_emb_grad, word_emb), which makes a call to tf.scatter_sub(word_emb, ...)* that updates only the rows of word_emb that were selected by indices.

另一方面,如果您想修改 opt.compute_gradients(loss, word_emb),你可以对其indices执行任意TensorFlow操作>values 属性,并创建一个新的 tf.IndexedSlices 可以传递给 opt.apply_gradients([(word_emb, ...)]).例如,您可以使用 MyCapper()(如示例中所示)使用以下调用来限制渐变:

If on the other hand you want to modify the tf.IndexedSlices that is returned by opt.compute_gradients(loss, word_emb), you can perform arbitrary TensorFlow operations on its indices and values properties, and create a new tf.IndexedSlices that can be passed to opt.apply_gradients([(word_emb, ...)]). For example, you could cap the gradients using MyCapper() (as in the example) using the following calls:

grad, = opt.compute_gradients(loss, word_emb)
train_op = opt.apply_gradients(
    [tf.IndexedSlices(MyCapper(grad.values), grad.indices)])

同样,您可以通过创建具有不同索引的新 tf.IndexedSlices 来更改将要修改的索引集.

Similarly, you could change the set of indices that will be modified by creating a new tf.IndexedSlices with a different indices.

* 一般来说,如果你只想更新 TensorFlow 中的一部分变量,你可以使用 tf.scatter_update()tf.scatter_add()tf.scatter_sub()> 运算符,分别设置、加(+=)或减去(-=)先前存储在变量中的值.

* In general, if you want to update only part of a variable in TensorFlow, you can use the tf.scatter_update(), tf.scatter_add(), or tf.scatter_sub() operators, which respectively set, add to (+=) or subtract from (-=) the value previously stored in a variable.

这篇关于只更新 Tensorflow 中词嵌入矩阵的一部分的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆