在tensorflow中的张量对象上进行非连续索引切片(高级索引如numpy) [英] non continuous index slicing on tensor object in tensorflow (Advanced indexing like numpy)

查看:154
本文介绍了在tensorflow中的张量对象上进行非连续索引切片(高级索引如numpy)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我研究了张量流切片的不同方式,即tf.gathertf.gather_nd. 在tf.gather中,它只是在一个维度上切片,在tf.gather_nd中,它也只接受一个indices应用于输入张量.

I have looked at different ways of slicing in tensorflow, namely, tf.gather and tf.gather_nd. In tf.gather, it just slices over a dimension, and also in tf.gather_nd it just accepts one indices to be applied over the input tensor.

我需要的是不同的,我想使用两个不同的张量在输入张量上切片;一个切片在行上,第二切片在列上,它们不一定具有相同的形状.

What I need is different, I want to slice over the input tensor using two different tensor;one slices over the rows the second slices over the column and they are not in the same shape necessarily.

例如:

假设这是我的输入张量,我想在其中提取部分张量.

suppose this is my input tensor in which I want to extract part of it.

input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])

第二个是:

 rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

第三个张量:

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

现在,我想使用rows_tfcolumns_tfinput_tf进行切片.行中的索引[1 2 5]columns_tf中的索引[1].同样,在columns_tf中的[1 2 5]行和[2]行.

Now, I want to slice over input_tf using rows_tf and columns_tf. index [1 2 5] in rows and [1] in columns_tf. Again, rows [1 2 5] with [2] in columns_tf.

或者,[1 4 6][2].

总体而言,rows_tf中的每个索引与columns_tf中的相同索引都将提取input_tf的一部分.

Overall, each index in the rows_tf, with the same index in columns_tfwill extract part of the input_tf.

因此,预期输出为:

[[8.3356,    0.,        8.457685 ],
 [0.,        6.103182,  8.602337 ],
 [8.8974,    7.330564,  0.       ],
 [0.,        3.8914037, 5.826657 ],
 [8.8974,    0.,        8.283971 ],
 [6.103182,  3.0614321, 5.826657 ],
 [7.330564,  0.,        8.283971 ],
 [6.103182,  3.8914037, 0.       ]]

例如,此处使用

rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)

关于张量流切片的问题很多,尽管他们使用了tf.gathertf.gather_ndtf.stack,但它们没有给出我想要的输出.

There were a couple of questions regarding slicing in tensorflow, though they used tf.gather or tf.gather_nd and tf.stack which it did not give my desired output.

无需提及,在numpy中,我们可以通过调用input_tf[rows_tf, columns_tf]轻松地做到这一点.

No need to mention that in numpy we can easily do that by calling: input_tf[rows_tf, columns_tf].

我还查看了这个高级索引,它试图模拟numpy中可用的高级索引,但是它仍然不像numpy flexible

I also, looked at this advanced indexing which tries to simulate the advanced indexing available in numpy, however it still is not like numpy flexible https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb

这是我尝试过的不正确的方法:

This is what I have tried which is not correct:

tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)

此代码的尺寸输出为(8, 1, 3, 8),这完全不正确.

the dimension output of this code is (8, 1, 3, 8) which is incorrect totally.

提前谢谢!

推荐答案

想法是首先获取稀疏索引(通过将行索引和列索引连接在一起)作为列表.然后,您可以使用gather_nd检索值.

The idea is to first get the sparse indices (by concatenating row index and column index) as a list. Then you can use gather_nd to retrieve the values.


tf.reset_default_graph()
input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])
rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])
columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])
rows_tf = tf.reshape(rows_tf, shape=[-1, 1])
columns_tf = tf.reshape(
    tf.tile(columns_tf, multiples=[1, 3]), 
    shape=[-1, 1])
sparse_indices = tf.reshape(
    tf.concat([rows_tf, columns_tf], axis=-1), 
    shape=[-1, 2])

v = tf.gather_nd(input_tf, sparse_indices)
v = tf.reshape(v, [-1, 3])

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  #print 'rows\n', sess.run(rows_tf)
  #print 'columns\n', sess.run(columns_tf)
  print sess.run(v)

结果将是:

[[ 8.3355999   0.          8.45768547]
 [ 0.          6.10318184  8.60233688]
 [ 8.8973999   7.33056402  0.        ]
 [ 0.          3.89140368  5.82665682]
 [ 8.8973999   0.          8.28397083]
 [ 6.10318184  3.06143212  5.82665682]
 [ 7.33056402  0.          8.28397083]
 [ 6.10318184  3.89140368  0.        ]]

这篇关于在tensorflow中的张量对象上进行非连续索引切片(高级索引如numpy)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆