在Tensorflow中,如何将tf.gather()用于最后一个维度? [英] In Tensorflow, how to use tf.gather() for the last dimension?

查看:291
本文介绍了在Tensorflow中,如何将tf.gather()用于最后一个维度?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试收集最后张量的张量切片,以实现层之间的部分连接。因为输出张量的形状是 [batch_size,h,w,深度] ,所以我想根据最后一个维度选择切片,例如

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor's shape is [batch_size, h, w, depth], I want to select slices based on the last dimension, such as

# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]

但是, tf.gather(L,[0,2,3,8])似乎仅适用于第一个维度(对吗?)有人可以告诉我该怎么做吗?

However, tf.gather(L, [0, 2,3,8]) seems to only work for the first dimension (right?) Can anyone tell me how to do it?

推荐答案

此处有一个跟踪错误可支持此用例: https://github.com/tensorflow/tensorflow/issues/206

There's a tracking bug to support this use-case here: https://github.com/tensorflow/tensorflow/issues/206

您可以:


  1. 对矩阵进行转置,以便首先收集要收集的维数(转置非常昂贵)

  1. transpose your matrix so that dimension to gather is first (transpose is expensive)

将张量重整为1d(重整便宜),然后在线性索引处将收集列的索引转换为单个元素索引的列表,然后重整

reshape your tensor into 1d (reshape is cheap) and turn your gather column indices into a list of individual element indices at linear indexing, then reshape back

这篇关于在Tensorflow中,如何将tf.gather()用于最后一个维度?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆