为什么 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次的 [英] Why W_q matrix in torch.nn.MultiheadAttention is quadratic

查看:512
本文介绍了为什么 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次的的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在我的网络中实现 nn.MultiheadAttention.根据文档

I am trying to implement nn.MultiheadAttention in my network. According to the docs,

embed_dim – 模型的总尺寸.

但是,根据源文件,

embed_dim 必须能被 num_heads 整除

embed_dim must be divisible by num_heads

self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))

如果我理解正确,这意味着每个头部只获取每个查询的一部分特征,因为矩阵是二次的.是实现的错误还是我的理解错误?

If I understand properly, this means each head takes only a part of features of each query, as the matrix is quadratic. Is it a bug of realization or is my understanding wrong?

推荐答案

每个头部使用投影查询向量的不同部分.您可以想象查询被拆分为 num_heads 个向量,这些向量独立用于计算缩放的点积注意力.因此,每个 head 对查询中的特征(以及键和值)的不同线性组合进行操作.这种线性投影是使用 self.q_proj_weight 矩阵完成的,投影后的查询被传递给 F.multi_head_attention_forward 函数.

Each head uses a different part of the projected query vector. You can imagine it as if the query gets split into num_heads vectors that are independently used to compute the scaled dot-product attention. So, each head operates on a different linear combination of the features in queries (and keys and values, too). This linear projection is done using the self.q_proj_weight matrix and the projected queries are passed to F.multi_head_attention_forward function.

F.multi_head_attention_forward,它是通过对查询向量进行整形和转置来实现的,从而可以计算出各个头部的独立注意力通过矩阵乘法高效.

注意力头的大小是 PyTorch 的设计决定.理论上,您可以有不同的头部大小,因此投影矩阵的形状为 embedding_dim × num_heads * head_dims.转换器的一些实现(例如基于 C++ 的 Marian 用于机器翻译,或 Huggingface 的变形金刚) 允许这样做.

The attention head sizes are a design decision of PyTorch. In theory, you could have a different head size, so the projection matrix would have a shape of embedding_dim × num_heads * head_dims. Some implementations of transformers (such as C++-based Marian for machine translation, or Huggingface's Transformers) allow that.

这篇关于为什么 torch.nn.MultiheadAttention 中的 W_q 矩阵是二次的的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆