如何在Tensorflow中更新2D张量的子集? [英] How to update a subset of 2D tensor in Tensorflow?

查看:103
本文介绍了如何在Tensorflow中更新2D张量的子集?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想用值0更新2D张量中的索引.所以数据是2D张量,其第二行第二列索引值将被0代替.但是,我遇到类型错误.有人可以帮我吗?

I want to update an index in a 2D tensor with value 0. So data is a 2D tensor whose 2nd row 2nd column index value is to be replaced by 0. However, I am getting a type error. Can anyone help me with it?

TypeError:"ScatterUpdate"操作的输入"ref"需要输入左值

TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
data2 = tf.reshape(data, [-1])
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0]))
#data = tf.reshape(data, [N,S])
init_op = tf.initialize_all_variables()

sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
#sess.run([updated_data_subset])
print "Values after:", sess.run([sparse_update])

推荐答案

tf.scatter_update仅适用于Variable类型.代码中的dataVariable,而data2不是,因为tf.reshape的返回类型是Tensor.

tf.scatter_update could only be applied to Variable type. data in your code IS a Variable, while data2 IS NOT, because the return type of tf.reshape is Tensor.

解决方案:

v1.0之后的张量流

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0)
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

用于v1.0之前的张量流

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]])
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

这篇关于如何在Tensorflow中更新2D张量的子集?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆