BERT(Bidirectional Encoder Representations from Transformers)是Google在2018年提出的预训练语言模型,其核心思想是通过双向Transformer结构捕捉上下文信息,为下游NLP任务提供通用的语义表示。
一、模型架构
BERT基于Transformer的编码器(Encoder)堆叠而成,摒弃了解码器(Decoder)。每个Encoder层包含:
自注意力机制(Self-Attention):计算输入序列中每个词与其他词的关系权重,动态聚合上下文信息。
前馈神经网络(FFN):对注意力输出进行非线性变换。
残差连接与层归一化:缓解深层网络训练中的梯度消失问题。
与传统单向语言模型(如GPT)不同,BERT通过同时观察左右两侧的上下文(双向注意力)捕捉词语的完整语义。
二、预训练任务
BERT通过以下两个无监督任务预训练模型:
1.遮蔽语言模型(Masked Language Model, MLM)
训练过程中,输入句子中一部分词会被 (MASK) 标记替换,模型需根据上下文信息预测这些被遮蔽的词。这种任务迫使模型在训练时同时考虑文本前后信息,学习更丰富的语言表征。具体操作是,随机选择 15% 的词汇用于预测,其中 80% 情况下用 (MASK) 替换,10% 情况下用任意词替换,10% 情况下保持原词汇不变。
2.下一句预测(Next Sentence Prediction, NSP)
旨在训练模型理解句子间的连贯性。训练时,模型接收一对句子作为输入,判断两个句子是否是连续的文本序列。通过该任务,模型能学习到句子乃至篇章层面的语义信息。在实际预训练中,会从文本语料库中随机选择 50% 正确语句对和 50% 错误语句对进行训练。
三、输入表示
BERT的输入由三部分嵌入相加组成:
Token Embeddings:词向量(WordPiece分词)。
Segment Embeddings:区分句子A和B(用于NSP任务)。
Position Embeddings:Transformer本身无位置感知,需显式加入位置编码。
四、微调(Fine-tuning)
预训练后,BERT可通过简单的微调适配下游任务:
分类任务(如情感分析):用(CLS)标记的输出向量接分类层。
序列标注(如NER):用每个Token的输出向量预测标签。
问答任务:用两个向量分别预测答案的起止位置。
微调时只需添加少量任务特定层,大部分参数复用预训练模型。
五、Python实现示例
(环境:Python 3.11,paddle 1.0.2, paddlenlp 2.6.1)
import paddle
from paddlenlp.transformers import BertTokenizer, BertForSequenceClassification# 1. 加载预训练模型和分词器
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_classes=2) # 2分类任务# 2. 准备示例数据 (正面和负面情感文本)
texts = ["这部电影太棒了,演员表演出色!", "非常糟糕的体验,完全不推荐。"]
labels = [1, 0] # 1表示正面,0表示负面# 3. 数据预处理
encoded_inputs = tokenizer(texts, max_length=128, padding=True, truncation=True, return_tensors='pd')
input_ids = encoded_inputs['input_ids']
token_type_ids = encoded_inputs['token_type_ids']# 转换为Paddle张量
labels = paddle.to_tensor(labels)# 4. 模型前向计算
outputs = model(input_ids, token_type_ids=token_type_ids)
logits = outputs# 5. 计算损失和预测
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)# 获取预测结果
predictions = paddle.argmax(logits, axis=1)# 打印结果
print("Loss:", loss.item())
print("Predictions:", predictions.numpy())
print("True labels:", labels.numpy())
End.