BERT 模型微调与传统机器学习的对比

BERT 微调与传统机器学习的区别和联系:

传统机器学习流程

传统机器学习处理文本分类通常包含以下步骤:

  1. 特征工程:手动设计特征(如 TF-IDF、词袋模型)
  2. 模型训练:使用分类器(如 SVM、随机森林、逻辑回归)
  3. 特征和模型调优:反复调整特征和超参数

BERT 微调流程

BERT 微调的典型流程:

  1. 预训练:使用大规模无标注数据预训练 BERT 模型
  2. 数据准备:将文本转换为 BERT 输入格式(tokenize、添加特殊标记)
  3. 模型微调:冻结大部分 BERT 层,只训练分类头(或少量 BERT 层)
  4. 评估与部署:在验证集上评估,保存模型

两者的主要区别

对比项传统机器学习BERT 微调
特征表示手动设计特征(如 TF-IDF)自动学习上下文相关表示
模型复杂度简单到中等(如 SVM、RF)非常复杂(Transformer 架构)
数据依赖需要大量标注数据可以用较少数据达到好效果
领域适应性迁移到新领域需要重新设计特征可以快速适应新领域(通过微调)
计算资源通常较低需要 GPU/TPU

您代码中的 BERT 微调关键点

  1. 数据预处理

    • 使用BertTokenizer将文本转换为 token IDs
    • 添加特殊标记([CLS]、[SEP])
    • 填充和截断到固定长度
  2. 模型架构

    • 基础模型:BERT 预训练模型(bert-base-chinese)
    • 分类头:在 BERT 顶部添加一个全连接层(num_labels=6
    • 微调策略:更新整个模型的权重
  3. 训练优化

    • 使用AdamW优化器(带权重衰减的 Adam)
    • 小学习率(2e-5)避免灾难性遗忘
    • 批量训练(batch_size=2)处理长序列
  4. 优势

    • 利用预训练模型捕获的语言知识
    • 自动学习文本的上下文表示
    • 对训练数据量要求较低
    • 迁移到新领域更容易

何时选择 BERT 微调而非传统方法?

  1. 当您有足够的计算资源时
  2. 当任务数据量有限时
  3. 当需要处理复杂语义理解时
  4. 当需要快速适应新领域时

BERT 微调入门示例,展示了如何将预训练语言模型应用于特定的分类任务。随着 Transformer 架构的普及,这种方法已经成为 NLP 任务的主流解决方案。

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd# 1. 准备数据
data = {'text': ["我想预订明天的机票", "查询今天的天气", "帮我设置闹钟","播放周杰伦的歌曲", "今天有什么新闻", "推荐几部科幻电影"],'label': [0, 1, 2, 3, 4, 5]  # 0:订票, 1:天气, 2:闹钟, 3:音乐, 4:新闻, 5:电影
}
df = pd.DataFrame(data)# 2. 数据集划分
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)# 3. 创建数据集类
class TextClassificationDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len=128):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = str(self.texts[idx])label = self.labels[idx]encoding = self.tokenizer(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',truncation=True,return_attention_mask=True,return_tensors='pt')return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}# 4. 初始化tokenizer和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=6
)# 5. 创建数据加载器
train_dataset = TextClassificationDataset(train_df['text'].values,train_df['label'].values,tokenizer
)
val_dataset = TextClassificationDataset(val_df['text'].values,val_df['label'].values,tokenizer
)train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2)# 6. 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)optimizer = AdamW(model.parameters(), lr=2e-5)
epochs = 3for epoch in range(epochs):model.train()train_loss = 0for batch in train_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)optimizer.zero_grad()outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstrain_loss += loss.item()loss.backward()optimizer.step()avg_train_loss = train_loss / len(train_loader)print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_train_loss:.4f}")# 7. 评估模型
model.eval()
predictions = []
true_labels = []with torch.no_grad():for batch in val_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model(input_ids, attention_mask=attention_mask)preds = torch.argmax(outputs.logits, dim=1)predictions.extend(preds.cpu().numpy())true_labels.extend(labels.cpu().numpy())accuracy = accuracy_score(true_labels, predictions)
print(f"Validation Accuracy: {accuracy:.4f}")# 8. 保存模型
model.save_pretrained('./intent_classifier')
tokenizer.save_pretrained('./intent_classifier')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/909171.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(12)-Fiddler抓包-Fiddler设置IOS手机抓包

1.简介 Fiddler不但能截获各种浏览器发出的 HTTP 请求,也可以截获各种智能手机发出的HTTP/ HTTPS 请求。 Fiddler 能捕获Android 和 Windows Phone 等设备发出的 HTTP/HTTPS 请求。同理也可以截获iOS设备发出的请求,比如 iPhone、iPad 和 MacBook 等苹…

芯科科技Tech Talks技术培训重磅回归:赋能物联网创新,共筑智能互联未来

聚焦于Matter、蓝牙、Wi-Fi、LPWAN、AI/ML五大热门无线协议与技术 为年度盛会Works With大会赋能先行 随着物联网(IoT)和人工智能(AI)技术的飞速发展,越来越多的企业和个人开发者都非常关注最新的无线连接技术和应用…

docker-compose容器单机编排

docker-compose容器单机编排 开篇前言 随着网站架构的升级,容器的使用也越来越频繁,应用服务和容器之间的关系也越发的复杂。 这个就要求研发人员能更好的方法去管理数量较多的服务器,而不能手动挨个管理。 例如一个LNMP 架构,就…

LeetCode--29.两数相除

解题思路: 1.获取信息: 给定两个整数,一个除数,一个被除数,要求返回商(商取整数) 限制条件:(1)不能使用乘法,除法和取余运算 (2&#…

中山大学GaussianFusion:首个将高斯表示引入端到端自动驾驶多传感器融合的新框架

摘要 近年来由于端到端自动驾驶极大简化了原有传统自动驾驶模块化的流程,吸引了来自工业界和学术界的广泛关注。然而,现有的端到端智驾算法通常采用单一传感器,使其在处理复杂多样和具有挑战性的驾驶场景中受到了限制。而多传感器融合可以很…

《哈希算法》题集

1、模板题集 满足差值的数字对 2、课内题集 字符统计 字符串统计 优质数对 3、课后题集 2006 Equations k倍区间 可结合的元素对 满足差值的数字对 异常频率 神秘数对 费里的语言 连连看 本题集为作者(英雄哪里出来)在抖音的独家课程《英雄C入门到精…

Cordova移动应用对云端服务器数据库的跨域访问

Cordova移动应用对云端服务器数据库的跨域访问 当基于类似 Cordova这样的跨平台开发框架进行移动应用的跨平台开发时,往往需要访问部署在公网云端服务器上的数据库,这时就涉及到了跨域数据访问的问题。 文章目录 Cordova移动应用对云端服务器数据库的跨…

mysql知识点3--创建和使用数据库

mysql知识点3–创建数据库 创建数据库 在MySQL中创建数据库使用CREATE DATABASE语句。语法如下: CREATE DATABASE database_name;其中database_name为自定义的数据库名称。例如创建名为test_db的数据库: CREATE DATABASE test_db;可以添加字符集和排…

林业资源多元监测技术守护绿水青山

在云南高黎贡山的密林中,无人机群正以毫米级精度扫描古树年轮;福建武夷山保护区,卫星遥感数据实时追踪着珍稀动植物的栖息地变化;海南热带雨林里,AI算法正从亿万条数据中预测下一场山火的风险……这些科幻场景&#xf…

一阶/二阶Nomoto模型(野本模型)为何“看不到”船速对回转角速度/角加速度的影响?

提问 图中的公式反映的是舵角和力矩之间的关系, 其中可以看到力矩(可以理解为角加速度)以及相应导致的回转角速度和当前的舵速(主要由船速贡献)有关,那么为什么一阶Nomoto模型(一阶野本&#xf…

深入剖析 C++ 默认函数:拷贝构造与赋值运算符重载

目录 1. 简单认识C 类的默认函数 1.1 默认构造函数 1.2 析构函数 1.3 拷贝构造函数 2. 拷贝构造函数的深入理解 拷贝构造的特点: 实际运用 3. 赋值运算符重载的深入理解 3.1.运算符重载 3.2样例 1.比较运算符重载 2.算术运算符重载 3.自增和自减运算符重载 4.输…

板凳-------Mysql cookbook学习 (十--3)

5.16 用短语来进行fulltext查询 mysql> select count(*) from kjv where match(vtext) against(God); ---------- | count(*) | ---------- | 0 | ---------- 1 row in set (0.00 sec)mysql> select count(*) from kjv where match(vtext) against(sin); -------…

python爬虫ip封禁应对办法

目录 一、背景现象 二、准备工作 三、代码实现 一、背景现象 最近在做爬虫项目时,爬取的网站,如果发送请求太频繁的话,对方网站会先是响应缓慢,最后是封禁一段时间。一直是拒绝连接,导致程序无法正常预期的爬取数据…

【AIGC】Qwen3-Embedding:Embedding与Rerank模型新标杆

Qwen3-Embedding:Embedding与Rerank模型新标杆 一、引言二、技术架构与核心创新1. 模型结构与训练策略(1)多阶段训练流程(2)高效推理设计(3)多语言与长上下文支持 2. 与经典模型的性能对比 三、…

算法竞赛阶段二-数据结构(32)数据结构简单介绍

数据结构的基本概念 数据结构是计算机存储、组织数据的方式,旨在高效地访问和修改数据。它是算法设计的基础,直接影响程序的性能。数据结构可分为线性结构和非线性结构两大类。 线性数据结构 线性结构中,数据元素按顺序排列,每…

Windows桌面图标修复

新建文本文件,粘入以下代码,保存为.bat文件,管理员运行这个文件 duecho off taskkill /f /im explorer.exe CD /d %userprofile%\AppData\Local DEL IconCache.db /a start explorer.exe echo 执行完成上面代码作用是删除桌面图标缓存库&…

13.react与next.js的特性和原理

🟡 一句话总结 React 专注于构建组件,而 Next.js 是基于 React 的全栈框架,提供了页面路由、服务端渲染和全栈能力,让你能快速开发现代 Web 应用。 React focuses on building UI components, while Next.js is a full-stack fra…

全栈监控系统架构

全栈监控系统架构 可观测性从数据层面可分为三类: 指标度量(Metrics):记录系统的总体运行状态。事件日志(Logs):记录系统运行期间发生的离散事件。链路追踪(Tracing):记录一个请求接入到结束的处理过程,主要用于排查…

云服务运行安全创新标杆:阿里云飞天洛神云网络子系统“齐天”再次斩获奖项

引言 为认真落实工信部《工业和信息化部办公厅关于印发信息通信网络运行安全管理年实施方案的通知》,2025年5月30日中国信息通信研究院于浙江杭州举办了“云服务运行安全高质量发展交流会”,推动正向引导,巩固云服务安全专项治理成果。会上&a…

刀客doc:WPP走下神坛

一、至暗时刻? 6月11日,快消巨头玛氏公司宣布其价值17 亿美元,在全球70个市场的广告业务交给阳狮集团,这其中包括M&Ms、士力架、宝路等知名品牌。 此前,玛氏公司一直是WPP的大客户。早在今年3月,WPP就…