在 PyTorch 中使用 3D 张量指数切片 4D 张量 [英] Slicing a 4D tensor with a 3D tensor-index in PyTorch

查看:41
本文介绍了在 PyTorch 中使用 3D 张量指数切片 4D 张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个 4D 张量(恰好是 三批 56x56 图像的堆栈,其中每批有 16 张图像),其大小为 [16, 3, 56, 56].我的目标是为每个像素选择正确的三个批次之一(我的索引图的大小为 [16, 56, 56])并获得我想要的图像.>

现在,我想选择这三个批次中的特定批次的图像,其中 a 具有诸如

之类的值

 [[[ 0, 0, 2, ..., 0, 0, 0],[ 0, 0, 2, ..., 0, 0, 0],[ 0, 0, 0, ..., 0, 0, 0],...,[ 0, 0, 0, ..., 0, 0, 0],[ 0, 2, 0, ..., 0, 0, 0],[ 0, 2, 2, ..., 0, 0, 0]],[[ 0, 2, 0, ..., 1, 1, 0],[ 0, 2, 0, ..., 0, 0, 0],[ 0, 0, 0, ..., 0, 2, 0],...,[ 0, 0, 0, ..., 0, 2, 0],[ 0, 0, 2, ..., 0, 2, 0],[ 0, 0, 2, ..., 0, 0, 0]]]

因此,对于 0,将从第一批中选择值,其中 1 和 2 表示我要从第二批和第三批中选择值.

以下是指数的一些可视化效果,每种颜色表示另一批次.

我试图转置 4D 张量以匹配我的索引的维度,但它没有奏效.它所做的只是给我一份我尝试选择的维度的副本.意思

tposed = torch.transpose(fourD, 0,1) print(indices.size(),outs.size(), tposed[:, 指数].size())

输出

torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])

而我需要的形状是

torch.Size([16, 56, 56]) 或 torch.Size([16, 1, 56, 56])

例如,如果我尝试为批处理中的第一张图像选择正确的值

fourD[0,indices].size()

我得到一个像

torch.Size([16, 56, 56, 56, 56])

更不用说在整个张量上尝试此操作时出现内存不足错误.

对于使用这些索引为我的图像中的每个像素选择这三个批次之一的任何帮助,我表示感谢.

注意:

我已经尝试过这个选项

outs[indices[:,None,:,:]].size()

然后返回

torch.Size([16, 1, 56, 56, 3, 56, 56])

torch.take 没有太大帮助,因为它将输入张量视为一维数组.

解决方案

原来 PyTorch 中有一个函数具有我一直在寻找的功能.

torch.gather(fourD, 1, index.unsqueeze(1))

完成任务.

这里很好地解释了 gather 的作用.

I have a 4D tensor (which happens to be a stack of three batches of 56x56 images where each batch has 16 images) with the size of [16, 3, 56, 56]. My goal is to select the correct one of those three batches (with my index map that has the size of [16, 56, 56]) for each pixel and get the images that I want.

Now, I want to select the particular batches of images inside those three batches, with a which has values such as

       [[[ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  2,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  0,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  0,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  2,  2,  ...,  0,  0,  0]],

        [[ 0,  2,  0,  ...,  1,  1,  0],
         [ 0,  2,  0,  ...,  0,  0,  0],
         [ 0,  0,  0,  ...,  0,  2,  0],
         ...,
         [ 0,  0,  0,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  2,  0],
         [ 0,  0,  2,  ...,  0,  0,  0]]]

So for the 0s, the value will be selected from the first batch, where 1 and 2 will mean I want to select the values from the second and the third batch.

Here are some of the visualizations of the indices, each color denoting another batch.

I have tried to transpose the 4D tensor to match the dimensions of my indices, but it did not work. All it does is to give me a copy of the dimensions I have tried to select. Means

tposed = torch.transpose(fourD, 0,1) print(indices.size(),
outs.size(), tposed[:, indices].size())

outputs

torch.Size([16, 56, 56]) torch.Size([16, 3, 56, 56]) torch.Size([3, 16, 56, 56, 56, 56])

while the shape I need is

torch.Size([16, 56, 56]) or torch.Size([16, 1, 56, 56])

and as an example, if I try to select the right values for only the first image on the batch with

fourD[0,indices].size()

I get a shape like

torch.Size([16, 56, 56, 56, 56])

Not to mention that I get an out of memory error when I try this on the whole tensor.

I appreciate any help for using these indices to select either one of these three batches for each pixel in my images.

Note :

I have tried the option

outs[indices[:,None,:,:]].size()

and that returns

torch.Size([16, 1, 56, 56, 3, 56, 56])

Edit : torch.take does not help much since it treats the input tensor as a single dimensional array.

解决方案

Turns out there is a function in PyTorch that has the functionality I was searching for.

torch.gather(fourD, 1, indices.unsqueeze(1)) 

did the job.

Here is a beautiful explanation of what gather does.

这篇关于在 PyTorch 中使用 3D 张量指数切片 4D 张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆