我如何在张量流中使用索引数组? [英] How can I use the index array in tensorflow?

查看:143
本文介绍了我如何在张量流中使用索引数组?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如果给定形状为(5,3)的矩阵a和形状为(5,)的索引数组b,我们可以很容易地获得相应的向量c

If given a matrix a with shape (5,3) and index array b with shape (5,), we can easily get the corresponding vector c through,

c = a[np.arange(5), b]

但是,我不能用张量流做同样的事情,

However, I cannot do the same thing with tensorflow,

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]

回溯(最近一次通话最近):文件",第1行,在 文件 〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", _SliceHelper中的第513行 name = name)

Traceback (most recent call last): File "", line 1, in File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 513, in _SliceHelper name=name)

文件 〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", 第671行,在strided_slice中 rink_axis_mask = shrink_axis_mask)文件〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", 第3688行,在strided_slice中 rink_axis_mask = shrink_axis_mask,名称=名称)文件〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", 第763行,位于apply_op中 op_def = op_def)文件〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", 第2397行,位于create_op中 set_shapes_for_outputs(ret)文件〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", 第1757行,在set_shapes_for_outputs中 形状= shape_func(op)文件〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", 第1707行,在call_with_requiring中 返回call_cpp_shape_fn(op,require_shape_fn = True)文件〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", 第610行,在call_cpp_shape_fn中 debug_python_shape_fn,require_shape_fn)文件〜/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", _call_cpp_shape_fn_impl中的第675行 引发ValueError(err.message)ValueError:形状"必须为等级1,但对于带有输入的"strided_slice_14"(操作数:"StridedSlice"),其等级应为2 形状:[5,3],[2,5],[2,5],[2].

File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 671, in strided_slice shrink_axis_mask=shrink_axis_mask) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3688, in strided_slice shrink_axis_mask=shrink_axis_mask, name=name) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op op_def=op_def) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2397, in create_op set_shapes_for_outputs(ret) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1757, in set_shapes_for_outputs shapes = shape_func(op) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1707, in call_with_requiring return call_cpp_shape_fn(op, require_shape_fn=True) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn debug_python_shape_fn, require_shape_fn) File "~/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 675, in _call_cpp_shape_fn_impl raise ValueError(err.message) ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_14' (op: 'StridedSlice') with input shapes: [5,3], [2,5], [2,5], [2].

我的问题是,如果使用上述方法无法像numpy那样在tensorflow中产生预期的结果,该怎么办?

My question is, if I cannot produce the expected result in tensorflow as in numpy using the above mentioned method, what should I do?

推荐答案

TensorFlow当前未实现此功能. GitHub issue#4638 正在跟踪NumPy样式的高级"索引的实现.但是,您可以使用 tf.gather_nd() 运算符来实现您的程序:

This feature is not currently implemented in TensorFlow. GitHub issue #4638 is tracking the implementation of NumPy-style "advanced" indexing. However, you can use the tf.gather_nd() operator to implement your program:

a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))

row_indices = tf.range(5)

# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])

c = tf.gather_nd(a, indices)

这篇关于我如何在张量流中使用索引数组?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆