TensorFlow 中张量值的条件赋值 [英] Conditional assignment of tensor values in TensorFlow
问题描述
我想在 tensorflow
中复制以下 numpy
代码.例如,我想为之前值为 1
的所有张量索引分配一个 0
.
I want to replicate the following numpy
code in tensorflow
. For example, I want to assign a 0
to all tensor indices that previously had a value of 1
.
a = np.array([1, 2, 3, 1])
a[a==1] = 0
# a should be [0, 2, 3, 0]
如果我在 tensorflow
中编写类似的代码,我会收到以下错误.
If I write similar code in tensorflow
I get the following error.
TypeError: 'Tensor' object does not support item assignment
方括号中的条件应该是任意的,如a[a<1] = 0
.
The condition in the square brackets should be arbitrary as in a[a<1] = 0
.
有没有办法在tensorflow
中实现这种条件赋值"(因为没有更好的名字)?
Is there a way to realize this "conditional assignment" (for lack of a better name) in tensorflow
?
推荐答案
比较运算符,例如 大于 在 TensorFlow API 中可用.
Comparison operators such as greater than are available within TensorFlow API.
然而,在直接操作张量方面,没有什么能与简洁的 NumPy 语法等效.您必须使用单独的 comparison
、where
和 assign
运算符来执行相同的操作.
However, there is nothing equivalent to the concise NumPy syntax when it comes to manipulating the tensors directly. You have to make use of individual comparison
, where
and assign
operators to perform the same action.
与您的 NumPy 示例等效的代码如下:
Equivalent code to your NumPy example is this:
import tensorflow as tf
a = tf.Variable( [1,2,3,1] )
start_op = tf.global_variables_initializer()
comparison = tf.equal( a, tf.constant( 1 ) )
conditional_assignment_op = a.assign( tf.where (comparison, tf.zeros_like(a), a) )
with tf.Session() as session:
# Equivalent to: a = np.array( [1, 2, 3, 1] )
session.run( start_op )
print( a.eval() )
# Equivalent to: a[a==1] = 0
session.run( conditional_assignment_op )
print( a.eval() )
# Output is:
# [1 2 3 1]
# [0 2 3 0]
打印语句当然是可选的,它们只是为了证明代码正确执行.
The print statements are of course optional, they are just there to demonstrate the code is performing correctly.
这篇关于TensorFlow 中张量值的条件赋值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!