在TensorFlow中更新权重的子集 [英] Update a subset of weights in TensorFlow
本文介绍了在TensorFlow中更新权重的子集的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
有人知道如何更新前向传播中使用的权重的子集(即仅某些索引)吗?
Does anyone know how to update a subset (i.e. only some indices) of the weights that are used in the forward propagation?
我的猜测是,在按如下所示应用compute_gradients之后,我可能能够做到这一点:
My guess is that I might be able to do that after applying compute_gradients as follows:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
grads_vars = optimizer.compute_gradients(loss, var_list=[weights, bias_h, bias_v])
...然后对grads_vars
中的元组列表进行处理.
...and then do something with the list of tuples in grads_vars
.
推荐答案
您可以使用gather
和scatter_update
的组合.这是一个示例,该示例将位置0
和2
You could use a combination of gather
and scatter_update
. Here's an example that doubles the values at position 0
and 2
indices = tf.constant([0,2])
data = tf.Variable([1,2,3])
data_subset = tf.gather(data, indices)
updated_data_subset = 2*data_subset
sparse_update = tf.scatter_update(data, indices, updated_data_subset)
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
sess.run([sparse_update])
print "Values after:", sess.run([data])
您应该看到
Values before: [array([1, 2, 3], dtype=int32)]
Values after: [array([2, 2, 6], dtype=int32)]
这篇关于在TensorFlow中更新权重的子集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文