使用 accelerate + deepspeed zero stage 3 offload 进行 sft trainning 的自动设备映射: GPU 训练计算 + CPU 存储
run_peft_qlora_deepspeed_stage3.sh
#!/bin/bashexport MAX_JOBS=4
export OMP_NUM_THREADS=4
export disable_exllama=True
export CUDA_VISIBLE_DEVICES=0,1
export TORCH_CUDA_ARCH_LIST="8.6"
export TOKENIZERS_PARALLELISM=false
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True# NCCL 优化
export NCCL_DEBUG=INFO
export NCCL_P2P_DISABLE=0
export NCCL_IGNORE_CPU_AFFINITY=1
export MASTER_ADDR=localhost
export MASTER_PORT=2022
export NCCL_SOCKET_IFNAME=ens34
export NCCL_MIN_NRINGS=8
export NCCL_MAX_NCHANNELS=8
export NCCL_ALGO=Ring
export NCCL_PROTO=Simpleaccelerate launch --config_file "deepspeed_config_z3_qlora.yaml" train.py \
--seed 100 \
--model_name_or_path "models/Qwen3-Coder-30B-A3B-Instruct" \
--dataset_name "xxx.json" \
--chat_template_format "qwen3" \
--add_special_tokens False \
--append_concat_token False \
--splits "train,test" \
--max_seq_len 1024 \
--num_train_epochs 3 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--eval_strategy "epoch" \
--save_strategy "epoch" \
--bf16 True \
--packing False \
--learning_rate 1e-4 \
--lr_scheduler_type "cosine" \
--weight_decay 0.01 \
--warmup_ratio 0.1 \
--max_grad_norm 1.0 \
--output_dir "xxx_adapter" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn True \
--use_peft_lora True \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--bnb_4bit_quant_storage_dtype "bfloat16"
deepspeed_config_z3_qlora.yaml (使用 accelerate config 配置生成)
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:deepspeed_config_file: ds_config.jsondeepspeed_moe_layer_cls_names: Qwen3MoeSparseMoeBlockzero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: "no"
dynamo_config:dynamo_backend: INDUCTOR
# dynamo_mode: reduce-overheaddynamo_mode: defaultdynamo_use_dynamic: truedynamo_use_fullgraph: truedynamo_use_regional_compilation: true
enable_cpu_affinity: true
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
ds_config.json
{"fp16": {"enabled": "auto","loss_scale": 0,"loss_scale_window": 1000,"initial_scale_power": 16,"hysteresis": 2,"min_loss_scale": 1},"bf16": {"enabled": "auto"},"zero_optimization": {"stage": 3,"offload_optimizer": {"device": "cpu","pin_memory": true},"offload_param": {"device": "cpu","pin_memory": true},"overlap_comm": true,"contiguous_gradients": true,"reduce_bucket_size": "auto","stage3_prefetch_bucket_size": "auto","stage3_param_persistence_threshold": "auto","sub_group_size": 1e9,"stage3_max_live_parameters": 1e9,"stage3_max_reuse_distance": 1e9,"stage3_gather_16bit_weights_on_model_save": "auto"},"gradient_accumulation_steps": "auto","gradient_clipping": "auto","steps_per_print": 2000,"train_batch_size": "auto","train_micro_batch_size_per_gpu": "auto","wall_clock_breakdown": false,"aio": {"enabled": true,"block_size": 256,"queue_depth": 8,"thread_count": 1,"single_submit": false,"overlap_events": true}
}
utils.py
import os
import json
import torch
import transformers
import packaging.versionfrom enum import Enum
from peft import LoraConfig
from transformers import (AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig,
)
from datasets.builder import DatasetGenerationError
from datasets import DatasetDict, load_dataset, load_from_diskDEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{{'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"class ZephyrSpecialTokens(str, Enum):user = "<|user|>"assistant = "<|assistant|>"system = "<|system|>"eos_token = "</s>"bos_token = "<s>"pad_token = "<pad>"@classmethoddef list(cls):return [c.value for c in cls]class ChatmlSpecialTokens(str, Enum):user = "<|im_start|>user"assistant = "<|im_start|>assistant"system = "<|im_start|>system"eos_token = "<|im_end|>"bos_token = "<s>"pad_token = "<pad>"@classmethoddef list(cls):return [c.value for c in cls]def format_alpaca_to_chatml(sample):"""Convert Alpaca format to ChatML format"""messages = []# Add system messagemessages.append({"role": "system","content": "You are a helpful assistant."})# Add user messageif sample["input"].strip():content = f"{sample['instruction']}\n{sample['input']}"else:content = sample["instruction"]messages.append({"role": "user","content": content})# Add assistant messagemessages.append({"role": "assistant","content": sample["output"]})return {"messages": messages}def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False):def preprocess(samples):batch = []for conversation in samples["messages"]:batch.append(tokenizer.apply_chat_template(conversation, tokenize=False))return {"content": batch}def load_alpaca_dataset(data_path):"""Load and convert Alpaca format dataset to ChatML format"""with open(data_path, "r", encoding="utf-8") as f:data = json.load(f)# Convert Alpaca format to ChatML formatchatml_data = [format_alpaca_to_chatml(item) for item in data]# Convert to Hugging Face Datasetfrom datasets import Datasetreturn Dataset.from_list(chatml_data)raw_datasets = DatasetDict()# 检查是否为本地JSON文件if data_args.dataset_name.endswith(".json") and os.path.exists(data_args.dataset_name):# 加载整个数据集full_dataset = load_alpaca_dataset(data_args.dataset_name)# 根据splits参数分割数据集if "train" in data_args.splits and "test" in data_args.splits:# 90% 训练, 10% 测试split_dataset = full_dataset.train_test_split(test_size=0.1, seed=training_args.seed)raw_datasets["train"] = split_dataset["train"]raw_datasets["test"] = split_dataset["test"]elif "train" in data_args.splits:raw_datasets["train"] = full_datasetelif "test" in data_args.splits:raw_datasets["test"] = full_datasetelse:raise ValueError(f"Split type {data_args.splits} not recognized")else:# 处理Hub数据集或目录结构for split in data_args.splits.split(","):try:# 尝试从Hub加载dataset = load_dataset(data_args.dataset_name, split=split)raw_datasets[split] = datasetexcept Exception:# 检查本地数据集try:dataset = load_from_disk(os.path.join(data_args.dataset_name, split))raw_datasets[split] = datasetexcept Exception as e:raise ValueError(f"Could not load dataset split {split}: {str(e)}")if apply_chat_template:raw_datasets = raw_datasets.map(preprocess,batched=True,remove_columns=raw_datasets["train"].column_names,)train_data = raw_datasets["train"]valid_data = raw_datasets["test"] if "test" in raw_datasets else Noneprint(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data) if valid_data else 0}")print(f"A sample of train dataset: {train_data[0]}")return train_data, valid_datadef create_and_prepare_model(args, data_args, training_args):if args.use_unsloth:from unsloth import FastLanguageModelbnb_config = Nonequant_storage_dtype = Noneif (torch.distributed.is_available()and torch.distributed.is_initialized()and torch.distributed.get_world_size() > 1and args.use_unsloth):raise NotImplementedError("Unsloth is not supported in distributed training")if args.use_4bit_quantization:compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype)bnb_config = BitsAndBytesConfig(load_in_4bit=args.use_4bit_quantization,bnb_4bit_quant_type=args.bnb_4bit_quant_type,bnb_4bit_compute_dtype=compute_dtype,bnb_4bit_use_double_quant=args.use_nested_quant,bnb_4bit_quant_storage=quant_storage_dtype,)if compute_dtype == torch.float16 and args.use_4bit_quantization:major, _ = torch.cuda.get_device_capability()if major >= 8:print("=" * 80)print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")print("=" * 80)elif args.use_8bit_quantization:bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization)if args.use_unsloth:if torch.xpu.is_available():raise NotImplementedError("XPU hasn't supported unsloth yet")# Load modelmodel, _ = FastLanguageModel.from_pretrained(model_name=args.model_name_or_path,max_seq_length=training_args.max_seq_length,dtype=None,load_in_4bit=args.use_4bit_quantization,)else:torch_dtype = (quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32)# Prepare model loading argumentsmodel_kwargs = {"trust_remote_code": True,"torch_dtype": torch_dtype,}if args.use_flash_attn:# 确保使用 float16# model_kwargs["torch_dtype"] = torch.float16if torch.xpu.is_available():print("XPU hasn't supported flash_attn yet, use eager implementation instead.")model_kwargs["attn_implementation"] = "eager"else:model_kwargs["attn_implementation"] = "flash_attention_2"# Only add quantization_config if bnb_config is not Noneif bnb_config is not None:model_kwargs["quantization_config"] = bnb_configmodel = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)peft_config = Nonechat_template = Noneif args.use_peft_lora and not args.use_unsloth:peft_config = LoraConfig(lora_alpha=args.lora_alpha,lora_dropout=args.lora_dropout,r=args.lora_r,bias="none",task_type="CAUSAL_LM",target_modules=args.lora_target_modules.split(",")if args.lora_target_modules != "all-linear"else args.lora_target_modules,)special_tokens = Nonechat_template = Noneif args.chat_template_format == "chatml":special_tokens = ChatmlSpecialTokenschat_template = DEFAULT_CHATML_CHAT_TEMPLATEelif args.chat_template_format == "zephyr":special_tokens = ZephyrSpecialTokenschat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATEif special_tokens is not None:tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path,pad_token=special_tokens.pad_token.value,bos_token=special_tokens.bos_token.value,eos_token=special_tokens.eos_token.value,additional_special_tokens=special_tokens.list(),trust_remote_code=True,)tokenizer.chat_template = chat_template# make embedding resizing configurable?# Transformers 4.46.0+ defaults uses mean_resizing by default, which fails with QLoRA + FSDP because the# embedding could be on meta device, therefore, we set mean_resizing=False in that case (i.e. the status quo# ante). See https://github.com/huggingface/accelerate/issues/1620.uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0")uses_fsdp = os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"# Check if the model is quantizedis_quantized = (bnb_config is not None) or (getattr(model, "hf_quantizer", None) is not None)if is_quantized and uses_fsdp and uses_transformers_4_46:model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False)else:model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)else:tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)tokenizer.pad_token = tokenizer.eos_tokenif args.use_unsloth:# Do model patching and add fast LoRA weightsmodel = FastLanguageModel.get_peft_model(model,lora_alpha=args.lora_alpha,lora_dropout=args.lora_dropout,r=args.lora_r,target_modules=args.lora_target_modules.split(",")if args.lora_target_modules != "all-linear"else args.lora_target_modules,use_gradient_checkpointing=training_args.gradient_checkpointing,random_state=training_args.seed,max_seq_length=training_args.max_seq_length,)return model, peft_config, tokenizer
train.py
import os
import gc
import sys
import torchfrom typing import Optional
from trl import SFTConfig, SFTTrainer
from dataclasses import dataclass, field
from transformers import HfArgumentParser, set_seed
from new_utils import create_and_prepare_model, create_datasets# 清理缓存
gc.collect()
torch.cuda.empty_cache()# Define and parse arguments.
@dataclass
class ModelArguments:"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune from."""model_name_or_path: str = field(metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"})max_seq_length: Optional[int] = field(default=512,metadata={"help": "The maximum total input sequence length after tokenization."},)chat_template_format: Optional[str] = field(default="none",metadata={"help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."},)lora_alpha: Optional[int] = field(default=16)lora_dropout: Optional[float] = field(default=0.1)lora_r: Optional[int] = field(default=64)lora_target_modules: Optional[str] = field(default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",metadata={"help": "comma separated list of target modules to apply LoRA layers to"},)use_nested_quant: Optional[bool] = field(default=False,metadata={"help": "Activate nested quantization for 4bit base models"},)bnb_4bit_compute_dtype: Optional[str] = field(default="float16",metadata={"help": "Compute dtype for 4bit base models"},)bnb_4bit_quant_storage_dtype: Optional[str] = field(default="uint8",metadata={"help": "Quantization storage dtype for 4bit base models"},)bnb_4bit_quant_type: Optional[str] = field(default="nf4",metadata={"help": "Quantization type fp4 or nf4"},)use_flash_attn: Optional[bool] = field(default=False,metadata={"help": "Enables Flash attention for training."},)use_peft_lora: Optional[bool] = field(default=False,metadata={"help": "Enables PEFT LoRA for training."},)use_8bit_quantization: Optional[bool] = field(default=False,metadata={"help": "Enables loading model in 8bit."},)use_4bit_quantization: Optional[bool] = field(default=False,metadata={"help": "Enables loading model in 4bit."},)use_reentrant: Optional[bool] = field(default=False,metadata={"help": "Gradient Checkpointing param. Refer the related docs"},)use_unsloth: Optional[bool] = field(default=False,metadata={"help": "Enables UnSloth for training."},)@dataclass
class DataTrainingArguments:dataset_name: Optional[str] = field(default="timdettmers/openassistant-guanaco",metadata={"help": "The preference dataset to use."},)append_concat_token: Optional[bool] = field(default=False,metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},)add_special_tokens: Optional[bool] = field(default=False,metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},)splits: Optional[str] = field(default="train,test",metadata={"help": "Comma separate list of the splits to use from the dataset."},)def main(model_args, data_args, training_args):# Set seed for reproducibilityset_seed(training_args.seed)# modelmodel, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)# gradient ckptmodel.config.use_cache = not training_args.gradient_checkpointingtraining_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_unslothif training_args.gradient_checkpointing:training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}training_args.dataset_kwargs = {"append_concat_token": data_args.append_concat_token,"add_special_tokens": data_args.add_special_tokens,}# datasetstrain_dataset, eval_dataset = create_datasets(tokenizer,data_args,training_args,apply_chat_template=model_args.chat_template_format != "none",)# trainertrainer = SFTTrainer(model=model,processing_class=tokenizer,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,peft_config=peft_config,)trainer.accelerator.print(f"{trainer.model}")if hasattr(trainer.model, "print_trainable_parameters"):trainer.model.print_trainable_parameters()# traincheckpoint = Noneif training_args.resume_from_checkpoint is not None:checkpoint = training_args.resume_from_checkpointtrainer.train(resume_from_checkpoint=checkpoint)# saving final modelif trainer.is_fsdp_enabled:trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")trainer.save_model()if __name__ == "__main__":parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):# If we pass only one argument to the script and it's the path to a json file,# let's parse it to get our arguments.model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))else:model_args, data_args, training_args = parser.parse_args_into_dataclasses()try:main(model_args, data_args, training_args)# 清理缓存gc.collect()torch.cuda.empty_cache()finally:import torch.distributed as dist# 确保无论训练成功或失败都清理资源if dist.is_initialized():dist.destroy_process_group()