TensorFlow:tf.where() 出错 [英] TensorFlow: Error with tf.where()
问题描述
我不确定为什么 tf.where() 不能按计划工作.我想使用 a
的值,其中 yt
小于 5,否则使用 b
.
I am not sure why tf.where() does not work as planned. I want to use the values of a
where yt
is less that 5, otherwise use b
.
tf.InteractiveSession()
yt = tf.constant([10,1,10])
a = tf.constant([1,2,3])
b = tf.constant([3,4,5])
tf.where(tf.less(yt,[5]), a, b).eval()
给出错误
where() takes at most 2 arguments (3 given)
你能告诉我为什么我会收到这个错误吗?有没有其他方法可以做到这一点?
Can you tell me why I am getting this error? Is there any other way to do this?
推荐答案
tf.where()
的语法在 TensorFlow 0.10 之间发生了变化(当它接受两个参数并返回两个输出)和TensorFlow 0.12+(现在接受三个张量参数并返回一个输出,替换之前的tf.select()
).
The syntax for tf.where()
was changed between TensorFlow 0.10 (when it took two arguments and returned two outputs) and TensorFlow 0.12+ (it now takes three tensor arguments and returns a single output, replacing the former tf.select()
).
正如 Himaprasoon 建议,升级到最新版本的 TensorFlow 应该可以解决你的问题.
As Himaprasoon suggests, upgrading to the latest version of TensorFlow should fix your problem.
这篇关于TensorFlow:tf.where() 出错的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!