Pytorch FSDP权重分片保存与合并

注:本文章方法只适用Pytorch FSDP1的模型,且切分策略为SHARDED_STATE_DICT场景。

在使用FSDP训练模型时,为了节省显存通常会把模型权重也进行切分,在保存权重时为了加速保存通常每个进程各自保存自己持有的部分权重,避免先汇聚到主进程再保存浪费大量时间的问题。保存成分片权重后,如果需要推理则还需要将分片权重进行合并。下面提供了保存分片权重以及将分片权重合并的代码示例,代码主要参考accelerate官方源码。

import osimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
import torch.distributed.checkpoint.format_utils as dist_cp_format_utilsdef save_fsdp_model(model: FSDP, fsdp_ckpt_path: str):# refer accelerate/utils/fsdp_utils.py:save_fsdp_modelwith FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):os.makedirs(fsdp_ckpt_path, exist_ok=True)state_dict = {"model": model.state_dict()}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(fsdp_ckpt_path),planner=DefaultSavePlanner(),)def merge_fsdp_weights(fsdp_ckpt_path: str, save_path: str):# refer accelerate/utils/fsdp_utils.py:merge_fsdp_weightsstate_dict = {}dist_cp_format_utils._load_state_dict(state_dict,storage_reader=dist_cp.FileSystemReader(fsdp_ckpt_path),planner=dist_cp_format_utils._EmptyStateDictLoadPlanner(),no_dist=True,)# To handle if state is a dict like {model: {...}}if len(state_dict.keys()) == 1:state_dict = state_dict[list(state_dict)[0]]torch.save(state_dict, save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/95412.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/95412.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IDEA自动生成Mapper、XML和实体文件

1. 引入插件 <build><finalName>demo</finalName><plugins><plugin><groupId>org.mybatis.generator</groupId><artifactId>mybatis-generator-maven-plugin</artifactId><version>1.3.5</version><depe…

单例模式的理解

目录单例模式1.饿汉式(线程安全)2.懒汉式(通过synchronized修饰获取实例的方法保证线程安全)3.双重校验锁的方式实现单例模式4.静态内部类方式实现单例模式【推荐】单例模式 1.饿汉式(线程安全) package 并发的例子.单例模式; // 饿汉式单例模式&#xff08;天然线程安全&…

NLP---IF-IDF案例分析

一案例 - 红楼梦1首先准备语料库http://www.dxsxs.com这个网址去下载2 任务一&#xff1a;拆分提取import os import redef split_hongloumeng():# 1. 配置路径&#xff08;关键&#xff1a;根据实际文件位置修改&#xff09; # 脚本所在文件夹&#xff08;自动获取&#xff0…

LaTeX(排版系统)Texlive(环境)Vscode(编辑器)环境配置与安装

LaTeX、Texlive 和 Vscode 三者之间的关系&#xff0c;可以把它们理解成语言、工具链和编辑器的配合关系。 1.下载Texlive 华为镜像网站下载 小编这边下载的是texlive2025.iso最新版的&#xff0c;下载什么版本看自己需求&#xff0c;只要下载后缀未.iso的即可。为避免错误&am…

【深入浅出STM32(1)】 GPIO 深度解析:引脚特性、工作模式、速度选型及上下拉电阻详解

GPIO 深度解析&#xff1a;引脚特性、工作模式、速度选型及上下拉电阻详解一、GPIO概述二、GPIO的工作模式1、简述&#xff08;1&#xff09;4种输入模式&#xff08;2&#xff09;4种输出模式&#xff08;3&#xff09;4种最大输出速度2、引脚速度&#xff08;1&#xff09;输…

第1节 大模型分布式推理基础与技术体系

前言:为什么分布式推理是大模型时代的核心能力? 当我们谈论大模型时,往往首先想到的是训练阶段的千亿参数、千卡集群和数月的训练周期。但对于商业落地而言,推理阶段的技术挑战可能比训练更复杂。 2025年,某头部AI公司推出的130B参数模型在单机推理时面临两个选择:要么…

《软件工程导论》实验报告一 软件工程文档

目 录 一、实验目的 二、实验环境 三、实验内容与步骤 四、实验心得 一、实验目的 1. 理解软件工程的基本概念&#xff0c;熟悉软件&#xff0c;软件生命周期&#xff0c;软件生存周期过程和软件生命周期各阶段的定义和内容。 2. 了解软件工程文档的类别、内容及撰写软件工…

基于elk实现分布式日志

1.基本介绍 1.1 什么是分布式日志 在分布式应用中&#xff0c;日志被分散在储存不同的设备上。如果你管理数十上百台服务器&#xff0c;你还在使用依次登录每台机器的传统方法查阅日志。这样是不是感觉很繁琐和效率低下。所以我们使用集中化的日志管理&#xff0c;分布式日志…

多模态RAG赛题实战之策略优化--Datawhale AI夏令营

科大讯飞AI大赛&#xff08;多模态RAG方向&#xff09; - Datawhale 项目流程图 1、升级数据解析方案&#xff1a;从 fitz 到 MinerU PyMuPDF&#xff08;fitz&#xff09;是基于规则的方式提取pdf里面的数据&#xff1b;MinerU是基于深度学习模型通过把PDF内的页面看成是图片…

09--解密栈与队列:数据结构核心原理

1. 栈 1.1. 栈的简介 栈 是一种 特殊的线性表&#xff0c;具有数据 先进后出 特点。 注意&#xff1a; stack本身 不支持迭代器操作 主要原因是因为stack不支持数据的随机访问&#xff0c;必须保证数据先进后出的特点。stack在CPP库中实现为一种 容器适配器 所谓容器适配器&a…

打造专属 React 脚手架:从 0 到 1 开发 CLI 工具

前言: 在前端开发中&#xff0c;重复搭建项目环境是个低效的事儿。要是团队技术栈固定&#xff08;比如 React AntD Zustand TS &#xff09;&#xff0c;每次从零开始配路由、状态管理、UI 组件&#xff0c;既耗时又容易出错。这时候&#xff0c;自定义 CLI 脚手架 就派上…

Python day43

浙大疏锦行 Python day43 import torch import numpy as np import pandas as pd import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Da…

python基于Hadoop的超市数据分析系统

前端开发框架:vue.js 数据库 mysql 版本不限 后端语言框架支持&#xff1a; 1 java(SSM/springboot)-idea/eclipse 2.NodejsVue.js -vscode 3.python(flask/django)–pycharm/vscode 4.php(thinkphp/laravel)-hbuilderx 数据库工具&#xff1a;Navicat/SQLyog等都可以 摘要&…

如何用 COLMAP 制作 Blender 格式的数据集

如何用 COLMAP 制作 Blender 格式的数据集并划分出 transforms_train.json、transforms_val.json 和 transforms_test.json。 一、什么是 Blender 格式数据集? Blender 格式数据集是 Nerf 和 Nerfstudio 常用的输入格式,其核心是包含了相机内外参的 JSON 文件,一般命名为:…

[GESP202309 六级] 2023年9月GESP C++六级上机题题解,附带讲解视频!

本文为GESP 2023年9月 六级的上机题目详细题解和讲解视频&#xff0c;觉得有帮助或者写的不错可以点个赞。 题目一讲解视频 GESP2023年9月六级上机题一题目二讲解视频 题目一:小羊买饮料 B3873 [GESP202309 六级] 小杨买饮料 - 洛谷 题目大意: 现在超市一共有n种饮料&#…

linux 操作ppt

目录 方法1&#xff1a;用 libreoffice 打开PPT文件 播放脚本&#xff1a; 方法2&#xff1a;用 python-pptx 创建和编辑PPT 方法3&#xff1a;其他方法 在Linux中&#xff0c;可以使用Python通过python-pptx库来创建和编辑PPT文件&#xff0c;但直接播放PPT文件需要借助其…

元数据管理与数据治理平台:Apache Atlas 基本搜索 Basic Search

文中内容仅限技术学习与代码实践参考&#xff0c;市场存在不确定性&#xff0c;技术分析需谨慎验证&#xff0c;不构成任何投资建议。 Apache Atlas 框架是一套可扩展的核心基础治理服务&#xff0c;使企业能够有效、高效地满足 Hadoop 中的合规性要求&#xff0c;并支持与整个…

LangChain4J-(1)-Hello World

一、LangChain4J是什么&#xff1f; LangChain4J 是一个专为 Java 生态系统设计的开源框架&#xff0c;用于简化与大语言模型&#xff08;LLM&#xff0c;如 OpenAI 的 GPT 系列、Google 的 Gemini、Anthropic 的 Claude 等&#xff09;的集成和交互。它借鉴了 Python 生态中 L…

HTTPS应用层协议-中间攻击人

HTTPS应用层协议-中间攻击人 • Man-in-the-MiddleAttack&#xff0c;简称“MITM 攻击” 确实&#xff0c;在方案 2/3/4 中&#xff0c;客户端获取到公钥 S 之后&#xff0c;对客户端形成的对称秘钥 X 用服务端给客户端的公钥 S 进行加密&#xff0c;中间人即使窃取到了数据&am…

利用 Makefile 高效启动 VIVADO 软件:深入解析与实践

利用 Makefile 高效启动 VIVADO 软件&#xff1a;深入解析与实践 系列文章目录 1、VMware Workstation Pro安装指南&#xff1a;详细步骤与配置选项说明 2、VMware 下 Ubuntu 操作系统下载与安装指南 3.基于 Ubuntu 的 Linux 系统中 Vivado 2020.1 下载安装教程 文章目录利用 …