TensorFlow获取特定列的每一行的元素 [英] TensorFlow getting elements of every row for specific columns

查看:612
本文介绍了TensorFlow获取特定列的每一行的元素的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如果A是这样的TensorFlow变量

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

index是另一个变量

index = tf.Variable([0, 1])

我想使用此索引来选择每一行中的列.在这种情况下,第一行的项目0,第二行的项目1.

I want to use this index to select columns in each row. In this case, item 0 from first row and item 1 from second row.

如果A是一个Numpy数组,则可以获取索引中提到的相应行的列,我们可以这样做

If A was a Numpy array then to get the columns of corresponding rows mentioned in index we can do

x = A[np.arange(A.shape[0]), index]

结果将是

[1, 4]

TensorFlow的等效操作是什么?我知道TensorFlow不支持许多索引操作.如果不能直接完成该怎么办?

What is the TensorFlow equivalent operation/operations for this? I know TensorFlow doesn't support many indexing operations. What would be the work around if it cannot be done directly?

推荐答案

闲逛了一段时间.我发现两个有用的功能.

After dabbling around for quite a while. I found two functions that could be useful.

一个是tf.gather_nd(),如果您可以生成张量,则可能很有用 形式为[[0, 0], [1, 1]],因此您可以

One is tf.gather_nd() which might be useful if you can produce a tensor of the form [[0, 0], [1, 1]] and thereby you could do

index = tf.constant([[0, 0], [1, 1]])

tf.gather_nd(A, index)

如果由于某种原因,您无法生成形式为[[0, 0], [1, 1]]的向量(我无法生成此向量,因为在我的情况下,行数取决于占位符),那么我发现的解决方法是使用tf.py_func().这是有关如何完成此操作的示例代码

If you are unable to produce a vector of the form [[0, 0], [1, 1]](I couldn't produce this as the number of rows in my case was dependent on a placeholder) for some reason then the work around I found is to use the tf.py_func(). Here is an example code on how this can be done

import tensorflow as tf 
import numpy as np 

def index_along_every_row(array, index):
    N, _ = array.shape 
    return array[np.arange(N), index]

a = tf.Variable([[1, 2], [3, 4]], dtype=tf.int32)
index = tf.Variable([0, 1], dtype=tf.int32)
a_slice_op = tf.py_func(index_along_every_row, [a, index], [tf.int32])[0]
session = tf.InteractiveSession()

a.initializer.run()
index.initializer.run()
a_slice = a_slice_op.eval() 

a_slice将是一个numpy数组[1, 4]

a_slice will be a numpy array [1, 4]

这篇关于TensorFlow获取特定列的每一行的元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆