大模型微调与高效训练

随着预训练大模型(如BERT、GPT、ViT、LLaMA、CLIP等)的崛起,人工智能进入了一个新的范式:预训练-微调(Pre-train, Fine-tune)。这些大模型在海量数据上学习到了通用的、强大的表示能力和世界知识。然而,要将这些通用模型应用于特定的下游任务或领域,通常还需要进行微调(Fine-tuning)

微调的核心在于调整预训练模型的参数,使其更好地适应目标任务的数据分布和特定需求。但大模型通常拥有数十亿甚至数万亿的参数,直接进行全参数微调会带来巨大的计算资源和存储挑战。本章将深入探讨大模型微调的策略,以及如何采用高效训练技术来应对这些挑战。

4.1 大模型微调:从通用到专精

4.1.1 为什么需要微调?

尽管预训练大模型具有强大的泛化能力,但它们在预训练阶段看到的数据通常是通用的、领域无关的。当我们需要它们完成特定领域的任务时,例如医疗文本分类、法律问答、特定风格的图像生成等,通用知识可能不足以满足需求。微调的目的是:

  • 适应任务特异性: 调整模型,使其更好地理解和处理特定任务的输入输出格式及语义。
  • 适应数据分布: 将模型知识迁移到目标任务的特定数据分布上,提高模型在目标数据上的性能。
  • 提升性能: 通常,经过微调的模型在特定下游任务上的表现会显著优于直接使用预训练模型。
  • 提高效率: 相较于从头开始训练一个新模型,微调一个预训练大模型通常更快、更有效。
4.1.2 全参数微调 (Full Fine-tuning)

核心思想: 全参数微调是最直接的微调方法,它解冻(unfreeze)预训练模型的所有参数,并使用目标任务的标注数据对其进行端到端(end-to-end)的训练。

原理详解: 在全参数微调中,我们加载一个预训练模型的权重,然后像训练一个普通神经网络一样,使用新的数据集和损失函数来训练它。由于模型的所有层都参与梯度计算和参数更新,理论上模型可以最大程度地适应新任务。

优点:

  • 性能潜力大: 如果资源允许且数据集足够大,全参数微调通常能达到最佳性能。
  • 概念简单: 实现起来相对直接。

缺点:

  • 计算资源需求巨大: 对于拥有数十亿参数的大模型,全参数微调需要大量的GPU显存和计算时间。
  • 存储成本高昂: 每个下游任务都需要存储一套完整的模型参数,不便于多任务部署。
  • 灾难性遗忘(Catastrophic Forgetting): 在小规模数据集上进行微调时,模型可能会“遗忘”在预训练阶段学到的通用知识,导致在其他任务上的性能下降。

Python示例:简单文本分类的全参数微调

我们将使用一个预训练的BERT模型进行情感分类任务的全参数微调。

Python

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset, Dataset
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support# 1. 加载预训练模型和分词器
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) # 2个类别:正面/负面情感# 2. 准备数据集 (使用Hugging Face datasets库加载一个情感分析数据集)
# 这里使用 'imdb' 数据集作为示例
# 如果是第一次运行,会自动下载
print("Loading IMDb dataset...")
dataset = load_dataset("imdb")# 预处理数据
def preprocess_function(examples):return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)tokenized_imdb = dataset.map(preprocess_function, batched=True)# 重命名标签列为 'labels' 以符合Trainer的要求
tokenized_imdb = tokenized_imdb.rename_columns({"label": "labels"})
# 移除原始文本列
tokenized_imdb = tokenized_imdb.remove_columns(["text"])
# 设置格式为PyTorch tensors
tokenized_imdb.set_format("torch")# 划分训练集和测试集
small_train_dataset = tokenized_imdb["train"].shuffle(seed=42).select(range(2000)) # 使用小部分数据进行演示
small_eval_dataset = tokenized_imdb["test"].shuffle(seed=42).select(range(500))# 3. 定义评估指标
def compute_metrics(p):predictions, labels = ppredictions = np.argmax(predictions, axis=1)precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')acc = accuracy_score(labels, predictions)return {'accuracy': acc,'f1': f1,'precision': precision,'recall': recall}# 4. 配置训练参数
training_args = TrainingArguments(output_dir="./results_full_finetune",num_train_epochs=3,              # 训练轮次per_device_train_batch_size=16,  # 训练批次大小per_device_eval_batch_size=16,   # 评估批次大小warmup_steps=500,                # 学习率预热步数weight_decay=0.01,               # 权重衰减logging_dir='./logs_full_finetune', # 日志目录logging_steps=100,evaluation_strategy="epoch",     # 每个epoch结束后评估save_strategy="epoch",           # 每个epoch结束后保存模型load_best_model_at_end=True,     # 训练结束后加载最佳模型metric_for_best_model="f1",      # 衡量最佳模型的指标report_to="none"                 # 不上传到任何在线平台
)# 5. 初始化Trainer并开始训练
trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics,
)print("\n--- BERT 全参数微调示例 ---")
# 检查是否有GPU可用
if torch.cuda.is_available():print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:print("No GPU available, training on CPU (will be slow).")trainer.train()print("\nFull fine-tuning completed. Evaluation results:")
eval_results = trainer.evaluate()
print(eval_results)

代码说明:

  • 我们使用了Hugging Face transformers库的Trainer API,它极大地简化了训练过程。
  • AutoTokenizerAutoModelForSequenceClassification 自动加载BERT模型和对应的分词器。
  • load_dataset("imdb") 用于获取情感分类的示例数据。
  • preprocess_function 将文本转换为模型可以理解的token ID序列。
  • TrainingArguments 用于配置各种训练参数,如学习率、批次大小、保存策略等。
  • compute_metrics 定义了用于评估模型性能的指标。
  • trainer.train() 启动训练过程。
4.1.3 参数高效微调 (Parameter-Efficient Fine-tuning, PEFT)

核心思想: PEFT旨在解决全参数微调的缺点,它通过只微调预训练模型中少量新增或现有参数,同时冻结大部分预训练参数,从而大大降低计算和存储成本,并有效避免灾难性遗忘。

PEFT方法可以分为几大类:

  1. 新增适配器模块: 在预训练模型的中间层或输出层插入小型的可训练模块(Adapter)。
  2. 软提示: 优化输入中少量连续的、可学习的“软提示”或“前缀”,而不是修改模型参数。
  3. 低秩适应: 通过低秩分解来近似全参数更新,减少可训练参数。

我们将重点介绍其中最流行且高效的几种方法。

4.1.3.1 LoRA (Low-Rank Adaptation)

  • 显著减少可训练参数: 大幅降低显存消耗和训练时间。
  • 避免灾难性遗忘: 预训练权重冻结,保护了通用知识。
  • 部署高效: 可以在推理时合并权重,不增加额外延迟。
  • 多任务部署: 针对不同任务,只需存储和加载很小的 (BA) 矩阵。

Python示例:使用peft库进行LoRA微调

我们将使用Hugging Face的peft库对预训练的GPT-2模型进行LoRA微调,用于文本生成任务。

首先,确保安装peft库:

Bash

<

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/84496.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

编程技能:字符串函数10,strchr

专栏导航 本节文章分别属于《Win32 学习笔记》和《MFC 学习笔记》两个专栏&#xff0c;故划分为两个专栏导航。读者可以自行选择前往哪个专栏。 &#xff08;一&#xff09;WIn32 专栏导航 上一篇&#xff1a;编程技能&#xff1a;字符串函数09&#xff0c;strncmp 回到目录…

动态规划-53.最大子数组和-力扣(LeetCode)

一、题目解析 在给定顺序的数组中找出一段具有最大和的连续子数组&#xff0c;且大小最小为1. 二、算法原理 1.状态表示 我们可以意一一枚举出所有的子数组&#xff0c;但我们想要的是最大子数组&#xff0c;所以f[i]表示&#xff1a;以i位置为结尾&#xff0c;所有子数组的最…

C++ queue对象创建、queue赋值操作、queue入队、出队、获得队首、获得队尾操作、queue大小操作、代码练习

对象创建&#xff0c;代码见下 #include<iostream> #include<queue>using namespace std;int main() {// 1 默认构造函数queue<int> q1;// 2 拷贝构造函数queue<int> q2(q1);return 0;} queue赋值操作&#xff0c;代码见下 #include<iostream>…

全链路解析:影刀RPA+Coze API自动化工作流实战指南

在数字化转型加速的今天&#xff0c;如何通过RPA与API的深度融合实现业务自动化提效&#xff0c;已成为企业降本增效的核心命题。本文以「影刀RPA」与「Coze API」的深度协作为例&#xff0c;系统性拆解从授权配置、数据交互到批量执行的完整技术链路&#xff0c;助你快速掌握跨…

php本地 curl 请求证书问题解决

错误: cURL error 60: SSL certificate problem: unable to get local issuer certificate (see https://curl.haxx.se/libcurl/c/libcurl-errors.html) for 解决方案 在php目录下创建证书文件夹, 执行下面生成命令, 然后在php.ini 文件中配置证书路径; 重启环境 curl --eta…

【图数据库】--Neo4j 安装

目录 1.Neo4j --概述 2.JDK安装 3.Neo4j--下载 3.1.下载资源包 3.2.创建环境变量 3.3.运行 Neo4j 是目前最流行的图形数据库(Graph Database)&#xff0c;它以节点(Node)、关系(Relationship)和属性(Property)的形式存储数据&#xff0c;专门为处理高度连接的数据而设计。…

MIT 6.S081 2020Lab5 lazy page allocation 个人全流程

文章目录 零、写在前面一、Eliminate allocation from sbrk()1.1 说明1.2 实现 二、Lazy allocation2.1 说明2.2 实现 三、Lazytests and Usertests3.1 说明3.2 实现3.2.1 lazytests3.2.2 usertests 零、写在前面 可以阅读下4.6页面错误异常 像应用程序申请内存&#xff0c;内…

(Git) 稀疏检出(Sparse Checkout) 拉取指定文件

文章目录 &#x1f3ed;作用&#x1f3ed;指令总览&#x1f477;core.sparseCheckout&#x1f477;sparse-checkout 文件 &#x1f3ed;实例演示⭐END&#x1f31f;交流方式 &#x1f3ed;作用 类似于 .gitignore 进行文件的规则匹配。 一般在需要拉取大型项目指定的某些文件…

docker初学

加载镜像&#xff1a;docker load -i ubuntu.tar 导出镜像&#xff1a;docker save -o ubuntu1.tar ubuntu 运行&#xff1a; docker run -it --name mu ubuntu /bin/bash ocker run -dit --name mmus docker.1ms.run/library/ubuntu /bin/bash 进入容器&#xff1a;docke…

Docker系列(二):开机自启动与基础配置、镜像加速器优化与疑难排查指南

引言 docker 的快速部署与高效运行依赖于两大核心环节&#xff1a;基础环境搭建与镜像生态优化。本期博文从零开始&#xff0c;系统讲解 docker 服务的管理配置与镜像加速实践。第一部分聚焦 docker 服务的安装、权限控制与自启动设置&#xff0c;确保环境稳定可用&#xff1b…

计算机视觉(图像算法工程师)学习路线

计算机视觉学习路线 Python基础 常量与变量 列表、元组、字典、集合 运算符 循环 条件控制语句 函数 面向对象与类 包与模块Numpy Pandas Matplotlib numpy机器学习 回归问题 线性回归 Lasso回归 Ridge回归 多项式回归 决策树回归 AdaBoost GBDT 随机森林回归 分类问题 逻辑…

工业软件国产化:构建自主创新生态,赋能制造强国建设

随着全球产业环境的变化和技术的发展&#xff0c;建立自主可控的工业体系成为我国工业转型升级、走新型工业化道路、推动国家制造业竞争水平提升的重要抓手。 市场倒逼与政策护航&#xff0c;国产化进程双轮驱动 据中商产业研究院预测&#xff0c;2025年中国工业软件市场规模…

OpenCV CUDA 模块图像过滤------创建一个高斯滤波器函数createGaussianFilter()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::cuda::createGaussianFilter 是 OpenCV CUDA 模块中的一个工厂函数&#xff0c;用于创建一个高斯滤波器。这个滤波器可以用来平滑图像&#…

【RocketMQ 生产者和消费者】- 生产者发送故障延时策略

文章目录 1. 前言2. FaultItem3. LatencyFaultToleranceImpl 容错集合处理类3.1 updateFaultItem 更新容错集合3.2 isAvailable 判断 broker 是否可用3.3 pickOneAtLeast 至少选出一个故障 broker 4. MQFaultStrategy 故障策略类4.1 属性4.2 updateFaultItem 更新延迟故障容错信…

【HarmonyOS 5】Map Kit 地图服务之应用内地图加载

#HarmonyOS SDK应用服务&#xff0c;#Map Kit&#xff0c;#应用内地图 目录 前期准备 AGC 平台创建项目并创建APP ID 生成调试证书 生成应用证书 p12 与签名文件 csr 获取 cer 数字证书文件 获取 p7b 证书文件 配置项目签名 项目开发 配置Client ID 开通地图服务 配…

(1-6-1)Java 集合

目录 0.知识概述&#xff1a; 1.集合 1.1 集合继承关系类图 1.2 集合遍历的三种方式 1.3 集合排序 1.3.1 Collections实现 1.3.2 自定义排序类 2 List 集合概述 2.1 ArrayList &#xff08;1&#xff09;特点 &#xff08;2&#xff09;常用方法 2.2 LinkedList 3…

Vue.extend

Vue.extend 是 Vue 2 中的一个重要 API&#xff0c;用于基于一个组件配置对象创建一个“可复用的组件构造函数”。它是 Vue 内部构建组件的底层机制之一&#xff0c;适用于某些高级用法&#xff0c;比如手动挂载组件、弹窗动态渲染等。 ⚠️ 在 Vue 3 中已被移除&#xff0c;V…

【MySQL系列】SQL 分组统计与排序

博客目录 引言一、基础语法解析二、GROUP BY 的底层原理三、ORDER BY 的排序机制四、NULL 值的处理策略五、性能优化建议六、高级变体查询 引言 在现代数据分析和数据库管理中&#xff0c;分组统计是最基础也是最核心的操作之一。无论是业务报表生成、用户行为分析还是系统性能…

spring中的InstantiationAwareBeanPostProcessor接口详解

一、接口定位与核心功能 InstantiationAwareBeanPostProcessor是Spring框架中扩展Bean生命周期的关键接口&#xff0c;继承自BeanPostProcessor。它专注于Bean的实例化阶段&#xff08;对象创建和属性注入&#xff09;的干预&#xff0c;而非父接口的初始化阶段&#xff08;如…

uniapp使用sse连接后端,接收后端推过来的消息(app不支持!!)

小白终成大白 文章目录 小白终成大白前言一、什么是SSE呢&#xff1f;和websocket的异同点有什么&#xff1f;相同点不同点 二、直接上实现代码总结 前言 一般的请求就是前端发 后端回复 你一下我一下 如果需要有什么实时性的 后端可以主动告诉前端的技术 我首先会想到 webso…