在特定索引后用零填充火炬张量 [英] Filling torch tensor with zeros after certain index

查看:67
本文介绍了在特定索引后用零填充火炬张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

给定一个 3d Tenzor,说:batch x 句子长度 x embedding dim

Given a 3d tenzor, say: batch x sentence length x embedding dim

a = torch.rand((10, 1000, 96)) 

以及每个句子的实际长度数组(或张量)

and an array(or tensor) of actual lengths for each sentence

lengths =  torch .randint(1000,(10,))

输出张量([ 370., 502., 652., 859., 545., 964., 566., 576.,1000., 803.])

如何根据张量长度"在维度 1(句子长度)的某个索引后用零填充张量a"?

How to fill tensor ‘a’ with zeros after certain index along dimension 1 (sentence length) according to tensor ‘lengths’ ?

我想要这样的:

a[ : , lengths : , : ]  = 0

一种方法(如果批量足够大,速度会很慢):

One way of doing it (slow if batch size is big enough):

for i_batch in range(10):
    a[ i_batch  , lengths[i_batch ] : , : ]  = 0

推荐答案

您可以使用二进制掩码来完成.
使用 lengths 作为 mask 的列索引,我们指示每个序列的结束位置(注意我们使 mask 长于 a.size(1) 允许全长序列).
使用 cumsum() 我们将 seq len 之后的 mask 中的所有条目设置为 1.

You can do it using a binary mask.
Using lengths as column-indices to mask we indicate where each sequence ends (note that we make mask longer than a.size(1) to allow for sequences with full length).
Using cumsum() we set all entries in mask after the seq len to 1.

mask = torch.zeros(a.shape[0], a.shape[1] + 1, dtype=a.dtype, device=a.device)
mask[(torch.arange(a.shape[0]), lengths)] = 1
mask = mask.cumsum(dim=1)[:, :-1]  # remove the superfluous column
a = a * (1. - mask[..., None])     # use mask to zero after each column

对于 a.shape = (10, 5, 96)lengths = [1, 2, 1, 1, 3, 0, 4, 4, 1, 3].
在每一行为各自的 lengths 分配 1,mask 看起来像:

For a.shape = (10, 5, 96), and lengths = [1, 2, 1, 1, 3, 0, 4, 4, 1, 3].
Assigning 1 to respective lengths at each row, mask looks like:

mask = 
tensor([[0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.]])

cumsum之后你得到

mask = 
tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1.]])

请注意,它在有效序列条目所在的位置正好有零,而在序列长度之外的地方有一个.取 1 - mask 给你你想要的.

Note that it exactly has zeros where the valid sequence entries are and ones beyond the lengths of the sequences. Taking 1 - mask gives you exactly what you want.

享受;)

这篇关于在特定索引后用零填充火炬张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆