大模型微调示例四之Llama-Factory-DPO

大模型微调示例四之Llama-Factory-DPO

  • 一、强化学习数据处理
  • 二、配置训练文档
  • 三、模型预测

一、强化学习数据处理

原始数据地址:https://nijianmo.github.io/amazon/index.html

第一步:读取 video game 信息

import codecs, json, re
from random import shuffle# 第一步:读取 video game 信息
# key 是 productID,value是 title
games = {}
cc = 0with codecs.open('./data/src_data/meta_Video_Games.json', mode='r') as fin:for line in fin:tmp_info = json.loads(line.strip())# asin - ID of the product# title - name of the productgames[tmp_info["asin"]] = tmp_info["title"]if len(games) % 10000 == 0:print(f'Length of games: {len(games)}')

第二步:读取用户评分信息

# key 是 userid,value 是评价的游戏和评分
user_reviews = {}cc = 0
with codecs.open('./data/src_data/Video_Games_5.json', mode='r') as fin:for line in fin:tmp_info = json.loads(line.strip())# reviewerID - ID of the reviewerreviewer_id = tmp_info["reviewerID"]time_info = re.split(', | ', tmp_info["reviewTime"])review_time = time_info[2] + '-' + time_info[0] + '-' + time_info[1]# asin - ID of the productproduct_id = tmp_info["asin"]# overall - rating of the productrating = tmp_info["overall"]# if cc > 1000:#     break# print(tmp_info)# print(user_reviews)if product_id in games.keys():product_title = games[product_id]if reviewer_id in user_reviews.keys():user_reviews[reviewer_id].append((product_title, rating, review_time))else:user_reviews[reviewer_id] = [(product_title, rating, review_time)]if len(user_reviews) % 10000 == 0:print(f'Length of user_reviews: {len(user_reviews)}')cc += 1user_reviews_sorted = {}
for k, v in user_reviews.items():# 首先去重v = list(set(v))# 然后根据评价时间从小到大排序,表示用户的评价历史v_sorted = sorted(v, key=lambda x: x[2])# 选择具有7个及以上的评论样本if len(v) >= 7:# print(f'v: {v}, v_sorted: {v_sorted}')user_reviews_sorted[k] = v_sorted
print(f'Length of user_reviews_sorted: {len(user_reviews_sorted)}')

第三步 训练数据生成

# 总样本
samples = []
# 指令
instruction = "You are an assistant working on Video Games recommendations. Given the user's history of Video Games they have shopped, which includes the \"Title\" of the Video Games and the \"Rating\" the user rate (the Rating value is like or dislike), please decide whether the user likes to shop the target Video Games by outputting the order of their titles."
samples = []
cc = 0
for k, v in user_reviews_sorted.items():# print('-'*10)# print(v)sample_input = "User shopped Video Games histories (Title and Rating): \n"# 前面的当作对话历史for vv in v[0: -2]:# 当 rating 大于 3.0 的时候设置为 likeif vv[1] > 3.0:rating = 'like'# 当 rating 小于等于 3.0 的时候设置为 dislikeelse:rating = 'dislike'sample_input += "<Title: {}, Rating: {}>\n".format(vv[0], rating)sample_input += "Based on the Video Games histories, please sort the following two Video Games titles. The one in the front is what the user like and should be recommended to user: \n"# 最后两个设置为需要预测的目标sample_input += "<Title: " + v[-2][0] + '>\n'sample_input += "<Title: " + v[-1][0] + '>\n'# print(f'v[-1][1]: {v[-1][1]}, v[-2][1]: {v[-2][1]}')# 保证有一个是 like,有一个是 dislikeif (v[-1][1] > 3.0 and v[-2][1] <= 3.0) or (v[-1][1] <= 3.0 and v[-2][1] > 3.0):# print(f'v[-1][1] != v[-2][1]: {v[-1][1]}, {v[-2][1]}')if v[-1][1] > v[-2][1]:# likeoption1 = v[-1][0]# dislikeoption2 = v[-2][0]else:# likeoption1 = v[-2][0]# dislikeoption2 = v[-1][0]# chosen 是 like 在前面chosen = "<Title: " + option1 + '>\n' + "<Title: " + option2 + '>'# rejected 是 dislike 在前面rejected = "<Title: " + option2 + '>\n' + "<Title: " + option1 + '>'sample = {"instruction": instruction,"input": sample_input,"chosen": chosen,"rejected": rejected}# print(f'--------')# print(v)# print(sample)samples.append(sample)if len(samples) % 10000 == 0:print(f'Length of samples: {len(samples)}')# cc += 1# if cc > 10:#     breakprint(f'Length of samples: {len(samples)}')

第四步 划分 train 和 test 保存样本

# 首先打乱
shuffle(samples)train = samples[:int(len(samples)*0.8)]
test = samples[int(len(samples)*0.8):]print(f'总样本数: {len(samples)},训练集样本数: {len(train)},测试集样本数: {len(test)}')with open("./data/processed/rlhf_train.json", "w", encoding='utf-8') as save_file:json.dump(train, save_file, indent=4)with open("./data/processed/rlhf_test.json", "w", encoding='utf-8') as save_file:json.dump(test, save_file, indent=4) # , sort_keys=True

二、配置训练文档

rlhf_train.yaml

### model
model_name_or_path: /ZhipuAI/glm-4-9b-chat### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: 16
lora_alpha: 32
pref_beta: 0.1
pref_loss: orpo### dataset
dataset: amazon_video_games
template: glm4
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: ./saves/amazon_video_games_orpo
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-6
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

rlhf_inference.yaml

model_name_or_path: /ZhipuAI/glm-4-9b-chat
adapter_name_or_path: ./saves/amazon_video_games_orpo
template: glm4
finetuning_type: lora

三、模型预测

import json
from openai import OpenAI
from tqdm import tqdm# 加载模型
client = OpenAI(api_key="EMPTY",# 需要修改为大模型地址base_url="http://10.114.16.65:8000/v1/"
)
# 加载测试数据
test_file_path = "./data/processed/rlhf_test.json"
with open(test_file_path, "r", encoding='utf-8') as test_file:test_data = json.load(test_file)
print(len(test_data))
# 开始预测
labels = []
predictions = []
cc = 0
for each_test in tqdm(test_data):chat_completion = client.chat.completions.create(messages=[{"role": "system","content": each_test["instruction"]},{"role": "user","content": each_test["input"],}],model="glm4",)predictions.append(chat_completion.choices[0].message.content)labels.append(each_test["chosen"])if len(labels) % 100 == 0:correct = 0wrong = 0for l, p in zip(labels, predictions):l = l.strip()p = p.strip()# print(f'l: {l}, p: {p}')if l == p:correct += 1else:wrong += 1# print(f'\nl: {l}, \np: {p}')print(f'总样本数:{len(labels)},准确数:{correct}, 错误数:{wrong}, 准确率:{correct / len(labels)}')cc += 1# if cc > 100:#     breakassert len(predictions) == len(labels)correct = 0
wrong = 0for l, p in zip(labels, predictions):l = l.strip()p = p.strip()if l == p:correct += 1else:wrong += 1
print(f'总样本数:{len(labels)},准确数:{correct}, 错误数:{wrong}, 准确率:{correct/len(labels)}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/95122.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/95122.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 将HTML文件、HTML字符串转换为图片

在 Java 开发中&#xff0c;我们经常会遇到将 HTML 内容转换为图片的需求&#xff0c;比如生成网页报告截图、电商商品详情页预览图、在线文档缩略图等。本文将介绍如何使用 Free Spire.Doc for Java 库来实现这一功能。 Free Spire.Doc for Java 是一款免费库且无需任何依赖&a…

(Arxiv-2024)VideoMaker:零样本定制化视频生成,依托于视频扩散模型的内在力量

VideoMaker&#xff1a;零样本定制化视频生成&#xff0c;依托于视频扩散模型的内在力量 paper title&#xff1a;VideoMaker: Zero-shot Customized Video Generation with the Inherent Force of Video Diffusion Models paper是ZJU发布在Arxiv 2024的工作 Code:链接 图1. 我…

录屏、助眠、翻译

01【小熊录屏】 02【全球翻译】 03【声萌助眠】 03 软件获取 小熊录屏&#xff08;点击下载&#xff09; 声萌助眠&#xff08;点击下载&#xff09; 全球-译官&#xff08;点击下载&#xff09;

第17章|PowerShell 安全警报——高分学习笔记(运维实战向)

&#x1f6e1;️ 第17章&#xff5c;PowerShell 安全警报——高分学习笔记&#xff08;运维实战向&#xff09;一句话核心&#xff1a;PowerShell 的“安全设计目标”是——不替你越权&#xff1b;尽量防“误触发不可信脚本”&#xff1b;并非反恶意软件的最后防线。1&#xff…

哈希表性能对比:uthash、hsearch与Linux内核哈希表的深度解析

引言 在网络编程和高性能服务器开发中,高效的数据结构是保证系统性能的关键。本文基于对三种主流哈希表实现(uthash、hsearch和Linux内核哈希表)的深度测试,探讨它们在处理50,000个客户端连接时的性能表现、内存效率及适用场景。 测试环境与方法 测试数据结构 我们使用…

探索 XGBoost 与 LightGBM 的差异:哪个更适合你的项目?

轻松对比&#xff1a;XGBoost 和 LightGBM 的差异与选择指南 在机器学习领域&#xff0c;梯度提升树&#xff08;GBDT&#xff09;是一种广泛使用的算法&#xff0c;而 XGBoost 和 LightGBM 是两款最受欢迎的 GBDT 实现。它们都能够显著提高模型的准确性&#xff0c;但它们之间…

C++链表双杰:list与forward_list

在C容器的世界里&#xff0c;当我们需要频繁地在序列中间进行插入和删除时&#xff0c;基于数组的 vector 会显得力不从心。这时&#xff0c;链表结构就闪亮登场了。STL提供了两种链表容器&#xff1a;功能全面的双向链表 std::list 和极致轻量化的单向链表 std::forward_list。…

Ruoyi-vue-plus-5.x第一篇Sa-Token权限认证体系深度解析:1.4 Sa-Token高级特性实现

&#x1f44b; 大家好&#xff0c;我是 阿问学长&#xff01;专注于分享优质开源项目解析、毕业设计项目指导支持、幼小初高的教辅资料推荐等&#xff0c;欢迎关注交流&#xff01;&#x1f680; Sa-Token高级特性实现 前言 在前面的文章中&#xff0c;我们学习了Sa-Token的…

Linux 服务器初始化解析和ssh密钥交换的介绍

目录 2. SSH 基于密钥交换的介绍和原理 2.1 核心优势 2.2 密钥交换原理&#xff08;非对称加密体系&#xff09; 2.3 基础配置步骤 3. 服务器初始化 3.1 安装 yum 网络源 3.1.1 背景说明 3.1.2 实操步骤 3.2 安装运维的必备工具 3.2.1 工具清单 3.2.2 批量安装命令 …

web渗透ASP.NET(Webform)反序列化漏洞

web渗透ASP.NET(Webform)反序列化漏洞1&#xff09;ASP.NET(Webform)反序列化漏洞ASP.NET(Webform) 反序列化漏洞的核心触发点是 Webform 框架中的VIEWSTATE参数 —— 该参数用于存储页面控件状态数据&#xff0c;默认以 Base64 编码传输&#xff0c;内部包含序列化的对象数据。…

Android FrameWork - 开机启动 SystemServer 进程

基于安卓 12 源码分析相关类&#xff1a;frameworks/base/core/java/com/android/internal/os/ZygoteInit.java frameworks/base/core/java/com/android/internal/os/Zygote.java frameworks/base/core/java/com/android/internal/os/RuntimeInit.java frameworks/base/service…

C++:list容器--模拟实现(下篇)

1. 模拟实现 list 一些常用接口// list.h #pragma once #include <assert.h> #include "Iterator.h"namespace room {template<class T>struct list_node{list_node<T>* _next;list_node<T>* _prev;T _data;list_node(const T& x T()):…

边缘计算:一场由物理定律发起的“计算革命”

专栏引言:在前面的文章中,我们探讨了云计算如何将计算资源变成了“数字水电煤”,构建了一个强大的中心化数字帝国。然而,当这个帝国试图将它的触角伸向物理世界的每一个角落时,却遭遇了两位“上古之神”的无情阻击——光速与带宽。今天,我们将聚焦于一场由物理定律发起的…

量化模型部署工具llama.cpp

量化模型部署工具llama.cppllama.cppllama.cpp 是什么使用场景是什么如何使用&#xff1f;第 1 步&#xff1a;获取量化模型第 2 步&#xff1a;编译 llama.cpp第 3 步&#xff1a;运行推理完整 Demo&#xff1a;与 Llama 3 对话进阶使用&#xff1a;Python 集成总结概念解释1.…

【光照】[光照模型]发展里程碑时间线

【从UnityURP开始探索游戏渲染】专栏-直达 图形学光照模型发展史&#xff1a;技术演进与里程碑 section 基础奠基期(1960s-1970s) 1967 &#xff1a; Lambert模型(漫反射) - Bui Tuong Phong提出1971 &#xff1a; Gouraud着色 - Henri Gouraud发明顶点插值着色1973 &#xf…

【从零开始java学习|第十篇】面向对象

目录 一、面向对象介绍 二、类和对象 1. 类&#xff08;Class&#xff09;&#xff1a;对象的模板 2. 对象&#xff08;Object&#xff09;&#xff1a;类的实例 三、封装 1. 封装的概念 2. 封装的优势 四、就近原则和 this 关键字 1. 就近原则 2. this 关键字 五、…

Spark算子调优

Spark中可用下面的算子对数据计算进行优化处理&#xff0c;包括&#xff1a; mapPartition&#xff1a;一次处理一个分区数据&#xff0c;能够使用mapPartition的尽量使用&#xff0c;但是使用时会一次性读取整个分区数据到内存&#xff0c;占内存很大&#xff0c;同理还有fore…

码农特供版《消费者权益保护法》逆向工程指北——附源码级注释与异常处理方案

尊敬的审核&#xff1a; 本人文章《码农特供版〈消费者权益保护法〉逆向工程指北——附源码级注释与异常处理方案》 1. 纯属技术交流&#xff0c;无任何违法内容 2. 所有法律引用均来自公开条文 3. 请依据《网络安全法》第12条“不得无故删除合法内容”处理 附&#xff1a;本文…

MQTT 连接建立与断开流程详解(二)

三、核心机制与最佳实践&#xff08;一&#xff09;会话管理与 QoS 保障Clean Session vs 持久会话&#xff1a;在 MQTT 连接中&#xff0c;会话管理是一个重要的概念&#xff0c;其中 Clean Session 和持久会话是两种不同的会话模式。Clean Session&#xff0c;当设置为 1 时&…

[光学原理与应用-332]:ZEMAX - 序列模式与非序列模式的本质、比较

序列模式&#xff08;Sequential Mode&#xff09;与非序列模式&#xff08;Non-Sequential Mode&#xff09;是ZEMAX光学设计软件中的两种核心设计模式&#xff0c;二者在光路定义、分析工具、应用场景等方面存在本质差异。以下是两者的详细比较&#xff1a;一、本质差异光路定…