Tensorflow:使用 argmax 对张量进行切片 [英] Tensorflow: using argmax to slice a tensor

查看:42
本文介绍了Tensorflow:使用 argmax 对张量进行切片的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个形状为 tf.shape(t1) = [1, 1000, 400] 的张量,我使用 max_ind = tf.argmax(t1, axis=-1) 形状为 [1, 1000].现在我有第二个与 t1 形状相同的张量:tf.shape(t2) = [1, 1000, 400].

I have a tensor with shape tf.shape(t1) = [1, 1000, 400] and I obtain the indices of the maxima on the 3rd dimension using max_ind = tf.argmax(t1, axis=-1) which has shape [1, 1000]. Now I have a second tensor that has the same shape as t1: tf.shape(t2) = [1, 1000, 400].

我想使用 t1 中的最大值索引来切片 t2,因此输出具有形式

I want to use the maxima indices from t1 to slice t2 so the output has the form

[1, 1000]

更直观的描述:生成的张量应该类似于 tf.reduce_max(t2,axis=-1) 的结果,但最大值的位置在 t1>

A more visual description: The resulting tensor should be like the result of tf.reduce_max(t2, axis=-1) but with the location of the maxima in t1

推荐答案

你可以通过 tf.gather_nd,虽然它不是很简单.例如,

You can achieve this through tf.gather_nd, although it is not really straightforward. For example,

shape = t1.shape.as_list()
xy_ind = np.stack(np.mgrid[:shape[0], :shape[1]], axis=-1)
gather_ind = tf.concat([xy_ind, max_ind[..., None]], axis=-1)
sliced_t2 = tf.gather_nd(t2, gather_ind)

另一方面,如果您的输入形状在图形构建时间中未知,您可以使用

If on the other hand the shape of your input is unknown as graph construction time, you could use

shape = tf.shape(t1)
xy_ind = tf.stack(tf.meshgrid(tf.range(shape[0]), tf.range(shape[1]),
                              indexing='ij'), axis=-1)

余数同上.

这篇关于Tensorflow:使用 argmax 对张量进行切片的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆