在 Tensorflow 中每行选择一个元素的优雅方式 [英] Elegant Way to Select one Element per Row in Tensorflow

查看:36
本文介绍了在 Tensorflow 中每行选择一个元素的优雅方式的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

鉴于...

  • 形状为[m, n]
  • 的矩阵A
  • 形状为 [m]
  • 的张量 I

我想从 A 中获取元素的列表 J,其中J[i] = A[i, I[i]].

I want to get a list J of elements from A where J[i] = A[i, I[i]].

也就是说,I 保存了要从 A 中的每一行中选择的元素的索引.

That is, I holds the index of the element to select from each row in A.

上下文:我已经有了 argmax(A, 1),现在我还想要 max.我知道我可以只使用 reduce_max.在尝试了一段时间之后,我也想出了这个:

Context: I already have the argmax(A, 1) and now I also want the max. I know that I can just use reduce_max. And after trying around for a bit I also came up with this:

J = tf.gather_nd(A,
    tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))

需要 to_int64 的地方,因为 range 只产生 int32argmax 只产生 int64.

Where the to_int64 is needed because range only produces int32 and argmax only produces int64.

两者都没有让我觉得特别优雅.一个有运行时开销(可能是关于因素 n),另一个有未知因素认知开销.我在这里遗漏了什么吗?

None of the two strike me as particularly elegant. One has runtime overhead (probably about factor n) and the other has an unknown factor cognitive overhead. Am I missing something here?

推荐答案

gather() 函数提供了一种方法:

The gather() function provides a way to do it:

r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)

这篇关于在 Tensorflow 中每行选择一个元素的优雅方式的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆