如何从 Tensorflow 中的张量中获取特定行? [英] How to fetch specific rows from a tensor in Tensorflow?

查看:165
本文介绍了如何从 Tensorflow 中的张量中获取特定行?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个定义如下的张量:

I have a tensor defined as follows:

temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]]))

我还有一个要从张量中获取的行索引数组:

I also have an array of indexes of rows to be fetched from tensor:

idx = tf.constant([0, 2])

现在我想在这些索引处取一个 temp_var 的子集,即 idx

Now I want to take a subset of temp_var at those indexes i.e. idx

我知道要获取单个索引或切片,我们可以执行类似的操作

I know that to take a single index or a slice, we can do something like

temp_var[single_row_index, :]

temp_var[start:end, :]

但是如何获取idx 数组指示的行?类似于 temp_var[idx, :] ?

But how to fetch rows indicated by idx array? Something like temp_var[idx, :] ?

推荐答案

tf.gather() op 完全符合您的需要:它从矩阵中选择行(或通常从 N 维张量中选择 (N-1) 维切片).以下是它在您的情况下的工作方式:

The tf.gather() op does exactly what you need: it selects rows from a matrix (or in general (N-1)-dimensional slices from an N-dimensional tensor). Here's how it would work in your case:

temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]))
idx = tf.constant([0, 2])

rows = tf.gather(temp_var, idx)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

print(sess.run(rows))  # ==> [[1, 2, 3], [7, 8, 9]]

这篇关于如何从 Tensorflow 中的张量中获取特定行?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆