如何用另一个张量切片 PyTorch 张量? [英] How can I slice a PyTorch tensor with another tensor?

查看:43
本文介绍了如何用另一个张量切片 PyTorch 张量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有:

inp =  torch.randn(4, 1040, 161)

我还有另一个名为 indices 的张量,带有值:

and I have another tensor called indices with values:

tensor([[124, 583, 158, 529],
        [172, 631, 206, 577]], device='cuda:0')

我想要相当于:

inp0 = inp[:,124:172,:]
inp1 = inp[:,583:631,:]
inp2 = inp[:,158:206,:]
inp3 = inp[:,529:577,:]

除了所有加在一起外,具有 [4, 48, 161] 的 .size.我怎样才能做到这一点?

Except all added together, to have a .size of [4, 48, 161]. How can I accomplish this?

目前,我的解决方案是一个 for 循环:

Currently, my solution is a for loop:

            left_indices = torch.empty(inp.size(0), self.side_length, inp.size(2))
            for batch_index in range(len(inp)):
                print(left_indices_start[batch_index].item())
                left_indices[batch_index] = inp[batch_index, left_indices_start[batch_index].item():left_indices_end[batch_index].item()]

推荐答案

Here you go ( 您可能需要在执行以下操作之前使用 tensor=tensor.cpu() 将张量复制到 cpu):

Here you go ( you probably need to copy tensors to cpu using tensor=tensor.cpu() before doing following operations):

index = tensor([[124, 583, 158, 529],
    [172, 631, 206, 577]], device='cuda:0')
#create a concatenated list of ranges of indices you desire to slice
indexer = np.r_[tuple([np.s_[i:j] for (i,j) in zip(index[0,:],index[1,:])])]
#slice using numpy indexing
sliced_inp = inp[:, indexer, :]

这是它的工作原理:

np.s_[i:j] 创建一个从 start=i 到 end=j 索引的切片对象(只是一个范围)>.

np.s_[i:j] creates a slice object (simply a range) of indices from start=i to end=j.

np.r_[i:j, k:m] 创建切片 (i,j)(k,m) 中的所有索引列表(您可以将更多切片传递给 np.r_ 以将它们一次连接在一起.这是仅连接两个切片的示例.)

np.r_[i:j, k:m] creates a list ALL indices in slices (i,j) and (k,m) (You can pass more slices to np.r_ to concatenate them all together at once. This is an example of concatenating only two slices.)

因此,indexer 通过连接切片列表(每个切片是一个索引范围)来创建所有索引的列表.

Therefore, indexer creates a list of ALL indices by concatenating a list of slices (each slice is a range of indices).

更新:如果您需要删除间隔重叠和排序间隔:

UPDATE: If you need to remove interval overlaps and sort intervals:

indexer = np.unique(indexer)

如果您想删除间隔重叠但不排序并保持原始顺序(以及第一次出现的重叠)

if you want to remove interval overlaps but not sort and keep original order (and first occurrences of overlaps)

uni = np.unique(indexer, return_index=True)[1]
indexer = [indexer[index] for index in sorted(uni)]

这篇关于如何用另一个张量切片 PyTorch 张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆