TensorFlow:tf.where() 出错 [英] TensorFlow: Error with tf.where()

查看:28
本文介绍了TensorFlow:tf.where() 出错的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我不确定为什么 tf.where() 不能按计划工作.我想使用 a 的值,其中 yt 小于 5,否则使用 b.

I am not sure why tf.where() does not work as planned. I want to use the values of a where yt is less that 5, otherwise use b.

tf.InteractiveSession()
yt = tf.constant([10,1,10])
a = tf.constant([1,2,3])
b = tf.constant([3,4,5])
tf.where(tf.less(yt,[5]), a, b).eval()

给出错误

where() takes at most 2 arguments (3 given)

你能告诉我为什么我会收到这个错误吗?有没有其他方法可以做到这一点?

Can you tell me why I am getting this error? Is there any other way to do this?

推荐答案

tf.where() 的语法在 TensorFlow 0.10 之间发生了变化(当它接受两个参数并返回两个输出)和TensorFlow 0.12+(现在接受三个张量参数并返回一个输出,替换之前的tf.select()).

The syntax for tf.where() was changed between TensorFlow 0.10 (when it took two arguments and returned two outputs) and TensorFlow 0.12+ (it now takes three tensor arguments and returns a single output, replacing the former tf.select()).

正如 Himaprasoon 建议,升级到最新版本的 TensorFlow 应该可以解决你的问题.

As Himaprasoon suggests, upgrading to the latest version of TensorFlow should fix your problem.

这篇关于TensorFlow:tf.where() 出错的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆