在 Tensorflow 中每行选择一个元素的优雅方式 [英] Elegant Way to Select one Element per Row in Tensorflow
问题描述
鉴于...
- 形状为
[m, n]
的矩阵 - 形状为
[m]
的张量
A
I
我想从 A
中获取元素的列表 J
,其中J[i] = A[i, I[i]]
.
I want to get a list J
of elements from A
where
J[i] = A[i, I[i]]
.
也就是说,I
保存了要从 A
中的每一行中选择的元素的索引.
That is, I
holds the index of the element to select from each row in A
.
上下文:我已经有了 argmax(A, 1)
,现在我还想要 max
.我知道我可以只使用 reduce_max
.在尝试了一段时间之后,我也想出了这个:
Context: I already have the argmax(A, 1)
and now I also want the max
.
I know that I can just use reduce_max
.
And after trying around for a bit I also came up with this:
J = tf.gather_nd(A,
tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))
需要 to_int64
的地方,因为 range 只产生 int32
而 argmax
只产生 int64
.
Where the to_int64
is needed because range only produces int32
and argmax
only produces int64
.
两者都没有让我觉得特别优雅.一个有运行时开销(可能是关于因素 n
),另一个有未知因素认知开销.我在这里遗漏了什么吗?
None of the two strike me as particularly elegant.
One has runtime overhead (probably about factor n
) and the other has an unknown factor cognitive overhead. Am I missing something here?
推荐答案
gather()
函数提供了一种方法:
The gather()
function provides a way to do it:
r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
这篇关于在 Tensorflow 中每行选择一个元素的优雅方式的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!