BertForSequenceClassification是否在CLS向量上分类? [英] Does BertForSequenceClassification classify on the CLS vector?

查看:371
本文介绍了BertForSequenceClassification是否在CLS向量上分类?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在将 Huggingface Transformer 软件包和BERT与PyTorch一起使用.我正在尝试进行4种情感分类,并且正在使用 BertForSequenceClassification 进行建立一个最终导致最终产生4向softmax的模型.

I'm using the Huggingface Transformer package and BERT with PyTorch. I'm trying to do 4-way sentiment classification and am using BertForSequenceClassification to build a model that eventually leads to a 4-way softmax at the end.

通过阅读BERT论文,我的理解是输入 CLS 令牌的最终密集向量可以代表整个文本字符串:

My understanding from reading the BERT paper is that the final dense vector for the input CLS token serves as a representation of the whole text string:

每个序列的第一个标记始终是特殊分类标记([CLS]).与此标记对应的最终隐藏状态用作分类任务的聚合序列表示.

The first token of every sequence is always a special classification token ([CLS]). The final hidden state corresponding to this token is used as the aggregate sequence representation for classification tasks.

那么, BertForSequenceClassification 是否实际上训练并使用此向量执行最终分类?

So, does BertForSequenceClassification actually train and use this vector to perform the final classification?

我问的原因是因为当我 print(model)时,使用 CLS 向量对我来说并不明显.

The reason I ask is because when I print(model), it is not obvious to me that the CLS vector is being used.

model = BertForSequenceClassification.from_pretrained(
    model_config,
    num_labels=num_labels,
    output_attentions=False,
    output_hidden_states=False
)

print(model)

这是输出的底部:

        (11): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=4, bias=True)

我看到有一个池化层 BertPooler 导致一个 Dropout 导致一个 Linear ,大概执行最终的4向softmax.但是,对我来说, BertPooler 的使用尚不清楚.是仅对 CLS 的隐藏状态进行操作,还是对所有输入令牌的隐藏状态进行某种汇总?

I see that there is a pooling layer BertPooler leading to a Dropout leading to a Linear which presumably performs the final 4-way softmax. However, the use of the BertPooler is not clear to me. Is it operating on only the hidden state of CLS, or is it doing some kind of pooling over hidden states of all the input tokens?

感谢您的帮助.

推荐答案

简短的回答:,您是正确的.确实,他们将CLS令牌(以及仅此令牌)用于 BertForSequenceClassification .

The short answer: Yes, you are correct. Indeed, they use the CLS token (and only that) for BertForSequenceClassification.

查看 的实现BertPooler 显示它正在使用第一个隐藏状态,该状态对应于 [CLS] 令牌.我简要检查了另一个模型(RoBERTa),以查看各个模型之间是否一致.在这里,分类也仅基于 [CLS] 令牌进行,尽管不太明显(请检查行539-542

Looking at the implementation of the BertPooler reveals that it is using the first hidden state, which corresponds to the [CLS] token. I briefly checked one other model (RoBERTa) to see whether this is consistent across models. Here, too, classification only takes place based on the [CLS] token, albeit less obvious (check lines 539-542 here).

这篇关于BertForSequenceClassification是否在CLS向量上分类?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆