极客时间:在 Google Colab 上尝试 Prefix Tuning

  每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领域的领跑者。点击订阅,与未来同行! 订阅:https://rengongzhineng.io/

Prefix Tuning 是当前最酷的参数高效微调(PEFT)方法之一,它可以在无需重新训练整个大模型的前提下对大语言模型(LLM)进行任务适配。为了理解它的工作原理,我们先了解下背景:传统微调需要更新模型的所有参数,成本高、计算密集。随后出现了 Prompting(提示学习),通过巧妙设计输入引导模型输出;Instruction Tuning(指令微调)进一步提升模型对任务指令的理解能力。再后来,LoRA(低秩适配)通过在网络中插入可训练的低秩矩阵实现任务适配,大大减少了可训练参数。

而 Prefix Tuning 则是另一种思路:它不会更改模型本体参数,也不插入额外矩阵,而是学习一小组“前缀向量”,将它们添加到每一层 Transformer 的输入中。这种方法轻巧快速,非常适合在 Google Colab 这样资源受限的环境中实践。

在这篇博客中,我们将一步步地在 Google Colab 上,使用 Hugging Face Transformers 和 peft 库完成 Prefix Tuning 的演示。


第一步:安装运行环境

!pip install transformers peft datasets accelerate bitsandbytes

使用的库包括:

  • transformers: 加载基础模型

  • peft: 实现 Prefix Tuning

  • datasets: 加载示例数据集

  • acceleratebitsandbytes: 优化训练性能


第二步:加载预训练模型和分词器

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, PrefixTuningConfig, TaskTypemodel_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

这里我们使用 GPT-2 作为演示模型,也可以替换为其他因果语言模型。


第三步:配置 Prefix Tuning

peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,num_virtual_tokens=10
)model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

上述配置在每层 Transformer 中加入了 10 个可学习的虚拟前缀 token,我们将对它们进行微调。


第四步:加载并预处理 Yelp 数据集样本

from datasets import load_datasetdataset = load_dataset("yelp_review_full", cache_dir="/tmp/hf-datasets")
dataset = dataset.shuffle(seed=42).select(range(1000))def preprocess(example):tokens = tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)return {"input_ids": tokens["input_ids"], "attention_mask": tokens["attention_mask"]}dataset = dataset.map(preprocess, batched=True)

第五步:使用 Prefix Tuning 训练模型

from transformers import TrainingArguments, Trainertraining_args = TrainingArguments(output_dir="./prefix_model",per_device_train_batch_size=4,num_train_epochs=1,logging_dir="./logs",logging_steps=10
)trainer = Trainer(model=model,args=training_args,train_dataset=dataset
)trainer.train()

第六步:保存并加载 Prefix Adapter

model.save_pretrained("prefix_yelp")

之后加载方法如下:

from peft import PeftModelbase_model = AutoModelForCausalLM.from_pretrained("gpt2")
prefix_model = PeftModel.from_pretrained(base_model, "prefix_yelp")

第七步:推理测试

训练完成后,我们可以使用调优后的模型进行生成测试。

input_text = "This restaurant was absolutely amazing!"
inputs = tokenizer(input_text, return_tensors="pt")output = prefix_model.generate(**inputs, max_new_tokens=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("\nGenerated Output:")
print(generated_text)

示例输出

这是在训练 3 个 epoch 并使用 20 个虚拟 token 后的输出示例:

This restaurant was absolutely amazing!, a the the the the the the the the the the the the the the the the the the the the the the the the the the the the the a., and the way. , and the was

虽然模型初步模仿了 Yelp 评论的风格,但输出仍重复性强、连贯性不足。为获得更好效果,可增加训练数据、延长训练周期,或使用更强的基础模型(如 gpt2-medium)。


完整代码

以下是经过改进的完整代码(包含更大前缀尺寸和更多训练轮次):

# 安装依赖
!pip install -U fsspec==2023.9.2
!pip install transformers peft datasets accelerate bitsandbytes# 加载模型与分词器
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, PrefixTuningConfig, TaskTypemodel_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
model.config.pad_token_id = tokenizer.pad_token_id# 配置 Prefix Tuning
peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,num_virtual_tokens=20  # 使用更多虚拟 token
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()# 加载和预处理数据集
from datasets import load_dataset
try:dataset = load_dataset("yelp_review_full", split="train[:1000]")
except:dataset = load_dataset("yelp_review_full")dataset = dataset["train"].select(range(1000))def preprocess(examples):tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)tokenized["labels"] = [[-100 if mask == 0 else token for token, mask in zip(input_ids, attention_mask)]for input_ids, attention_mask in zip(tokenized["input_ids"], tokenized["attention_mask"])]return tokenizeddataset = dataset.map(preprocess, batched=True, remove_columns=["text", "label"])# 配置训练参数
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir="./prefix_model",per_device_train_batch_size=4,num_train_epochs=3,  # 增加轮次logging_dir="./logs",logging_steps=10,report_to="none"
)trainer = Trainer(model=model,args=training_args,train_dataset=dataset
)trainer.train()# 保存和加载前缀
model.save_pretrained("prefix_yelp")
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
prefix_model = PeftModel.from_pretrained(base_model, "prefix_yelp")# 推理
input_text = "This restaurant was absolutely amazing!"
inputs = tokenizer(input_text, return_tensors="pt")
output = prefix_model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(output[0], skip_special_tokens=True))

不同微调技术如何选择?

方法特点适合场景
Prompting零样本/少样本,无需训练快速实验、通用模型调用
Instruction Tuning统一风格指导多个任务多任务模型,提示兼容性强
Full Fine-Tuning全模型更新,效果最好但成本高数据量大、计算资源充足场景
LoRA插入低秩矩阵,性能和效率平衡中等规模适配任务、部署灵活
Prefix Tuning训练前缀向量,模块化且轻量多任务共享底模、小规模快速适配

真实应用案例

  • 客服机器人:为不同产品线训练不同前缀,提高回答准确性

  • 法律/医学摘要:为专业领域调优风格和术语的理解

  • 多语种翻译:为不同语言对训练前缀,重用同一个基础模型

  • 角色对话代理:通过前缀改变语气(如正式、幽默、亲切)

  • SaaS 多租户服务:不同客户使用不同前缀,但共用主模型架构


总结

Prefix Tuning 是一种灵活且资源友好的方法,适合:

  • 有多个任务/用户但希望复用基础大模型的情况

  • 算力有限,但希望实现快速个性化的场景

  • 构建模块化、可热切换行为的 LLM 服务

建议从小任务入手测试,尝试不同 prefix 长度与训练轮次,并结合任务类型进行微调策略选择。

如果你想将此教程发布到 Colab、Hugging Face 或本地部署,欢迎继续交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/908726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android设备推送traceroute命令进行网络诊断

文章目录 工作原理下载traceroute for android推送到安卓设备执行traceroutetraceroute www.baidu.com Traceroute(追踪路由) 是一个用于网络诊断的工具,主要用于追踪数据包从源主机到目标主机所经过的路由路径,以及每一跳&#x…

【Linux应用】Linux系统日志上报服务,以及thttpd的配置、发送函数

【Linux应用】Linux系统日志上报服务,以及thttpd的配置、发送函数 文章目录 thttpd服务安装thttpd配置thttpd服务thttpd函数日志效果和文件附录:开发板快速上手:镜像烧录、串口shell、外设挂载、WiFi配置、SSH连接、文件交互(RADX…

Linux 内核内存管理子系统全面解析与体系构建

一、前言: 为什么内存管理是核心知识 内存管理是 Linux 内核最核心也最复杂的子系统之一,其作用包括: 为软件提供独立的虚拟内存空间,实现安全隔离分配/回收物理内存资源,维持系统稳定支持不同类型的内存分配器,最优…

鼠标的拖动效果

1、变量的设置 let isDragging false; let startX; let startY; let endX; let endY; let box null;isDragging : 表示是否推拽startX、startY:表示起始坐标,相对于元素endX、endY:表示结束坐标,相对于元素box&…

SwaggerFuzzer:一款自动化 OpenAPI/Swagger 接口未授权访问测试工具

SwaggerFuzzer 🌐 一款自动化 OpenAPI/Swagger 接口未授权访问测试工具🚀 工具介绍:SwaggerFuzzer✨ 核心功能亮点🚀 快速使用🧰 支持参数 📌 项目结构📥 获取与下载 🌐 一款自动化 …

文献阅读:Exploring Autoencoder-based Error-bounded Compression for Scientific Data

目录 论文简介动机:为什么作者想要解决这个问题?贡献:作者在这篇论文中完成了什么工作(创新点)?规划:他们如何完成工作?离线训练阶段:在线压缩阶段 理由:通过什么实验验证它们的工作…

【业务框架】3C-相机-Cinemachine

概述 插件,做相机需求,等于相机老师傅多年经验总结的工具 Feature Transform:略Control Camera:控制相机参数Noise:增加随机性Blend:CameraBrain的混合列表指定一个虚拟相机到另一个相机的过渡&#xff…

设计一个算法:删除非空单链表L中结点值为x的第一个结点的前驱结点

目录 单链表的存储结构定义如下 快慢指针法 三指针法版本① 三指针法版本② 单链表的存储结构定义如下 typedef struct{Elemtype data;struct Node* next; }LNode,*LinkList; 快慢指针法 void deleteprex(LinkList L, Elemtype e) {if (L NULL || L->next NULL ||…

【Qt】:设置新建类模板

完整的头文件模板 #ifndef %FILENAME%_H #define %FILENAME%_H/*** brief The %CLASSNAME% class* author %USER%* date %DATE%*/ class %CLASSNAME% { public:%CLASSNAME%();~%CLASSNAME%();// 禁止拷贝构造和赋值%CLASSNAME%(const %CLASSNAME%&) delete;%CLASSNAME%&a…

​**​CID字体​**​ 和 ​**​Simple字体​**​

在PDF中,字体类型主要分为 ​​CID字体​​ 和 ​​Simple字体​​ 两大类,它们的主要区别在于编码方式和适用场景。以下是它们的详细对比: ​​1. CID字体(CID-keyed Fonts)​​ CID(Character Identifie…

计组_导学

2025.05.31:老汤讲408计组学习笔记 导学 第1章计算机系统概述:对计算机系统有全局的认识第2章总线系统:简单且独立,不会依赖其他内容,它是被依赖的第3章主存储器:只有了解主存储器的内部结构,才能理解在主存中是如何存储二进制的第4章数据的表示与运算:各种编码以及计算…

【GPT模型训练】第二课:张量与秩:从数学本质到深度学习的基础概念解析

这里写自定义目录标题 张量(Tensor)的定义关键特点:示例: 张量的秩(Rank)示例:“秩”的拼音常见混淆点 总结 张量(Tensor)的定义 在数学和物理学中,张量是一…

RabbitMQ work模型

Work 模型是 RabbitMQ 最基础的消息处理模式,核心思想是 ​​多个消费者竞争消费同一个队列中的消息​​,适用于任务分发和负载均衡场景。同一个消息只会被一个消费者处理。 当一个消息队列绑定了多个消费者,每个消息消费的个数都是平摊的&a…

【Linux操作系统】基础开发工具(yum、vim、gcc/g++)

文章目录 Linux软件包管理器 - yumLinux下的三种安装方式什么是软件包认识Yum与RPMyum常用指令更新软件安装与卸载查找与搜索清理缓存与重建元数据 yum源更新1. 备份现有的 yum 源配置2. 下载新的 repo 文件3. 清理并重建缓存 Linux编辑器 - vim启动vimVim 的三种主要模式常用操…

73常用控件_QFormLayout的使用

目录 代码⽰例: 使⽤ QFormLayout 创建表单. 除了上述的布局管理器之外, Qt 还提供了 QFormLayout , 属于是 QGridLayout 的特殊情况, 专 ⻔⽤于实现两列表单的布局. 这种表单布局多⽤于让⽤⼾填写信息的场景. 左侧列为提⽰, 右侧列为输⼊框 代码⽰例: 使⽤ QFormLayout 创…

兰亭妙微 | 医疗软件的界面设计能有多专业?

从医疗影像系统到手术机器人控制界面,从便携式病原体检测设备到多平台协同操作系统,兰亭妙微为众多医疗设备研发企业,打造了兼具专业性与可用性的交互界面方案。 我们不仅做设计,更深入理解医疗场景的实际需求: 对精…

鸿蒙开发修改版本几个步骤

鸿蒙开发修改版本几个步骤 比如:5.0.4(16)版本改为5.0.2(14)版本 一、项目下的build-profile.json5 "products": [{"name": "default","signingConfig": "default&qu…

Flask 基础与实战概述

一、Flask 基础知识 什么是 Flask? Flask 是一个基于 Python 的轻量级 Web 框架(微框架)。 特点:核心代码简洁,给予开发者更多选择空间。 与 Django 对比: Django 创建空项目生成多个文件,Flask 仅需一个文件即可实现简单应用(如 "Hello, World!")。 Flask …

Linux安全加固:从攻防视角构建系统免疫

Linux安全加固:从攻防视角构建系统免疫 构建坚不可摧的数字堡垒 引言:攻防对抗的新纪元 在日益复杂的网络威胁环境中,Linux系统安全已从被动防御转向主动免疫。2023年全球网络安全报告显示,高级持续性威胁(APT)攻击同比增长65%,平均入侵停留时间缩短至48小时。本章将从…