文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Bert的pooler_output是什么?

2023-09-15 15:47

关注

BERT的两个输出

在学习bert的时候,我们知道bert是输出每个token的embeding。但在使用hugging face的bert模型时,发现除了last_hidden_state还多了一个pooler_output输出。

例如:

from transformers import AutoTokenizer, AutoModeltokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")model = AutoModel.from_pretrained("bert-base-uncased")inputs = tokenizer("I'm caixunkun. I like singing, dancing, rap and basketball.", return_tensors="pt")outputs = model(**inputs)print("last_hidden_state shape:", outputs.last_hidden_state.size())print("pooler_output shape:", outputs.pooler_output.size())
last_hidden_state shape: torch.Size([1, 20, 768])pooler_output shape: torch.Size([1, 768])

许多人可能以为pooler_output[CLS]token的embedding,但使用last_hidden_state shape[:, 0]比较后,发现又不是,然后就很奇怪。

Bert的Pooler_output

先说一下结论: pooler_output可以理解成该句子语义的特征向量表示

那它是怎么来的?和[CLS]token的embedding区别在哪?

我们将Bert模型打印一下,会发现最后还有一个BertPooler层,pooler_output就是从这来的。如下所示:

BertModel((embedding): BertEmbeddings(....)(encoder): BertEncoder(... # 12层TransformerEncoder)(pooler): BertPooler(    (dense): Linear(in_features=768, out_features=768, bias=True)    (activation): Tanh()))

其中encoder就是将BERT的所有token经过12个TransformerEncoder进行embedding。pooler就是将[CLS]这个token再过一下全连接层+Tanh激活函数,作为该句子的特征向量

我们可以从Bert源码中验证以上结论。在transformers.models.bert.modeling_bert.BertModel.forward方法中这么一行代码:

# sequence_output就是last_hidden_state# self.pooler就是上面的BertPoolerpooled_output = self.pooler(sequence_output) if self.pooler is not None else None

我们再来看看transformers.models.bert.modeling_bert.BertPooler的源码:

class BertPooler(nn.Module):    def __init__(self, config):        super().__init__()        self.dense = nn.Linear(config.hidden_size, config.hidden_size)        self.activation = nn.Tanh()    def forward(self, hidden_states):# hidden_states的第一个维度是batch_size。所以用[:, 0]取所有句子的[CLS]的embedding        first_token_tensor = hidden_states[:, 0]        pooled_output = self.dense(first_token_tensor)        pooled_output = self.activation(pooled_output)        return pooled_output

从上面的源码可以看出,pooler_output 就是[CLS]embedding又经历了一次全连接层的输出。我们可以通过以下代码进行验证:

print("pooler:", model.pooler)my_pooler_output = model.pooler(outputs.last_hidden_state)print(my_pooler_output[0, :5])print(outputs.pooler_output[0, :5])
pooler: BertPooler(  (dense): Linear(in_features=768, out_features=768, bias=True)  (activation): Tanh())tensor([-0.8129, -0.6216, -0.9810,  0.8090,  0.9032], grad_fn=)tensor([-0.8129, -0.6216, -0.9810,  0.8090,  0.9032], grad_fn=)

Bert的Pooler_output的由来

我们知道,BERT的训练包含两个任务:MLM和NSP任务(Next Sentence Prediction)。 对这两个任务不熟悉的朋友可以参考:BERT源码实现与解读(Pytorch)【论文阅读】BERT 两篇文章。

其中MLM就是挖空,然后让bert预测这个空是什么。做该任务是使用token embedding进行预测。

而Next Sentence Prediction就是预测bert接受的两句话是否为一对。例如:窗前明月光,疑是地上霜 为 True,窗前明月光,李白打开窗为False。

所以,NSP任务需要句子的语义信息来预测,但是我们看下源码是怎么做的。transformers.models.bert.modeling_bert.BertForNextSentencePrediction的部分源码如下:

class BertForNextSentencePrediction(BertPreTrainedModel):    def __init__(self, config):        super().__init__(config)        self.bert = BertModel(config)        self.cls = BertOnlyNSPHead(config)# 这个就是一个 nn.Linear(config.hidden_size, 2)...def forward(...):...outputs = self.bert(...)pooled_output = outputs[1] # 取pooler_outputseq_relationship_scores = self.cls(pooled_output)# 使用pooler_ouput送给后续的全连接层进行预测...

从上面的源码可以看出,在NSP任务训练时,并不是直接使用[CLS]token的embedding作为句子特征传给后续分类头的,而是使用的是pooler_output。个人原因可能是因为直接使用[CLS]的embedding效果不够好。

但在MLM任务时,是直接使用的是last_hidden_state,有兴趣可以看一下transformers.models.bert.modeling_bert.BertForMaskedLM的源码。

来源地址:https://blog.csdn.net/zhaohongfei_358/article/details/127960742

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯