在 Tensorflow 中,如何将 tf.gather() 用于最后一个维度? [英] In Tensorflow, how to use tf.gather() for the last dimension?
问题描述
我试图根据层之间的部分连接的最后一个维度来收集张量的切片.因为输出张量的形状是[batch_size, h, w,depth]
,所以我想根据最后一个维度来选择切片,比如
I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth]
, I want to select slices based on the last dimension, such as
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
然而,tf.gather(L, [0, 2,3,8])
似乎只适用于第一维(对吧?)谁能告诉我怎么做?
However, tf.gather(L, [0, 2,3,8])
seems to only work for the first dimension (right?) Can anyone tell me how to do it?
推荐答案
这里有一个跟踪错误来支持这个用例:https://github.com/tensorflow/tensorflow/issues/206
There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206
现在您可以:
转置您的矩阵,以便首先收集维度(转置很昂贵)
transpose your matrix so that dimension to gather is first (transpose is expensive)
将你的张量重塑为 1d(重塑是廉价的)并将你的收集列索引转换为线性索引的单个元素索引列表,然后重塑回来
reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back
这篇关于在 Tensorflow 中,如何将 tf.gather() 用于最后一个维度?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!