如何使用 tf.gather_nd 在 tensorflow 中对张量进行切片? [英] How to use tf.gather_nd to slice a tensor in tensorflow?

查看:29
本文介绍了如何使用 tf.gather_nd 在 tensorflow 中对张量进行切片?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在 numpy 中寻找与以下代码等效的 tensorflow.aidx_2 给出.目标是构造b.

I am looking for the tensorflow equivalent of the following code in numpy. a and idx_2 are given. The goal is to construct b.

# A float Tensor obtained somehow
a = np.arange(3*5).reshape(3,5)                    

# An int Tensor obtained somehow
idx_2 = np.array([[1,2,3,4],[0,2,3,4],[0,2,3,4]])  

# An int Tensor, constructed for indexing
idx_1 = np.arange(a.shape[0]).reshape(-1,1)        

# The goal
b = a[idx_1, idx_2]

print(b)
>>> [[ 1  2  3  4]
     [ 5  7  8  9]
     [10 12 13 14]]

我曾尝试直接索引张量并使用 tf.gather_nd 但我不断收到错误,所以我决定在这里询问如何做.我到处寻找答案的人都使用 tf.gather_nd(因此是标题)来解决类似的问题,但要应用这个函数,我必须以某种方式重塑索引,以便它们可用于对第一维进行切片.我该怎么做呢?请帮忙.

I have tried directly indexing the tensors and using tf.gather_nd but I keep getting errors so I decided to ask how to do it here. Everywhere I look for answers people use tf.gather_nd (hence the title) to solve similar problems, but to apply this functions I have to somehow reshape the indexes such that they can be used to slice the first dimension. How do I do this? Please help.

推荐答案

当谈到 NumPy 中非常简单和 Pythonic 的事情时,Tensorflow 可能非常丑陋.这是我如何使用 tf.gather_nd 在 TensorFlow 中重现您的问题.不过,可能有更好的方法来做到这一点.

Tensorflow can be quite ugly when it comes to things that are very simple and Pythonic in NumPy. Here is how I used tf.gather_nd to recreate your problem in TensorFlow. There is probably a much better way to do it though.

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
    # Define 'a'
    a = tf.reshape(tf.range(15),(3,5))
    # Define both index tensors 
    idx_1 = tf.reshape(tf.range(a.get_shape().as_list()[0]),(-1,1)).eval()
    idx_2 = tf.constant([[1,2,3,4],[0,2,3,4],[0,2,3,4]]).eval()
    # get indices for use with gather_nd
    gather_idx = tf.constant([(x[0],y) for (i,x) in enumerate(idx_1) for y in idx_2[i]])
    # extract elements and reshape to desired dimensions
    b = tf.gather_nd(a, gather_idx)
    b = tf.reshape(b,(idx_1.shape[0], idx_2.shape[1]))
    print(sess.run(b))

[[ 1  2  3  4]
[ 5  7  8  9]
[10 12 13 14]]

这篇关于如何使用 tf.gather_nd 在 tensorflow 中对张量进行切片?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆