探索Qwen2ForCausalLM 架构上进行微调

简述

试验参考了mini_qwen 的开源实现

GitHub - qiufengqijun/mini_qwen: 这是一个从头训练大语言模型的项目,包括预训练、微调和直接偏好优化,模型拥有1B参数,支持中英文。这是一个从头训练大语言模型的项目,包括预训练、微调和直接偏好优化,模型拥有1B参数,支持中英文。. Contribute to qiufengqijun/mini_qwen development by creating an account on GitHub.https://github.com/qiufengqijun/mini_qwen

分词器使用Qwen/Qwen2.5-0.5B-Instruct,通过扩充模型隐藏状态层数、嵌入层维度和注意力头数,增加参数量到1B,使用flash_attention_2进行加速

主要特点:

  • 低资源需求:预训练和微调仅需 12GB 显存,DPO 训练需要 14GB 显存

  • 训练数据:使用来自 BAAI(北京智源人工智能研究院)的数据集,包括用于预训练的 160 亿 tokens、用于微调的 900 万条样本,以及用于偏好优化的 6 万条样本。

数据集

魔塔社区 BAAI/IndustryCorpus2 数据集,根据需要下载


# 下载预训练数据集
modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'film_entertainment/*/high*' --local_dir 'data/pt' # 数据量较大,英文文件选择前3个文件modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'computer_programming_code/*/high*' --local_dir 'data/pt'
modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'computer_communication/*/high*' --local_dir 'data/pt' # 数据量较大,英文文件选择前3个文件modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'tourism_geography/*/high*' --local_dir 'data/pt'
modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'artificial_intelligence_machine_learning/*/high*' --local_dir 'data/pt'
modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'news_media/*/high*' --local_dir 'data/pt'
modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'literature_emotion/*/high*' --local_dir 'data/pt' # 数据量较大,英文文件选择前3个文件modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'accommodation_catering_hotel/*/high*' --local_dir 'data/pt'
modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'current_affairs_government_administration/*/high*' --local_dir 'data/pt' # 数据量较大,英文文件选择前3个文件modelscope download --dataset 'BAAI/IndustryCorpus2' --include 'mathematics_statistics/*/high*' --local_dir 'data/pt'

查看下数据

dataset = load_dataset("parquet", data_files="mini_data/pt/accommodation_catering_hotel/chinese/high/rank_00000.parquet", split="train")
print(dataset[0])

# 大概这么个结构
{"text": "马亮:如何破解外卖骑手的\"生死劫\"\n在消费至上的今天,企业不应道德绑架消费者,让消费者为企业的伪善埋单。。。。。。","alnum_ratio": 0.9146919431,"avg_line_length": 158.25,"char_rep_ratio": 0.044444444400000005,"flagged_words_ratio": 0.0,"max_line_length": 223,"num_words": 404,"perplexity": 858.6,"quality_score": 4.0625,"special_char_ratio": 0.1000526593,"word_rep_ratio": 0.088772846,"_id": 200200005357,"industry_type": "住宿_餐饮_酒店"
}

这个结构是一个经过质量分析或过滤的训练样本的 JSON 表示,用于语言模型训练前的数据评估或筛选阶段。它除了包含原始文本(text)外还包含了一系列用来衡量数据质量的统计特征指标,用于判断该样本是否值得保留用于训练。

🔹 质量指标字段说明:

字段名说明
alnum_ratio字母数字字符所占比例。用于判断文本是否主要为自然语言(而非乱码或表格类数据)
avg_line_length平均每行字符数。可能反映文本结构是否合理(过长/过短)
char_rep_ratio字符重复率。例如“哈哈哈哈哈哈”这种重复率就很高
flagged_words_ratio敏感词或不良词汇占比(0 表示未检测到敏感词)
max_line_length最长一行的字符数。可用于过滤极端异常格式文本
num_words词数总计。用于衡量样本长度
perplexity使用某个语言模型评估的困惑度(perplexity)。数值越低,表示文本越“正常”或模型越容易预测它
quality_score综合质量评分。可能是上述特征加权后的结果,衡量样本是否值得用于训练
special_char_ratio特殊字符(如 #¥%&* 等)在文本中的占比
word_rep_ratio单词重复率(如“外卖外卖外卖平台”)

训练逻辑

加载数据集

# 加载数据集并进行预处理
directories = ["accommodation_catering_hotel","artificial_intelligence_machine_learning","computer_communication","computer_programming_code","film_entertainment","literature_emotion","news_media","tourism_geography","current_affairs_government_administration","mathematics_statistics",
]
data_files = find_files(directories)
dataset = load_dataset("parquet", data_files=data_files, split="train", columns=["text"]) # 只保留text字段
dataset = dataset.shuffle(seed=42)

数据清洗,将原始文本 → 添加结束符 → 分词 → 拼接成一长串 → 按 block_size 切成多个训练用的样本块(每块长度一致),给每条文本加上自定义的“结束符” <|im_end|>,把所有样本的 token 串接在一起(例如把多个 [101,102] 合并为 [101,102,103,104,...]),这是因为 GPT 模型的预训练目标是连续预测序列,所以训练输入是一个“连续的 token 流”。

计算总长度并对齐

  • 得到拼接后 token 总长度(例如 10,356)

  • 只保留整除 block_size(1024)的部分,截断掉尾部多余部分,例如:10356 → 10240(保留完整的 10 块)切成 1024 个 token 一块的样本,每隔 1024 个 token 分一块,生成多个训练样本,输出结构:

{"input_ids": [[token1...token1024], [token1025...token2048], ...],"attention_mask": 同理
}

参考预训练代码

def preprocess_dataset(examples):"""预处理预训练数据集,将文本分词并分块"""eos_token = "<|im_end|>"text_examples = [text + eos_token for text in examples["text"]]  # 添加结束符tokenized_examples = tokenizer(text_examples, add_special_tokens=False)# 将分词结果拼接并分块concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])block_size = 1024  # 分块大小total_length = (total_length // block_size) * block_size  # 对齐块大小result = {k: [t[i : i + block_size] for i in range(0, total_length, block_size)]for k, t in concatenated_examples.items()}return result# 应用预处理函数
train_dataset = dataset.map(preprocess_dataset,batched=True,batch_size=5000,remove_columns=dataset.column_names,num_proc=16,
)

,原来有很多条文本,现在经过这段预处理函数:

  • 文本 → 拼接 → 分词 → 连续 token 序列 → 按块切分

  • 输出的每个样本都是 1024 个 token 的一段,可直接送入语言模型进行训练(如 GPT)

数据预处理

训练配置

预训练

accelerate_config.yaml 文件包含了用于配置训练环境的参数。以下是各个配置项的含义:

  • compute_environment: 指定计算环境,这里为本地机器 (LOCAL_MACHINE)。
  • debug: 调试模式,设置为 false 表示不启用调试。
  • deepspeed_config: 包含与 DeepSpeed 相关的配置:
    • gradient_accumulation_steps: 梯度累积步数,这里设置为 16。
    • gradient_clipping: 梯度裁剪值,防止梯度过大,这里为 1.0。
    • offload_optimizer_device: 优化器的卸载设备,这里为 none 表示不卸载。
    • offload_param_device: 参数的卸载设备,这里为 none
    • zero3_init_flag: 是否启用 ZeRO-3 初始化,这里为 false
    • zero_stage: ZeRO 优化的阶段,这里设置为 2。
  • distributed_type: 分布式训练类型,这里为 DEEPSPEED
  • downcast_bf16: 是否降低 bf16 精度,这里设置为 'no'。
  • enable_cpu_affinity: 是否启用 CPU 亲和性,设置为 false
  • machine_rank: 当前机器在分布式训练中的排名,这里为 0。
  • main_training_function: 主训练函数的名称,这里为 main
  • mixed_precision: 混合精度训练,这里使用 bf16。
  • num_machines: 参与训练的机器数量,这里为 1。
  • num_processes: 每台机器上的进程数量,这里为 2。
  • rdzv_backend: rendezvous 后端,这里为 static
  • same_network: 是否在同一网络中,设置为 true
  • tpu_env: TPU 环境配置,这里为空。
  • tpu_use_cluster: 是否使用 TPU 集群,设置为 false
  • tpu_use_sudo: 是否使用 sudo 权限,设置为 false
  • use_cpu: 是否使用 CPU 进行训练,设置为 false

这些配置项帮助用户设置和优化模型训练过程,尤其是在使用 DeepSpeed 进行分布式训练时, 配置参考

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:gradient_accumulation_steps: 16gradient_clipping: 1.0offload_optimizer_device: noneoffload_param_device: nonezero3_init_flag: falsezero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

训练逻辑

# 已经下载了Qwen2.5-0.5B-Instruct地址
model_path = "./models/Qwen2.5-0.5B-Instruct"
config = AutoConfig.from_pretrained(model_path)# 调整模型配置
config.num_attention_heads = 16
config.num_key_value_heads = 4
config.hidden_size = 1024
config.num_hidden_layers = 48# 加载模型
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)# 训练参数配置
training_args = TrainingArguments(output_dir=output_path,overwrite_output_dir=True,learning_rate=1e-4,warmup_ratio=0.1,lr_scheduler_type="cosine",num_train_epochs=1,per_device_train_batch_size=12,gradient_accumulation_steps=16,save_steps=100_000,save_total_limit=3,bf16=True,# save_only_model=True,logging_steps=20,
)# 初始化Trainer
trainer = Trainer(model=model,args=training_args,data_collator=collator,train_dataset=train_dataset,
)

使用accelerate 加速训练

模型参数量: 576851712
表示模型总共有 576,851,712 个参数,也就是约 5.77 亿参数(≈ 577M)。

和之前用的MiniMind架构的模型比训练速度要慢很多,所以直接跳过预训练使用基座进行微调

SFT微调

 trl 配置

# 训练参数配置
training_args = SFTConfig(output_dir=output_path,                 # 训练完成后模型保存的目录overwrite_output_dir=True,             # 如果目录已存在则覆盖原模型learning_rate=1e-5,                    # 学习率,SFT阶段建议小一点warmup_ratio=0.1,                      # 热身步数比例,用于逐渐增加学习率lr_scheduler_type="cosine",            # 学习率调度策略:余弦退火num_train_epochs=3,                    # 训练轮数per_device_train_batch_size=12,        # 每张显卡的batch大小(显存不够就调小)gradient_accumulation_steps=16,        # 梯度累计步数,总batch大小 = 12 × 16 = 192save_strategy="epoch",                 # 每轮结束保存一次模型save_total_limit=3,                    # 最多保存3个checkpoint,旧的自动删掉bf16=True,                             # 使用 bfloat16 进行训练(比 fp16 更稳定,NVIDIA A100/H100 支持)logging_steps=20,                      # 每20步打印一次日志
)# 初始化Trainer
trainer = SFTTrainer(model=model,                           # 使用的模型(已初始化)args=training_args,                    # 上面定义的训练参数train_dataset=dataset,                 # 训练数据集tokenizer=tokenizer,                   # 分词器formatting_func=formatting_prompts_func, # 格式化数据的函数,把样本转换成 prompt + completiondata_collator=collator,               # 数据整理器(例如自动填充、构建input_ids等)
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/81032.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

hysAnalyser特色的TS流编辑、剪辑和转存MP4功能说明

摘要 hysAnalyser 是一款特色的 MPEG-TS 数据分析工具&#xff0c;融合了常规TS文件的剪辑&#xff0c;转存功能&#xff0c;可用于平常的视频开发和测试。 本文详细阐述了对MPEG-TS 流的节目ID&#xff0c;名称&#xff0c;PID&#xff0c;时间戳&#xff0c;流类型&#xff…

前端[插件化]设计思想_Vue、React、Webpack、Vite、Element Plus、Ant Design

前端插件化设计思想旨在提升应用的可扩展性、可维护性和模块化程度。这种思想不仅体现在框架&#xff08;如 Vue、React&#xff09;中&#xff0c;也广泛应用于构建工具&#xff08;如 Webpack、Vite&#xff09;以及 UI 库&#xff08;如 Element Plus、Ant Design&#xff0…

2025年高防IP与游戏盾深度对比:如何选择最佳防护方案?

2025年&#xff0c;随着DDoS攻击规模的指数级增长和混合攻击的常态化&#xff0c;高防IP与游戏盾成为企业网络安全的核心选择。然而&#xff0c;两者在功能定位、技术实现及适用场景上存在显著差异。本文结合最新行业实践与技术趋势&#xff0c;全面解析两者的优劣&#xff0c;…

日志根因分析:Elastic Observability 的异常检测与日志分类功能

作者&#xff1a;来自 Elastic Bahubali Shetti Elastic Observability 不仅提供日志聚合、指标分析、APM 和分布式追踪&#xff0c;Elastic 的机器学习能力还能帮助分析问题的根因&#xff0c;让你将时间专注于最重要的任务。 随着越来越多的应用程序迁移到云端&#xff0c;收…

Linux火墙管理及优化

网络环境配置 使用3个新的虚拟机【配置好软件仓库和网络的】 F1 192.168.150.133 NAT F2 192.168.150.134 192.168.10.20 NAT HOST-ONLY 网络适配仅主机 F3 192.168.10.30 HOST-ONLY 网络适配仅主机 1 ~]# hostnamectl hostname double1.timinglee.org 【更…

java配置webSocket、前端使用uniapp连接

一、这个管理系统是基于若依框架&#xff0c;配置webSocKet的maven依赖 <!--websocket--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency> 二、配…

基于Yolov8+PyQT的老人摔倒识别系统源码

概述 ​​基于Yolov8PyQT的老人摔倒识别系统​​&#xff0c;该系统通过深度学习算法实时检测人体姿态&#xff0c;精准识别站立、摔倒中等3种状态&#xff0c;为家庭或养老机构提供及时预警功能。 主要内容 ​​完整可运行代码​​ 项目采用Yolov8目标检测框架结合PyQT5开发…

Oracle 创建外部表

找别人要一下数据&#xff0c;但是他发来一个 xxx.csv 文件&#xff0c;怎么办&#xff1f; 1、使用视图化工具导入 使用导入工具导入&#xff0c;如 DBeaver&#xff0c;右击要导入的表&#xff0c;选择导入数据。 选择对应的 csv 文件&#xff0c;下一步就行了&#xff08;如…

【华为OD- B卷 01 - 传递悄悄话 100分(python、java、c、c++、js)】

【华为OD- B卷 01 - 传递悄悄话 100分(python、java、c、c++、js)】 题目 给定一个二叉树,每个节点上站一个人,节点数字表示父节点到该节点传递悄悄话需要花费的时间。 初始时,根节点所在位置的人有一个悄悄话想要传递给其他人,求二叉树所有节点上的人都接收到悄悄话花…

房贷利率计算前端小程序

利率计算前端小程序 视图效果展示如下&#xff1a; 在这里插入代码片 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0&qu…

自制操作系统day8 (鼠标数据取得、通往32位模式之路、A20GATE、切换到保护模式、控制寄存器cr0-cr4以及cr8、ALIGNB)

day8 鼠标数据取得方法 fifo8_init(&mousefifo, 128, mousebuf); for (;;) { io_cli(); if (fifo8_status(&keyfifo) fifo8_status(&mousefifo) 0) { io_stihlt(); } else { if (fifo8_status(&keyfifo) ! 0) { i fifo8_get(&keyfifo); io_sti(); spr…

IP大科普:住宅IP、机房IP、原生IP、双ISP

不同类型的IP在跨境电商、广告营销、网络技术、数据收集等领域都有广泛应用&#xff0c;比如常见的住宅IP、机房IP、原生IP、双ISP等&#xff0c;这些IP分别都有什么特点&#xff0c;发挥什么作用&#xff0c;适合哪些业务场景&#xff1f; 一、IP类型及其作用 1.住宅IP 住宅…

Elasticsearch面试题带答案

Elasticsearch面试题带答案 Elasticsearch面试题及答案【最新版】Elasticsearch高级面试题大全(2025版),发现网上很多Elasticsearch面试题及答案整理都没有答案,所以花了很长时间搜集,本套Elasticsearch面试题大全,Elasticsearch面试题大汇总,有大量经典的Elasticsearch面…

Eigen与OpenCV矩阵操作全面对比:最大值、最小值、平均值

功能对比总表 功能Eigen 方法OpenCV 方法主要区别最大值mat.maxCoeff(&row, &col)cv::minMaxLoc(mat, NULL, &maxVal, NULL, &maxLoc)Eigen需要分开调用&#xff0c;OpenCV一次获取最小值mat.minCoeff(&row, &col)cv::minMaxLoc(mat, &minVal, NU…

echarts之双折线渐变图

vue3echarts实现双折线渐变图 echarts中文官网&#xff1a;https://echarts.apache.org/examples/zh/index.html 效果图展示&#xff1a; 整体代码如下&#xff1a; <template><div id"lineChart" style"width:100%;height:400px;"></di…

MD编辑器推荐【Obsidian】含下载安装和实用教程

为什么推荐 Obsidian &#xff1f; 免费 &#xff08;Typora 开始收费了&#xff09;Typora 实现的功能&#xff0c;它都有&#xff01;代码块可一键复制 文件目录支持文件夹 大纲支持折叠、搜索 特色功能 – 白板 特色功能 – 关系图谱 下载 https://pan.baidu.com/s/1I1fSly…

vue 鼠标经过时显示/隐藏其他元素

方式一&#xff1a; 使用纯css方式 , :hover是可以控制其他元素 1、 当两个元素是父子关系 <div class"all_" ><div> <i class"iconfont icon-sun sun"></i></div> </div> .all_{} .sun {display: none; /* 默认…

静态网站部署:如何通过GitHub免费部署一个静态网站

GitHub提供的免费静态网站托管服务可以无需担心昂贵的服务器费用和复杂的设置步骤&#xff0c;本篇文章中将一步步解如何通过GitHub免费部署一个静态网站&#xff0c;帮助大家将创意和作品快速展现给世界。 目录 了解基础情况 创建基础站点 在线调试站点 前端项目部署 部署…

Pytorch里面多任务Loss是加起来还是分别backward? | Pytorch | 深度学习

当你在深度学习中进入“多任务学习(Multi-task Learning)”的领域,第一道关卡可能不是设计网络结构,也不是准备数据集,而是:多个Loss到底是加起来一起backward,还是分别backward? 这个问题看似简单,却涉及PyTorch计算图的构建逻辑、自动求导机制、内存管理、任务耦合…

基于DPABI提取nii文件模板的中心点坐标

基于DPABI提取nii文件模板的中心点坐标 在使用DPABI&#xff08;Data Processing Assistant for Resting-State fMRI&#xff09;处理NIfTI&#xff08;.nii&#xff09;文件时&#xff0c;可以通过以下步骤提取模板中每个坐标点的中心点坐标&#xff1a;https://wenku.csdn.n…