4.27-5.4学习周报

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 摘要
  • Abstract
  • 一、方法介绍
    • 2.Rainbow Memory(RM)
      • 2.1多样性感知内存更新
      • 2.2通过数据增强增强样本多样性(DA)
  • 二、使用步骤
    • 1.实验概况
    • 2.RM核心代码
  • 总结


摘要

本博客概述了文章《Rainbow Memory: Continual Learning with a Memory of Diverse Samples》聚焦于任务边界模糊的持续学习场景,提出基于样本分类不确定性和数据增强的Rainbow Memory (RM)记忆管理策略。多数研究在任务不共享类别的较人为的设置下评估相关方法,但在现实世界应用场景中,任务之间的类分布是不断变化的,更现实和实用的是任务共享类别的模糊CIL设置。在这种设置下,之前存储少量旧数据的方法虽在缓解灾难性遗忘方面有成果,但也引出了如何管理记忆(memory)的最优策略问题。基于该问题,研究者在新定义的模糊CIL设置下更好地持续学习的两个因素:记忆的采样和记忆中的数据增强,进而提出Rainbow Memory(RM)方法。通过在MNIST、CIFAR10、CIFAR100和ImageNet数据集上的实证验证,RM在模糊持续学习设置中显著提高了准确性,大幅超越现有技术。

文章链接

Abstract

This blog summarizes the article “Rainbow Memory: Continual Learning with a Memory of Diverse Samples”, which focuses on the continuous learning scenario with fuzzy task boundaries, and proposes a Rainbow Memory (RM) memory management strategy based on sample classification uncertainty and data augmentation. Most studies evaluate the relevant methods in a more artificial setting where tasks do not share categories, but in real-world application scenarios, the class distribution between tasks is constantly changing, and it is more realistic and practical to see the fuzzy CIL settings of task sharing categories. In this setting, the previous method of storing a small amount of old data has been successful in mitigating catastrophic forgetting, but it also raises the question of the optimal strategy for managing memory. Based on this problem, the researchers proposed a rainbow memory (RM) method for better continuous learning under the newly defined fuzzy CIL setting: memory sampling and data enhancement in memory. Through empirical verification on MNIST, CIFAR10, CIFAR100, and ImageNet datasets, RM significantly improves accuracy in fuzzy continuous learning settings, significantly outperforming existing technologies.

一、方法介绍

模糊类增量学习的设置要求如下:1)每个任务作为流顺序地给出,(2)大多数(分配的)任务类别彼此不同,以及(3)模型只能利用先前任务的非常小的一部分数据。 如下图所示,在模糊CIL中,任务共享类,与传统的不相交CIL相反。建议的记忆管理策略更新的情景记忆与当前任务的样本,以保持不同的样本在内存中。数据扩充(DA)进一步增强了内存中样本的多样性。
在这里插入图片描述

2.Rainbow Memory(RM)

在模糊类增量学习的场景中,现有方法因样本多样性不足导致模型易过拟合或遗忘严重。为了解决该问题,研究者提出了Rainbow Memory(RM),RM提出通过多样性记忆管理和数据增强解决 Blurry-CIL 问题。

2.1多样性感知内存更新

研究者认为,被选择存储在内存中的样本应该不仅是代表其相应的类,还要识别其他类。为了选择这样的样本,研究者认为,在分类边界附近的样本是最具鉴别力的,靠近分布中心的样本是最具代表性的。为了满足这两个特点,研究者建议抽样的样本是不同的特征空间。

由于计算样本与样本之间的距离O(N2)较为复杂和昂贵,研究者通过分类模型估计的样本的不确定性来估计相对位置,即假设模型的更确定的样本将位于更靠近类分布的中心,通过测量扰动样本的模型输出方差来计算样本的不确定性,扰动样本通过各种数据增强转换方法进行:包括颜色抖动、剪切和剪切,如下图所示:
在这里插入图片描述

通过蒙特-卡罗(MC)法近似计算分布p(y = c)的不确定度|x),当给定扰动样本x的先验时,即p(x| x)的情况下,推导过程可以写成:
在这里插入图片描述

其中,x、x^~、y和A分别表示样本、扰动样本、样本的标签和扰动方法的数量。分布D * 表示由扰动样本λ x定义的数据分布。特别地,扰动样本λ x由随机函数fr(·)绘制,如下:
在这里插入图片描述

其中θr是表示第r次扰动的随机因子的超参数。
测量样品相对于扰动的不确定性为:
在这里插入图片描述

其中u(x)表示样本x的不确定性,Sc是类别c是预测的前1类别的次数。1c表示二进制类索引向量。较低的u(x)值对应于扰动上更一致的top-1类,表明x位于模型强置信的区域.

2.2通过数据增强增强样本多样性(DA)

为了进一步增强记忆中的示例的多样性,研究者采用了数据增强(DA)。 DA的通过图像级或特征扰动使给定的样本多样化,这对应于通过确保多样性来更新内存的理念。
随着任务迭代的进行,新任务中的样本可能会遵循与情节内存中的样本(即,从以前的任务中)遵循不同的分布。 研究者在新任务的类别和内存中旧类的示例中采用混合标记的DA来“混合”图像。 这种混合标签DA减轻了由类分布在任务上的变化引起的副作用,并改善了表现。
混合标记的DA方法之一,CutMix 生成了混合样品和平滑标签,鉴于一组监督样品(X1,Y1)和(X2,Y2),其公式如下:

在这里插入图片描述

二、使用步骤

1.实验概况

研究者通过将RM与各种实验设置中的艺术状态进行比较,从经验上验证了RM的功效。 基准测试的CIL任务设置,情节内存的内存大小和性能指标。在MNIST、CIFAR10、CIFAR100和ImageNet数据集上进行实验。采用多种CIL任务设置、不同的记忆大小和性能指标评估RM方法。将RM与EWC、Rwalk、iCaRL等标准CIL方法对比 ,比较不同方法在各种设置下的Last Accuracy(A5)、Last Forgetting(F5)和Intransigence(I5)等指标。分析RM在不同模糊水平(如Blurry0、Blurry10、Blurry30)下的性能,还探究了不确定性测量方法、记忆更新算法、数据增强方法等对性能的影响。

2.RM核心代码

RM部分的完整核心代码如下:

import logging
import randomimport numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom methods.finetune import Finetune
from utils.data_loader import cutmix_data, ImageDatasetlogger = logging.getLogger()
writer = SummaryWriter("tensorboard")def cycle(iterable):# iterate with shufflingwhile True:for i in iterable:yield iclass RM(Finetune):def __init__(self, criterion, device, train_transform, test_transform, n_classes, **kwargs):super().__init__(criterion, device, train_transform, test_transform, n_classes, **kwargs)self.batch_size = kwargs["batchsize"]self.n_worker = kwargs["n_worker"]self.exp_env = kwargs["stream_env"]if kwargs["mem_manage"] == "default":self.mem_manage = "uncertainty"def train(self, cur_iter, n_epoch, batch_size, n_worker, n_passes=0):if len(self.memory_list) > 0:mem_dataset = ImageDataset(pd.DataFrame(self.memory_list),dataset=self.dataset,transform=self.train_transform,)memory_loader = DataLoader(mem_dataset,shuffle=True,batch_size=(batch_size // 2),num_workers=n_worker,)stream_batch_size = batch_size - batch_size // 2else:memory_loader = Nonestream_batch_size = batch_size# train_list == streamed_list in RMtrain_list = self.streamed_listtest_list = self.test_listrandom.shuffle(train_list)# Configuring a batch with streamed and memory data equally.train_loader, test_loader = self.get_dataloader(stream_batch_size, n_worker, train_list, test_list)logger.info(f"Streamed samples: {len(self.streamed_list)}")logger.info(f"In-memory samples: {len(self.memory_list)}")logger.info(f"Train samples: {len(train_list)+len(self.memory_list)}")logger.info(f"Test samples: {len(test_list)}")# TRAINbest_acc = 0.0eval_dict = dict()self.model = self.model.to(self.device)for epoch in range(n_epoch):# initialize for each taskif epoch <= 0:  # Warm start of 1 epochfor param_group in self.optimizer.param_groups:param_group["lr"] = self.lr * 0.1elif epoch == 1:  # Then set to maxlrfor param_group in self.optimizer.param_groups:param_group["lr"] = self.lrelse:  # Aand go!self.scheduler.step()train_loss, train_acc = self._train(train_loader=train_loader, memory_loader=memory_loader,optimizer=self.optimizer, criterion=self.criterion)eval_dict = self.evaluation(test_loader=test_loader, criterion=self.criterion)writer.add_scalar(f"task{cur_iter}/train/loss", train_loss, epoch)writer.add_scalar(f"task{cur_iter}/train/acc", train_acc, epoch)writer.add_scalar(f"task{cur_iter}/test/loss", eval_dict["avg_loss"], epoch)writer.add_scalar(f"task{cur_iter}/test/acc", eval_dict["avg_acc"], epoch)writer.add_scalar(f"task{cur_iter}/train/lr", self.optimizer.param_groups[0]["lr"], epoch)logger.info(f"Task {cur_iter} | Epoch {epoch+1}/{n_epoch} | train_loss {train_loss:.4f} | train_acc {train_acc:.4f} | "f"test_loss {eval_dict['avg_loss']:.4f} | test_acc {eval_dict['avg_acc']:.4f} | "f"lr {self.optimizer.param_groups[0]['lr']:.4f}")best_acc = max(best_acc, eval_dict["avg_acc"])return best_acc, eval_dictdef update_model(self, x, y, criterion, optimizer):optimizer.zero_grad()do_cutmix = self.cutmix and np.random.rand(1) < 0.5if do_cutmix:x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)logit = self.model(x)loss = lam * criterion(logit, labels_a) + (1 - lam) * criterion(logit, labels_b)else:logit = self.model(x)loss = criterion(logit, y)_, preds = logit.topk(self.topk, 1, True, True)loss.backward()optimizer.step()return loss.item(), torch.sum(preds == y.unsqueeze(1)).item(), y.size(0)def _train(self, train_loader, memory_loader, optimizer, criterion):total_loss, correct, num_data = 0.0, 0.0, 0.0self.model.train()if memory_loader is not None and train_loader is not None:data_iterator = zip(train_loader, cycle(memory_loader))elif memory_loader is not None:data_iterator = memory_loaderelif train_loader is not None:data_iterator = train_loaderelse:raise NotImplementedError("None of dataloder is valid")for data in data_iterator:if len(data) == 2:stream_data, mem_data = datax = torch.cat([stream_data["image"], mem_data["image"]])y = torch.cat([stream_data["label"], mem_data["label"]])else:x = data["image"]y = data["label"]x = x.to(self.device)y = y.to(self.device)l, c, d = self.update_model(x, y, criterion, optimizer)total_loss += lcorrect += cnum_data += dif train_loader is not None:n_batches = len(train_loader)else:n_batches = len(memory_loader)return total_loss / n_batches, correct / num_datadef allocate_batch_size(self, n_old_class, n_new_class):new_batch_size = int(self.batch_size * n_new_class / (n_old_class + n_new_class))old_batch_size = self.batch_size - new_batch_sizereturn new_batch_size, old_batch_size

1.内存管理与数据混合(对应论文 Section 4.1)
将内存中的旧任务样本(memory_loader)与当前任务的流数据(train_loader)按比例混合(默认各占50%)。

使用cycle(memory_loader)循环读取内存数据,避免内存样本因容量限制被忽略。
实现多样性记忆回放,通过混合新旧任务样本缓解灾难性遗忘,确保模型同时学习新任务和巩固旧任务知识。

def train(self, cur_iter, n_epoch, batch_size, n_worker, n_passes=0):# 加载内存数据(旧任务样本)和流数据(新任务样本)if len(self.memory_list) > 0:mem_dataset = ImageDataset(self.memory_list, transform=self.train_transform)memory_loader = DataLoader(mem_dataset, batch_size=(batch_size // 2), ...)stream_batch_size = batch_size - batch_size // 2else:memory_loader = Nonestream_batch_size = batch_size# 混合流数据和内存数据data_iterator = zip(train_loader, cycle(memory_loader))  # 循环迭代内存数据x = torch.cat([stream_data["image"], mem_data["image"]])y = torch.cat([stream_data["label"], mem_data["label"]])

数据增强:CutMix
以50%概率应用CutMix,将两张图像局部区域混合,并生成对应的混合标签(labels_a和labels_b)。
计算混合损失(lam * loss_a + (1-lam) * loss_b),鼓励模型学习更鲁棒的特征,实现标签混合增强(Section 4.2),通过生成边界复杂的样本提升记忆库多样性,增强模型泛化能力。

def update_model(self, x, y, criterion, optimizer):# CutMix增强:混合图像和标签do_cutmix = self.cutmix and np.random.rand(1) < 0.5if do_cutmix:x, labels_a, labels_b, lam = cutmix_data(x=x, y=y, alpha=1.0)logit = self.model(x)loss = lam * criterion(logit, labels_a) + (1 - lam) * criterion(logit, labels_b)else:logit = self.model(x)loss = criterion(logit, y)

动态学习率与批量调整

# Warm start学习率调整
if epoch <= 0:for param_group in self.optimizer.param_groups:param_group["lr"] = self.lr * 0.1  # 初始低学习率
elif epoch == 1:param_group["lr"] = self.lr  # 恢复基准学习率
else:self.scheduler.step()         # 后续按计划调整# 动态调整新旧任务批量大小
def allocate_batch_size(self, n_old_class, n_new_class):new_batch_size = int(self.batch_size * n_new_class / (n_old_class + n_new_class))old_batch_size = self.batch_size - new_batch_sizereturn new_batch_size, old_batch_size

初始阶段使用低学习率(10%基准值)进行预热(Warm-up),避免训练初期不稳定。

根据新旧类别比例动态分配批量大小,平衡新旧任务的学习强度,防止新任务数据主导学习过程。
4. 训练流程与评估

# 训练与评估循环
for epoch in range(n_epoch):train_loss, train_acc = self._train(...)  # 训练eval_dict = self.evaluation(...)          # 评估logger.info(f"Task {cur_iter} | Epoch {epoch+1} | train_acc {train_acc:.4f} | test_acc {eval_dict['avg_acc']:.4f}")

3.实验结果
研究者将提出的RM与各种数据集的“ Blurry10-Online”设置中的其他方法进行了比较,并总结了如下表的结果,如表所示,RM始终优于所有其他方法,并且当类(| C |)增加时,增益会更大。但是,在MNIST上,没有DA的RM表现最好。 研究者认为,DA会干扰模型培训,因为示例足以避免忘记。
在这里插入图片描述

下表列出了三个情节记忆大小(K)的CIFAR10-Blurry10Online的比较; 200、500和1,000。结果表明,这些方法在最终任务中保留了有效的示例,足以恢复以前任务中发生的遗忘。 ICARL,GDUMB和BIC对于不固定(i5)的有效性较小,并且与EWC和RWALK相比,它们在忘记方面的表现较大,作为权衡。
在这里插入图片描述

研究者进一步比较了任务流的准确性轨迹; 由随机分配的函数ψ(c)生成的三个流,具有不同的随机种子,用于Imagenet和单个流,用Imagenet,并总结了下图中的结果:
在这里插入图片描述

RM在整个任务流中始终优于其他基线。

总结

研究结论:研究者提出一种名为彩虹记忆(RM)的方法,用于处理任务共享类别(模糊 - CIL)的现实持续学习场景。通过基于样本分类不确定性的新的多样性增强采样方法和多种数据增强技术,在CIFAR10、CIFAR100和ImageNet的模糊 - CIL场景中,RM大幅优于现有方法,在不连续和离线CIL设置中也有可比性能。
研究的创新性:一是提出基于样本扰动不确定性的多样性增强采样方法管理有限容量记忆;二是采用多种数据增强技术提高样本多样性,增强记忆中样本的代表性和判别性。
研究展望:可研究基于不确定性的记忆更新和数据增强在训练时的关系,及其对不同CIL任务的影响。还可探索RM在更多类型数据集或其他领域持续学习场景中的应用效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/903858.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/903858.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI Rack架构高速互连的挑战:损耗设计与信号完整性的设计框架

在AI驱动的时代&#xff0c;系统设计已经从单一PCB的视角&#xff0c;逐步转向以整个rack为单位来考量。 对于信号完整性而言&#xff0c;焦点以不再局限于单一PCB上的损耗&#xff0c;而是扩展到芯片与芯片之间的端到端互连损耗&#xff08;end-to-end interconnect loss&…

杭电oj(1180、1181)题解

目录 1180 题目 思路 问题概述 代码思路分析 1. 数据结构与全局变量 2. BFS 函数 bfs 3. 主函数 main 总结 代码 1181 题目 思路 1. 全局变量的定义 2. 深度优先搜索函数 dfs 3. 主函数 main 总结 代码 1180 题目 思路 注&#xff1a;当走的方向和楼梯方向一…

软件测试概念

这里写目录标题 需求开发模型软件生命周期瀑布模型螺旋模型增量模型、迭代模型敏捷模型Scrum 测试模型V模型W模型&#xff08;双V模型&#xff09; 需求 用户需求&#xff1a;没有经过合理的评估&#xff0c;通常就是一句话 软件需求&#xff1a;是开发人员和测试人员执行工作…

数字基带信号和频带信号的区别解析

数字基带信号和数字频带信号是通信系统中两种不同的信号形式&#xff0c;它们的核心区别在于是否经过调制以及适用的传输场景。以下是两者的主要区别和分析&#xff1a; 1. 定义与核心区别 数字基带信号&#xff08;Digital Baseband Signal&#xff09; 未经调制的原始数字信号…

Linux52 运行百度网盘 解决故障无法访问repo nosandbox 未解决:疑似libstdc++版本低导致无法运行baidu网盘

昨日参考 哦 我是root Cannot find a valid baseurl for repo: centos-sclo-rh/x86_64 没了 计划去手动下一个 还是不行 放弃 猜测是 centos7 过期了 一些依赖组件也没地方下载了 通过阿里云镜像站下载 之前安装的好像不是这个版本 还是计划用yum去下载依赖&#xff0c;先处…

2000-2022年上市公司数字经济专利申请数据

2000-2022年上市公司数字经济专利申请数据 1、时间&#xff1a;2000-2022年 2、来源&#xff1a;国家知识产权局 3、指标&#xff1a;年份、股票代码、股票简称、行业名称、行业代码、省份、城市、区县、行政区划代码、城市代码、区县代码、首次上市年份、上市状态、数字经济…

机器学习之五:基于解释的学习

正如人们有各种各样的学习方法一样&#xff0c;机器学习也有多种学习方法。若按学习时所用的方法进行分类&#xff0c;则机器学习可分为机械式学习、指导式学习、示例学习、类比学习、解释学习等。这是温斯顿在1977年提出的一种分类方法。 有关机器学习的基本概念&#xff0c;…

Chromium 134 编译指南 - Android 篇:安装构建依赖项(七)

1. 引言 欢迎来到《Chromium 134 编译指南》系列的第七篇文章&#xff01;在前面的章节中&#xff0c;我们已经成功获取了Chromium源代码&#xff0c;并将其配置为支持Android平台。这些步骤为我们的编译之旅奠定了坚实的基础&#xff0c;但在开始实际编译之前&#xff0c;我们…

java 进阶 1.0

静态方法 static 就是能直接用&#xff0c;不用再new一个对象了 一般java中Math等静态类就是可以直接使用其方法 main函数里面不能包含太多的逻辑性语句&#xff0c;全部写成模块 写好程序之后如何测试呢&#xff1f; 使用junit&#xff0c;不能在main函数里测试 测试本身就…

中小企业MES系统详细设计

版本&#xff1a;V1.1 日期&#xff1a;2025年5月2日 一、设备协议兼容性设计 1.1 设备接入框架 #mermaid-svg-PkwqEMRIIlIBPP58 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-PkwqEMRIIlIBPP58 .error-icon{fill…

Spring Security会话管理

用户认证通过后&#xff0c;为了避免用户的每次操作都进行认证&#xff0c;可以将用户的信息保存在会话中。会话就是系统为了保持当前用户的登录状态所提供的机制&#xff0c;常见的有基于Session方式、基于Token方式等。Spring Security提供会话管理功能&#xff0c;只需要配置…

PostgreSQL数据库操作基本命令

常用操作sql &#x1f510; 用户管理 -- 创建用户 CREATE USER username WITH PASSWORD password;-- 修改用户密码 ALTER USER username WITH PASSWORD newpassword;-- 删除用户 DROP USER username;&#x1f4e6; 数据库操作 -- 创建数据库 CREATE DATABASE dbname;-- 删除…

[吾爱出品] 网文提取精灵_4.0

网文提取精灵 链接&#xff1a;https://pan.xunlei.com/s/VOPDvKljcT3EWLjpt5LeDZvfA1?pwdw8kq# 易语言写的&#xff0c;介意的不要下载 相对网文提取工具_2.10.02版&#xff0c;因为是重写界面&#xff0c;目前版本限制最高5线程&#xff0c;暂时不支持批处理。 虽然不支…

每日算法-250502

每日算法 - 2025.05.02 记录一下今天刷的几道 LeetCode 算法题。 3191. 使二进制数组全部等于 1 的最少操作次数 I 题目 思路 贪心 解题过程 遍历数组 nums。当我们遇到 nums[i] 时&#xff1a; 如果 nums[i] 是 1&#xff0c;我们不需要进行操作&#xff0c;因为目标是全 …

移动端开发中设备、分辨率、浏览器兼容性问题

以下是针对移动端开发中设备、分辨率、浏览器兼容性问题的 系统化解决方案&#xff0c;按开发流程和技术维度拆解&#xff0c;形成可落地的执行步骤&#xff1a; 一、基础环境适配&#xff1a;从「起点」杜绝兼容性隐患 1. Viewport 元标签标准化 <meta name"viewpor…

2025最新AI绘画系统源码 - 画图大模型/GPT-4全支持/AI换脸/自定义智能体

在AI绘画技术日新月异的2025年&#xff0c;比象AI绘画系统源码以其突破性的技术创新重新定义了数字艺术创作的边界。作为第四代AI绘画引擎&#xff0c;我们不仅集成了最先进的GPT-4o多模态画图模型&#xff0c;实现了从基础文生图到专业级艺术创作的全面进化。本系统源码经过多…

构造函数详解

构造函数的作用 构造函数的主要任务是初始化对象&#xff0c;而不是创建对象&#xff08;对象的内存空间在构造函数被调用前已经分配好&#xff09;。 构造函数特性 命名规则&#xff1a;函数名必须与类名完全相同。 返回值&#xff1a;构造函数没有返回值类型&#xff08;连…

jaffree 封装ffmpeg 转换视频格式,获取大小,时间,封面

下载 参考网址 【收藏级教程】FFmpeg音视频处理宝典&#xff1a;从入门到精通的50个实用技巧_ffmpeg教程-CSDN博客 配置环境变量 验证 重启idea开发工具 springboot maven集成 <dependency><groupId>com.github.kokorin.jaffree</groupId><artifactId&…

2505C++,wmi客户端示例

原文 #define _WIN32_DCOM #include <iostream> using namespace std; #include <comdef.h> #include <Wbemidl.h> #pragma comment(lib, "wbemuuid.lib") int main(int argc, char **argv) {HRESULT hres;//初化COM.hres CoInitializeEx(0, CO…

[面试]SoC验证工程师面试常见问题(三)

SoC验证工程师面试常见问题(三) 在 SoC 验证工程师的面试中,面试官可能会要求候选人现场编写 SystemVerilog、UVM (Universal Verification Methodology) 或 SystemC 代码,以评估其编程能力、语言掌握程度以及解决实际验证问题的能力。这种随机抽题写代码的环节通常…