基于 Transformer RoBERTa的情感分类任务实践总结之四——PGM、EMA

整合了以下五大核心技术:R-Drop、PGM 对抗训练、EMA、标签平滑、CosineAnnealing 学习率调度。

1. R-Drop(Regularized Dropout)

原理:同一个样本做两次前向传播(同 dropout mask),计算两次输出的 KL 散度,作为正则项加入损失中。
目标:增强鲁棒性,提升泛化能力。
损失组合:

loss = CrossEntropy(logits1, labels) + CrossEntropy(logits2, labels) + α * KL(logits1 || logits2)

2. PGM(Projected Gradient Method)对抗训练

机制:

在词嵌入空间中添加扰动,制造“敌意样本”。
多步迭代(PGM_STEPS=3),每步计算扰动梯度并累积。
作用:
增强模型对小扰动的鲁棒性,提高对抗泛化能力。
干预时机:在每次主 loss 反向传播后注入对抗 loss 的梯度。

3. EMA(Exponential Moving Average)

思路:
模型参数滑动平均(shadow weights),推理时使用这些平滑参数。
核心优势:
抑制训练波动、缓解过拟合、稳定收敛。

4. 标签平滑(Label Smoothing)

方式:将 one-hot 标签略微“平滑”,防止模型过度自信。
具体值:label_smoothing=0.1
结果:能缓解过拟合、提升模型稳定性。

5. Cosine Annealing 学习率衰减

调度策略:余弦退火(cosine),带 warmup。
优势:
前期快速学习,后期逐步收敛;适合 fine-tuning 场景。

模型训练流程

Trainer 子类化:自定义 AdvancedTrainer,重载 compute_loss 以支持双前向(R-Drop)训练。
Callbacks 集成:
PGMCallback:注入多步对抗扰动。
EmaCallback:更新并应用 shadow 参数。
EarlyStoppingCallback:监控 f1,连续 3 轮无改进则提前停止。

总体优势

多重正则和鲁棒性增强机制叠加,极大提升模型泛化能力和抗干扰能力。
适合工业级 NLP 分类任务的强化训练。

代码

# Advanced RoBERTa Sentiment Classifier with R-Drop + PGM + EMA + LabelSmoothing + CosineAnnealingimport os
import numpy as np
import torch
import torch.nn as nn
from transformers import (AutoTokenizer,AutoModelForSequenceClassification,Trainer,TrainingArguments,DataCollatorWithPadding,set_seed,EarlyStoppingCallback,TrainerCallback
)
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score# 固定随机种子
set_seed(42)# 配置参数
MODEL_NAME = "roberta-base"
NUM_LABELS = 2
R_DROP_ALPHA = 5.0
LABEL_SMOOTHING = 0.1
PGM_EPSILON = 1.0
PGM_ALPHA = 0.3
# --- PGM 多步迭代次数 ---
PGM_STEPS = 3 # 例如,迭代 3 次来生成对抗扰动
EMA_DECAY = 0.999
# 加载数据
dataset = load_dataset("imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)def preprocess_function(examples):return tokenizer(examples["text"], truncation=True)train_dataset = train_dataset.map(preprocess_function, batched=True)
test_dataset = test_dataset.map(preprocess_function, batched=True)# 数据整理器
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)# 加载模型
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS
)# --- R-Drop Loss ---
class RDropLoss(nn.Module):def __init__(self, alpha=1.0, label_smoothing=0.0):super().__init__()self.alpha = alphaself.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)self.kl = nn.KLDivLoss(reduction="batchmean")def forward(self, logits1, logits2, labels):ce_loss1 = self.ce(logits1, labels)ce_loss2 = self.ce(logits2, labels)ce_loss = 0.5 * (ce_loss1 + ce_loss2)p = torch.log_softmax(logits1, dim=-1)q = torch.log_softmax(logits2, dim=-1)p_softmax = torch.softmax(logits1, dim=-1)q_softmax = torch.softmax(logits2, dim=-1)kl_loss = 0.5 * (self.kl(p, q_softmax) + self.kl(q, p_softmax))return ce_loss + self.alpha * kl_loss# --- PGM ---
class PGM:def __init__(self, model, epsilon=1.0, alpha=0.3, emb_name='embeddings.word_embeddings'):self.model = modelself.epsilon = epsilonself.alpha = alphaself.emb_name = emb_nameself.backup = {}
# 扰动词嵌入可以被理解为在原始单词的语义空间中进行微小的“移动”,
# 使其略微偏离原来的意义,但又不至于完全改变其含义,从而模拟“对抗性样本”。def attack(self, is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and self.emb_name in name and param.grad is not None:if is_first_attack:self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0:r_at = self.alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data, self.backup[name])def restore(self):for name, param in self.model.named_parameters():if name in self.backup:param.data = self.backup[name]self.backup = {}def project(self, param_name, param_data, param_backup):r = param_data - param_backupif torch.norm(r) > self.epsilon:r = self.epsilon * r / torch.norm(r)return param_backup + r# --- EMA ---
class EMA:def __init__(self, model, decay):self.model = modelself.decay = decayself.shadow = {}self.backup = {}def register(self):for name, param in self.model.named_parameters():if param.requires_grad:if name not in self.shadow:self.shadow[name] = param.data.clone()def update(self):for name, param in self.model.named_parameters():if param.requires_grad:if name not in self.shadow:continue  # 保护:skip 未注册 param,避免 KeyErrornew_avg = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]self.shadow[name] = new_avg.clone()def apply_shadow(self):for name, param in self.model.named_parameters():if param.requires_grad:if name not in self.shadow:continueself.backup[name] = param.data.clone()param.data = self.shadow[name]def restore(self):for name, param in self.model.named_parameters():if param.requires_grad:if name in self.backup:param.data = self.backup[name]self.backup = {}# --- Callbacks ---
class PGMCallback(TrainerCallback):def __init__(self, pgm, rdrop_loss_fn, pgm_steps=1):self.pgm = pgmself.rdrop_loss_fn = rdrop_loss_fnself.pgm_steps = pgm_steps # 对抗迭代步数def on_after_backward(self, args, state, control, model=None, inputs=None, optimizer=None, **kwargs):# 备份原始的梯度状态,以便在对抗训练结束后恢复# PyTorch 优化器在 step() 时会清零梯度,但我们需要在 PGM 内部操作时保留它们# 更安全的做法是使用一个更细粒度的梯度累积或在每次 PGM 迭代后清零# 为了简化,我们假设在这个回调中梯度是可用的# 开始多步 PGMfor step in range(self.pgm_steps):is_first_attack = (step == 0)# 在第一次攻击时备份参数并施加扰动# 在后续攻击时,只施加扰动,不备份self.pgm.attack(is_first_attack=is_first_attack)# 在扰动后的模型上进行前向传播# 这里需要确保模型处于训练模式,并且梯度是开启的model.train() # 确保模型处于训练模式model.zero_grad() # 在每次PGM步清零模型梯度adv_outputs1 = model(**{k: v for k, v in inputs.items() if k != "labels"})adv_outputs2 = model(**{k: v for k, v in inputs.items() if k != "labels"})adv_logits1 = adv_outputs1.logitsadv_logits2 = adv_outputs2.logitslabels = inputs["labels"]# 计算对抗损失adv_loss = self.rdrop_loss_fn(adv_logits1, adv_logits2, labels)# 对抗损失的反向传播# 注意:这里不能简单地直接调用 .backward()# 因为 Trainer 已经处理了主损失的梯度累积和优化器步骤# PGM 的梯度应该累积到主梯度中,而不是覆盖它们# 最简单的集成方式是让对抗损失也产生梯度,并累积到参数上# 在 Hugging Face Trainer 的 on_after_backward 中,# 已经进行了一次主损失的 backward,因此这里的 adv_loss.backward() 会累积梯度。# 但是,为了避免在多步中梯度累积不当,需要更细致的控制。# 通常,PGM 是在优化器步骤之前,对参数进行修改并重新计算损失。# --- 关键点:如何处理多步 PGM 的梯度 ---# 这里的 `adv_loss.backward()` 会计算并累积梯度。# 由于每次 `pgm.attack()` 都会修改参数,所以 `adv_loss` 都会基于当前扰动后的参数计算。# 在每次 `step` 中,我们计算 `adv_loss` 的梯度并累加到参数上。# 注意:`model.zero_grad()` 放在循环内部可以确保每次 PGM 步只计算当前扰动下的梯度,# 如果放在循环外部,则所有 PGM 步的梯度会累积到同一个梯度值上。# 这里设置为每次 PGM 步清零梯度,然后计算当前步的对抗梯度。# 这样做可以确保 `adv_loss.backward()` 每次计算的是相对于当前扰动参数的梯度。accelerator = kwargs.get("accelerator", None)if accelerator is not None:accelerator.backward(adv_loss)else:adv_loss.backward()# 多步 PGM 结束后,恢复模型参数到原始状态(即未被 PGM 扰动前的状态)self.pgm.restore()# 此时,model 的所有 param.grad 中已经包含了# (主损失的梯度) + (最后一次 PGM 迭代的对抗损失的梯度)# HuggingFace Trainer 会紧接着调用优化器的 step() 方法来更新模型的参数。#最终留下并用于优化器更新的梯度,是最后一次 PGM 迭代所产生的对抗损失的梯度。class EmaCallback(TrainerCallback):def __init__(self, ema):self.ema = emadef on_step_end(self, args, state, control, **kwargs):self.ema.update()def on_evaluate(self, args, state, control, **kwargs):self.ema.apply_shadow()def on_evaluate_end(self, args, state, control, **kwargs):self.ema.restore()# --- AdvancedTrainer  ---
class AdvancedTrainer(Trainer):def __init__(self, *args, alpha=1.0, label_smoothing=0.0, ema=None, **kwargs):super().__init__(*args, **kwargs)self.rdrop_loss_fn = RDropLoss(alpha=alpha, label_smoothing=label_smoothing)self.ema = emaif self.ema is not None:self.ema.register()def compute_loss(self, model, inputs, return_outputs=False, **kwargs):labels = inputs["labels"]# 两次前向传播用于 R-Dropoutputs1 = model(**{k: v for k, v in inputs.items() if k != "labels"})outputs2 = model(**{k: v for k, v in inputs.items() if k != "labels"})logits1 = outputs1.logitslogits2 = outputs2.logitsloss = self.rdrop_loss_fn(logits1, logits2, labels)return (loss, outputs1) if return_outputs else loss# --- Metrics ---
def compute_metrics(eval_pred):logits, labels = eval_predprobs = torch.softmax(torch.tensor(logits), dim=-1).numpy()predictions = np.argmax(logits, axis=-1)acc = accuracy_score(labels, predictions)f1 = f1_score(labels, predictions)try:auc = roc_auc_score(labels, probs[:, 1])except:auc = 0.0return {"accuracy": acc, "f1": f1, "auc": auc}# --- TrainingArguments  ---
training_args = TrainingArguments(output_dir="./results_adv_rdrop_pgm_ema_multistep", # 更改输出目录eval_strategy="epoch",save_strategy="epoch",learning_rate=2e-5,per_device_train_batch_size=16,per_device_eval_batch_size=16,num_train_epochs=5,weight_decay=0.01,warmup_ratio=0.1,lr_scheduler_type="cosine",logging_dir="./logs_adv_rdrop_pgm_ema_multistep", # 更改日志目录logging_steps=50,load_best_model_at_end=True,metric_for_best_model="f1",fp16=True,save_total_limit=2,
)# --- 初始化模块 ---
pgm = PGM(model, epsilon=PGM_EPSILON, alpha=PGM_ALPHA)
ema = EMA(model, decay=EMA_DECAY)# --- Trainer ---
trainer = AdvancedTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=test_dataset,tokenizer=tokenizer, # 使用 tokenizer 而不是 processing_classdata_collator=data_collator,compute_metrics=compute_metrics,alpha=R_DROP_ALPHA,label_smoothing=LABEL_SMOOTHING,callbacks=[PGMCallback(pgm=pgm, rdrop_loss_fn=RDropLoss(alpha=R_DROP_ALPHA, label_smoothing=LABEL_SMOOTHING), pgm_steps=PGM_STEPS),EmaCallback(ema=ema),EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01),],
)# --- 训练 ---
trainer.train()# --- 评估 ---
trainer.evaluate()

结果

{'eval_loss': 0.2900645434856415, 'eval_accuracy': 0.95836, 'eval_f1': 0.9586822782298074, 'eval_auc': 0.9911689504000001, 'eval_runtime': 275.0978, 'eval_samples_per_second': 90.877, 'eval_steps_per_second': 5.682, 'epoch': 5.0}                                                                                                             
{'train_runtime': 171019.0699, 'train_samples_per_second': 0.731, 'train_steps_per_second': 0.046, 'train_loss': 0.30841634256749756, 'epoch': 5.0}                      
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7815/7815 [47:30:19<00:00, 21.88s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [04:32<00:00,  5.75it/s]

TensorBoard

在这里插入图片描述

注意

实际操作,这里要保存模型。还要转成ONNX模型,用C++ OnnxRuntime推理等等推理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/84612.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

录制mp4 rospy

ros 预览摄像头 #!/usr/bin/env python import rospy from sensor_msgs.msg import Image from cv_bridge import CvBridge import cv2# 初始化 bridge bridge CvBridge()def image_callback(msg):# 将ROS图像消息转换为OpenCV图像cv_image bridge.imgmsg_to_cv2(msg, desir…

超简单部署离线语音合成TTS和语音识别

一篇文章讲清楚超简单 离线语音合成TTS 和 离线语音识别 系统部署 本文只介绍两个轻量级的 语音合成用piper, 语音识别用vosk 部署简单,效果勉强 语音合成 推荐 piper (其他没用过) 安装 linux下安装 pip install piper-tts下载模型(63M) 中文模型下载 zh_CN-huayan-medi…

【算力网】

一、算力网-DNS 1.1、核心架构设计 1.1.1 设计框架 基于SRv6的智能DNS算法设计框架&#xff0c;结合IPv6路由可编程性、动态路径优化及业务感知能力&#xff0c;实现网络性能与用户体验的双重提升&#xff1a;​ ​SRv6-DNS融合架构​ ​控制平面​&#xff1a; DNS服务器集…

shell分析nginx日志的指令

shell指令 查看有多少个IP访问&#xff1a; awk {print $1} log_file|sort|uniq|wc -l 查看某一个页面被访问的次数&#xff1a; grep "/index.php" log_file | wc -l 查看每一个IP访问了多少个页面&#xff1a; awk {S[$1]} END {for (a in S) print a,S[a]} …

CMS软件以及常见分类

CMS&#xff08;Content Management System&#xff0c;内容管理系统&#xff09;是 让非技术人员也能便捷创建、编辑、管理网站内容的软件 &#xff0c;核心是 分离 “内容” 和 “页面设计”&#xff08;内容存在数据库&#xff0c;页面用模板生成&#xff09;&#xff0c;无…

Spring @Value 典型用法

典型用法 注入常量值 Value("Hello World") private String message;注入配置文件中的属性值&#xff08;如 application.properties&#xff09; // 假设你有如下配置&#xff1a; app.nameMyApp app.version1.0.0// Java 类中使用&#xff1a; Value("${ap…

golang -- map实现原理

目录 一、前言二、结构1. hmap(map) 结构2. bmap(buckets) 结构 三、哈希冲突四、负载因子五、哈希函数六、扩容增量扩容等量扩容 一、前言 在现代编程语言中&#xff0c;map 是一种非常重要的数据结构&#xff0c;广泛用于存储和快速查找键值对。Go 语言中的 map 是一种高效且…

Vue2 Extends 继承机制与组件复用实践

extends在某些场景下依然发挥作用&#xff0c;如Options API。子组件将继承父组件的属性、方法、生命周期钩子函数以及混合&#xff08;mixins&#xff09;等选项。 注意&#xff1a;子组件可以覆盖、或继承扩展父组件的选项。子组件的生命周期钩子和父组件的钩子一起执行。 &l…

openSUSE MicroOS不可变Linux

openSUSE MicroOS不可Linux 1、openSUSE MicroOS简介安装时可能遇到的问题 2、ssh登录3、openSUSE MicroOS配置国内软件源4、系统变更openSUSE MicroOS安装软件包方法1&#xff1a;进入事务性更新模式安装软件包方法2&#xff1a;继续快照id基于这个快照进行增量安装方法3&…

建站SEO优化之站点地图sitemap

文章目录 编写规范小型网站站点地图小型网站规范示例站点地图说明 大型网站站点地图大型网站规范示例以豆瓣站点地图为例 近期文章&#xff1a; 个人建站做SEO网站外链这一点需要注意&#xff0c;做错了可能受到Google惩罚一文搞懂SEO优化之站点robots.txt网页常见水印实现方式…

Java分层开发必知:PO、BO、DTO、VO、POJO概念详解

目录 引言一、核心概念与定义1、PO&#xff08;Persistent Object&#xff0c;持久化对象&#xff09;2、BO&#xff08;Business Object&#xff0c;业务对象&#xff09;3、DTO&#xff08;Data Transfer Object&#xff0c;数据传输对象&#xff09;4、VO&#xff08;View O…

Linux下OLLAMA安装卡住怎么办?

网络环境不理想&#xff0c;经常在官方的linux安装脚本执行时卡住&#xff0c;其实主要是下载文件卡住&#xff0c;于是我想到了是否可以把其中下载的过程显化、分步&#xff0c;这样更可控&#xff0c;于是修改了官方的install.sh #!/bin/sh # This script installs Ollama o…

C++面试(5)-----删除链表中指定值的节点

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 给定一个单向链表的头节点 head 和一个特定值 val&#xff0c;要求编写一个函数来删除链表中所有值等于 val 的节点&#xff0c;并返回修改后的链表头节点。 示例&#xff1a; 输…

如何用AI赋能学习

由于博主是大学生&#xff0c;今天花费了大量的时间去进行期末的复习&#xff0c;不过从复习中得到了一些学习的灵感&#xff0c;即&#xff1a;如何用AI赋能学习 当我们需要掌握一门新的技能的时候&#xff0c;我们很容易的想到三种办法&#xff1a;买书自己学&#xff0c;报…

【threejs】每天一个小案例讲解:常见材质

代码仓 GitHub - TiffanyHoo/three_practices: Learning three.js together! 可自行clone&#xff0c;无需安装依赖&#xff0c;直接liver-server运行/直接打开chapter01中的html文件 运行效果图 知识要点 1. MeshBasicMaterial&#xff08;基础网格材质&#xff09; • 特…

springboot后端与鸿蒙的结合

软件&#xff1a;鸿蒙devceo3.1&#xff0c;springboot项目采用IDEA 目的&#xff1a; 1、结合springboot后端与鸿蒙的结合运用。 2、Log日志查看console语句的信息。 3、引入 import http from ohos.net.http。 4、调用springboot后端提供的链接发送post 5、TextInput的…

minio集群通过mc mirror命令进行定时备份,支持X86和arm两种架构

文章目录 前言一、思路二、使用步骤1.下载mc二进制文件2.手动测试备份命令3.配置定时任务4.成功截图 总结 前言 通过mc mirror命令对minio集群进行定时备份。 一、思路 通过mc mirror命令配合crond定时任务进行周期性的备份 二、使用步骤 1.下载mc二进制文件 wget https:…

三大能力升级,为老项目重构开辟新路径

在软件技术飞速迭代的今天&#xff0c;老项目重构是开发者们绕不开的难题。接口实现缺失、业务逻辑矛盾、架构规划偏离等问题如同拦路虎&#xff0c;让重构工作举步维艰。而传统的 AI 辅助方式&#xff0c;因未充分关联项目实际情况&#xff0c;犹如 “空中造楼”&#xff0c;难…

AES加密

AES加密算法详解 AES&#xff08;Advanced Encryption Standard&#xff09;是一种对称密钥分组加密算法&#xff0c;用于保护电子数据的安全性。其核心特点是通过相同的密钥进行加密和解密&#xff0c;属于对称加密体系。。以下从核心特性、加密流程及安全性三方面展开说明&a…

关于联咏(Novatek )自动曝光中Lv值的计算方式实现猜想

目录 一、常见Lv对应的实际场景 二、常见光圈值 三、最小二乘法计算SV中的系数K