论文阅读笔记:《Curriculum Coarse-to-Fine Selection for High-IPC Dataset Distillation》
- 1.背景与动机
- 2.核心贡献
- 3.方法详解
- 4.实验结果与贡献
- 主体代码
- 算法整体逻辑
CVPR25 github
一句话总结:
CCFS基于组合范式(轨迹匹配+选择真实图像),通过“粗过滤+精选”课程式框架,动态补充合成集弱点,显著提升高IPC设定下的数据集蒸馏性能,是目前高IPC场景下的SOTA方法。
1.背景与动机
- Dataset Distillation: 将一个大规模训练集压缩成一个小型合成数据集,使得在此合成集上训练的模型性能接近用原始全量数据训练的模型。
- IPC(Image Per Class):每类合成图像数。低IPC场景下(每类几张图),已有方法表现不错;但IPC增大(要合成更多图像)时,性能往往退化,甚至不如简单随机抽样。
- 核心问题:高IPC时合成集过于“平均”,缺少稀有/复杂特征(hard samples),导致合成集覆盖不足;已有的混合蒸馏+真实样本方法(如SelMatch)是一次性静态选样,缺乏与合成集的动态互补。
2.核心贡献
- 不兼容性诊断:分析了“先选真实样本再蒸馏”范式下,静态真样本与动态蒸馏集互补不足的问题。
- CCFS方法:提出一种课程式(Curriculum)“从粗到细”动态选真样本框架,将选样分为两阶段:
- 粗过滤(Coarse):用当前合成集训练的filter模型识别“还没学会”的真实样本(即被错分的样本)
- 精细选择(fine):在这些候选中,根据“难度分数”或直接用filter logits选出“最简单但尚未学会”的样本,逐步补充到合成集中。
- 实证效果:在CIFAR-10/100和Tiny-ImageNet的高IPC设置(压缩比5%~30%)下,CCFS刷新多项SOTA,部分场景下性能仅比全量训练低0.3%。
3.方法详解
整体流程:
- 初始化
- 从任一基础蒸馏算法(如CDA)得到初始合成集DdistillD_{distill}Ddistill
- 令当前合成集S0=DdistillS_0=D_{distill}S0=Ddistill
- 课程循环(共j阶)
- 训练Filter:在Sj−1S_{j-1}Sj−1上蒸馏训练一个filter模型ϕjϕ_jϕj, 让它学会当前合成集的决策边界。
- Coarse:用ϕjϕ_jϕj在原始训练集T上做推理,挑出被错分的样本集合DmisjD_{mis}^jDmisj
- Fine:对DmisjD_{mis}^jDmisj内部进行排序,选出每类最“简单未学会”的前kjk_jkj张,构成DrealjD_{\mathrm{real}}^jDrealj
- 更新:Sj=Sj−1∪DrealjS_j = S_{j-1} \cup D^j_{\mathrm{real}}Sj=Sj−1∪Drealj
解释一下为什么选“简单未学会”的样本:
错分样本集合DmissD_{miss}Dmiss反映了S中的局限性。在这些局限性中,更简单的特征相对于更复杂的特征而言,对模型训练的益处更大,因为它们更容易被学习。预选计算的难度分数能够从全局角度有效衡量样本特征的相对难度,指导下一步的精细选择。通过从误分类样本中选择最简单的特征,可以获得最优的DrealD_{real}Dreal,同时避免引入可能阻碍S性能的过于复杂的特征。
4.实验结果与贡献
- 数据集:CIFAR-10/100,Tiny-ImageNet
- 高IPC设置刷新SOTA
- CIFAR-10/100 在 10% IPC 下,分别较最佳基线提升 ~6.1% / ~5.8%;
- Tiny-ImageNet 20% IPC 下,仅比全量训练低 0.3%。
- 跨架构泛化
用 ResNet-18 生成合成集,训练 ResNet-50/101、DenseNet-121、RegNet 等网络,均优于 CDA、SelMatch 等方法 - 详尽消融
- 验证 coarse(错分 vs 自信分对)、fine(简单 vs 困难 vs 随机)策略组合;
- 不同难度分数对比,Forgetting score 最好;
- 课程轮数对性能与效率影响,3 轮是良好折中。
主体代码
import os
import datetime
import time
import warnings
import numpy as np
import random
import torch
import torch.utils.data
import torchvision
import utils
from torch import nn
import torchvision.transforms as transforms
from imagenet_ipc import ImageFolderIPC
import torch.nn.functional as F
from tqdm import tqdm
import json
warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="CCFS on CIFAR-100", add_help=add_help)parser.add_argument("--data-path", default=None, type=str, help="path to CIFAR-100 data folder")parser.add_argument("--filter-model", default="resnet18", type=str, help="filter model name")parser.add_argument("--teacher-model", default="resnet18", type=str, help="teacher model name")parser.add_argument("--teacher-path", default=None, type=str, help="path to teacher model")parser.add_argument("--eval-model", default="resnet18", type=str, help="model for final evaluation")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=64, type=int, help="Batch size")parser.add_argument("--epochs", default=90, type=int, metavar="N", help="# training epochs for both the filter and the evaluation model")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="sgd", type=str, help="optimizer")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, metavar="W", help="weight decay (default: 1e-4)", dest="weight_decay")parser.add_argument("--lr-scheduler", default="steplr", type=str, help="the lr scheduler (default: steplr)")parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")parser.add_argument("--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)")parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")parser.add_argument("-T", "--temperature", default=20, type=float, help="temperature for distillation loss")parser.add_argument("--print-freq", default=1000, type=int, help="print frequency")# --- CCFS parameters ---# 目标每类最终的图像数量IPCparser.add_argument("--image-per-class", default=50, type=int, help="number of synthetic images per class")parser.add_argument("--distill-data-path", default=None, type=str, help="path to already distilled data")# distillation portion,决定合成(蒸馏)图像与真实选择的比例;# cpc=IPC*alpha 是每类的合成图像 (condensed per class)# spc=IPC-cpc是每类要选的真实图像数 (selected per class)parser.add_argument('--alpha', type=float, default=0.2, help='Distillation portion')# 分几个阶段做”课程式“选样(例如3轮)parser.add_argument('--curriculum-num', type=int, default=None, help='Number of curricula')# 粗阶段式选被filter预测错的(True)还是预测对的(false)parser.add_argument('--select-misclassified', action='store_true', help='Selection strategy in coarse stage')# 细阶段的选法 ,simple/hard/random(对应论文里”最简单未学会“/"最难"/"随机")parser.add_argument('--select-method', type=str, default='simple', choices=['random', 'hard', 'simple'], help='Selection strategy in fine stage')# 是否每类均衡选parser.add_argument('--balance', action='store_true', help='Whether to balance the amount of the synthetic data between classes')# 选择哪种方法评分parser.add_argument('--score', type=str, default='forgetting', choices=['logits', 'forgetting', 'cscore'], help='Difficulty score used in fine stage')# 如果不是logits而是预先算好的难度分(如forgetting score),用这个路径读入parser.add_argument('--score-path', type=str, default=None, help='Path to the difficulty score')parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")parser.add_argument("--num-eval", default=1, type=int, help="number of evaluations")return parserdef load_data(args):'''数据集加载Returns:dataset: 蒸馏数据集image_og, labels_og: 全量原始训练样本(用于选样)dataset_test: 验证集(test)以及对应sampler'''# Data loading codenormalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],std=[0.2023, 0.1994, 0.2010])print("Loading distilled data")train_transform = transforms.Compose([transforms.RandomResizedCrop(32),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize,])# ImageFolderIPC自定义数据读取,可以从大的ipc蒸馏数据集中每类选择或随机选择cpc个图像# cpc=IPC * alpha,是每类的合成图像数dataset = ImageFolderIPC(root=args.distill_data_path, ipc=args.cpc, transform=train_transform)print("Loading validation data")val_transform = transforms.Compose([transforms.ToTensor(),normalize,])# 加载验证集(test)dataset_test = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=val_transform)print("Loading original training data")# 加载原始训练集(用于做coarse selection / teacher correctness等)dataset_og = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=val_transform)# 构造原始训练数据:直接把全量CIFAR100d的所有图像展开到一个大tensor image_og 和对应标签labels_og。# 这在内存允许时可行,但规模增大时可以优化成分batch处理或lazy访问images_og = [torch.unsqueeze(dataset_og[i][0], dim=0) for i in range(len(dataset_og))]labels_og = [dataset_og[i][1] for i in range(len(dataset_og))]images_og = torch.cat(images_og, dim=0)labels_og = torch.tensor(labels_og, dtype=torch.long)print("Creating data loaders")train_sampler = torch.utils.data.RandomSampler(dataset)test_sampler = torch.utils.data.SequentialSampler(dataset_test)return dataset, images_og, labels_og, dataset_test, train_sampler, test_samplerdef create_model(model_name, device, num_classes, path=None):# 根据名称构造backbone(TODO:默认不加载预训练权重)model = torchvision.models.get_model(model_name, weights=None, num_classes=num_classes)# 将下采样第一层conv和pooling修改为适配CIFAR风格model.conv1 = nn.Conv2d(3, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)model.maxpool = nn.Identity()# 加载预训练权重 (TODO:是否加载预训练权重)if path is not None:checkpoint = torch.load(path, map_location="cpu")if "model" in checkpoint:checkpoint = checkpoint["model"]elif "state_dict" in checkpoint:checkpoint = checkpoint["state_dict"]if "module." in list(checkpoint.keys())[0]:checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}model.load_state_dict(checkpoint)model.to(device)return modeldef curriculum_arrangement(spc, curriculum_num):'''课程分配安排将总共要选的每类真实图像数spc等分到curriculum_num轮:例如spc=7,curriculum_num=3会分成[3,2,2](前面多余的向前分)'''remainder = spc % curriculum_numarrangement = [spc // curriculum_num] * curriculum_numfor i in range(remainder):arrangement[i] += 1return arrangementdef train_one_epoch(model, teacher_model, criterion, optimizer, data_loader, device, epoch, args):"""在一个 epoch(遍历一遍 data_loader)里,用 KL 散度蒸馏(distillation)student(model)去学习 teacher_model 的“软标签”。具体做法是:把 teacher 和 student 的 logits 都除以温度 T 后做 log_softmax,然后用 KLDivLoss;最后乘上 T^2 做梯度缩放,确保温度对 loss 的影响保持一致。"""# 切换student到train模式model.train()# 切换teacher到eval模式,只做前向不更新teacher_model.eval()metric_logger = utils.MetricLogger(delimiter=" ")metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))header = f"Epoch: [{epoch}]"for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):start_time = time.time()image, target = image.to(device), target.to(device)# 1)teacher 前向teacher_output = teacher_model(image)# 2)student前向output = model(image)# 把Logits除以温度系数,再做log_softmaxteacher_output_log_softmax = F.log_softmax(teacher_output/args.temperature, dim=1)output_log_softmax = F.log_softmax(output/args.temperature, dim=1)# 用KL散度计算loss,乘上T^2以抵消温度缩放带来的梯度变换loss = criterion(output_log_softmax, teacher_output_log_softmax) * (args.temperature ** 2)# 标准的 backward 流程optimizer.zero_grad()loss.backward()optimizer.step()# 计算 student 在原始 hard label(target)上的 top1/top5 准确率acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))batch_size = image.shape[0]# 更新 metric_logger 里的各项指标metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))def evaluate(model, criterion, data_loader, device, log_suffix=""):"""在测试/验证集上跑一个完整的 forward,计算交叉熵 loss 和 top1/top5 准确率。不做梯度更新,只做推理。"""model.eval()metric_logger = utils.MetricLogger(delimiter=" ")header = f"Test: {log_suffix}"num_processed_samples = 0 # 累计处理样本数with torch.inference_mode():for image, target in data_loader:image = image.to(device, non_blocking=True)target = target.to(device, non_blocking=True)# 前向output = model(image)# 用硬标签算交叉熵loss = criterion(output, target)# 计算 top1/top5acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))batch_size = image.shape[0]metric_logger.update(loss=loss.item())metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)num_processed_samples += batch_size# 如果是分布式,需要把各卡的样本数累加num_processed_samples = utils.reduce_across_processes(num_processed_samples)if (hasattr(data_loader.dataset, "__len__")and len(data_loader.dataset) != num_processed_samplesand torch.distributed.get_rank() == 0):warnings.warn(f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} ""samples were used for the validation, which might bias the results. ""Try adjusting the batch size and / or the world size. ""Setting the world size to 1 is always a safe bet.")metric_logger.synchronize_between_processes()return metric_logger.acc1.global_avgdef curriculum_train(current_curriculum, dst_train, test_loader, model, teacher_model, args):"""对当前的”合成数据+已选真实数据“ dst_train进行一次完整的filter模型训练(蒸馏学习):- 根据数据规模动态调整batch_size- 构造 DataLoader / Criterion / Optimizer / LR Scheduler(含 warmup)- 训练 args.epochs 轮,后 20% 轮做验证并记录最佳 acc1返回:训练好的 model 和最佳 top-1 准确率 best_acc1"""best_acc1 = 0# 1. 根据dst_train(合成+真实)大小粗略选batch sizeif len(dst_train) < 50 * args.num_classes:args.batch_size = 32elif 50 * args.num_classes <= len(dst_train) < 100 * args.num_classes:args.batch_size = 64else:args.batch_size = 128# 2. 用随机采样器包装训练集,保证每个epoch顺序打散train_sampler = torch.utils.data.RandomSampler(dst_train)train_loader = torch.utils.data.DataLoader(dst_train,batch_size=args.batch_size,sampler=train_sampler,num_workers=args.workers,pin_memory=True,)# 3. 损失函数:硬标签用CrossEntropy,蒸馏软标签用KLDivcriterion = nn.CrossEntropyLoss()criterion_kl = nn.KLDivLoss(reduction='batchmean', log_target=True)parameters = utils.set_weight_decay(model, args.weight_decay)# 构造优化器opt_name = args.opt.lower()if opt_name.startswith("sgd"):optimizer = torch.optim.SGD(parameters,lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay,nesterov="nesterov" in opt_name,)elif opt_name == "rmsprop":optimizer = torch.optim.RMSprop(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)elif opt_name == "adamw":optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)else:raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")# 6. 构造主学习率调度器:StepLR / CosineAnnealingLR / ExponentialLRargs.lr_scheduler = args.lr_scheduler.lower()if args.lr_scheduler == "steplr":main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)elif args.lr_scheduler == "cosineannealinglr":main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - args.lr_warmup_epochs, eta_min=0.0)elif args.lr_scheduler == "exponentiallr":main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)else:raise RuntimeError(f"Invalid lr scheduler '{args.lr_scheduler}'. Only StepLR, CosineAnnealingLR and ExponentialLR ""are supported.")# 7. 如果设置了 warmup,就把 warmup scheduler 和主 scheduler 串联if args.lr_warmup_epochs > 0:if args.lr_warmup_method == "linear":warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)elif args.lr_warmup_method == "constant":warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs)else:raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported.")# milestones 指在第 args.lr_warmup_epochs 次后切换到 main_lr_schedulerlr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs])else:lr_scheduler = main_lr_scheduler# 8. 开始训练print("Start training on synthetic dataset...")start_time = time.time()pbar = tqdm(range(args.epochs), ncols=100)for epoch in pbar:# 每个 epoch 都调用前面写好的 train_one_epoch(KL 蒸馏)train_one_epoch(model, teacher_model, criterion_kl, optimizer, train_loader, args.device, epoch, args)# 训练完一轮后,调度学习率lr_scheduler.step()# 只在最后 20% 的轮次做验证,节省时间if epoch > args.epochs * 0.8:acc1 = evaluate(model, criterion, test_loader, device=args.device) # 这里 evaluate 用硬标签 loss & 准确率# 更新 best_acc1if acc1 > best_acc1:best_acc1 = acc1# 在进度条上显示当前/最佳准确率pbar.set_description(f"Epoch[{epoch}] Test Acc: {acc1:.2f}% Best Acc: {best_acc1:.2f}%")print(f"Best Accuracy {best_acc1:.2f}%")total_time = time.time() - start_timetotal_time_str = str(datetime.timedelta(seconds=int(total_time)))print(f"Training time {total_time_str}")return model, best_acc1def coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True):"""对全量原始训练集 images_all 用 filter 模型做一次完整的推理:- 如果 get_correct=True,返回“预测正确”的样本索引列表;否则返回“预测错误”的样本索引列表。- 同时返回所有样本的 raw logits(未 softmax)。"""true_labels = labels_all.cpu()filter.eval() # 只做前向,不更新参数logits = None# 分批推理,防止一次OOMfor select_times in range((len(images_all)+batch_size-1)//batch_size):# slice出当前batch的图像# detach 防止梯度追溯,再搬到devicecurrent_data_batch = images_all[batch_size*select_times : batch_size*(select_times+1)].detach().to(args.device)# 前向batch_logits = filter(current_data_batch)# concatenate 到一起if logits == None:logits = batch_logits.detach()else:logits = torch.cat((logits, batch_logits.detach()),0)# 取每行最大值的下标作为预测标签predicted_labels = torch.argmax(logits, dim=1).cpu()# 根据get_correct 选正确或错误的索引target_indices = torch.where(true_labels == predicted_labels)[0] if get_correct else torch.where(true_labels != predicted_labels)[0]target_indices = target_indices.tolist()print('Acc on training set: {:.2f}%'.format(100*len(target_indices)/len(images_all) if get_correct else 100*(1-len(target_indices)/len(images_all))))return target_indices, logitsdef selection_logits(selected_idx, teacher_correct_idx, images_all, labels_all, filter, args):"""用 filter 模型的 logits 做 fine 阶段的选样:- teacher_correct_idx: teacher 在原始训练集上预测正确的样本索引- selected_idx: 已经在前几轮中选过的样本索引,避免重复返回当前轮要新增的选样索引列表"""batch_size = 512true_labels = labels_all.cpu()filter.eval()print('Coarse Filtering...')# --- Coarse 阶段:决定哪些样本进入fine阶段# 如果select_misclassified=True,就filter"预测错误"的样本if args.select_misclassified:target_indices, logits = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=False)else:target_indices, logits = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True)# —— 交叉过滤:只保留 teacher 也预测正确的样本,且去除已选过的# teacher_correct_idx 是 teacher 在原始训练集上预测正确的索引(论文里是“只有 teacher 也能正确的样本才考虑”这一类过滤)。if teacher_correct_idx is not None:# 取 teacher_correct_idx 与 target_indices 的交集,再减去 selected_idxtarget_indices = list(set(teacher_correct_idx) & set(target_indices) - set(selected_idx))else:target_indices = list(set(target_indices) - set(selected_idx))print('Fine Selection...')selection = []if args.balance:# 如果要 class-balance,每个类单独选 args.curpc 个target_idx_per_class = [[] for c in range(args.num_classes)]for idx in target_indices:target_idx_per_class[true_labels[idx]].append(idx)for c in range(args.num_classes):if args.select_method == 'random':# 随机抽样selection += random.sample(target_idx_per_class[c], args.curpc)elif args.select_method == 'hard':# 按 logits[c] 升序,logit 越低表示模型越“不自信” ⇒ “更难”selection += sorted(target_idx_per_class[c], key=lambda i: logits[i][c], reverse=False)[:args.curpc]elif args.select_method == 'simple':# 按 logits[c] 降序,logit 越高表示模型越“自信” ⇒ “简单”selection += sorted(target_idx_per_class[c], key=lambda i: logits[i][c], reverse=True)[:args.curpc]else:# 不做 class-balance,直接在所有 target_indices 中选总数 = curpc * num_classesif args.select_method == 'random':selection = random.sample(target_indices, args.curpc*args.num_classes)elif args.select_method == 'hard':selection = sorted(target_indices, key=lambda i: logits[i][true_labels[i]], reverse=False)[:args.curpc*args.num_classes]elif args.select_method == 'simple':selection = sorted(target_indices, key=lambda i: logits[i][true_labels[i]], reverse=True)[:args.curpc*args.num_classes]return selectiondef selection_score(selected_idx, teacher_correct_idx, images_all, labels_all, filter, score, reverse, args):"""用预先计算好的difficult score 做fine阶段的选样:- score: numpy array, score[i]表示样本i的难度分数- reverse: bool其余流程同 selection_logits,只是排序依据改为 score"""batch_size = 512true_labels = labels_all.cpu()filter.eval()print('Coarse Filtering...')# Coarse 阶段同上if args.select_misclassified:target_indices, _ = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=False)else:target_indices, _ = coarse_filtering(images_all, labels_all, filter, batch_size, args, get_correct=True)# 交叉过滤 teacher_correct_idx & 去除已选if teacher_correct_idx is not None:target_indices = list(set(teacher_correct_idx) & set(target_indices) - set(selected_idx))else:target_indices = list(set(target_indices) - set(selected_idx))print('Fine Selection...')selection = []if args.balance:target_idx_per_class = [[] for c in range(args.num_classes)]for idx in target_indices:target_idx_per_class[true_labels[idx]].append(idx)for c in range(args.num_classes):if args.select_method == 'random':selection += random.sample(target_idx_per_class[c], min(args.curpc, len(target_idx_per_class[c])))elif args.select_method == 'hard':selection += sorted(target_idx_per_class[c], key=lambda i: score[i], reverse=reverse)[:args.curpc]elif args.select_method == 'simple':# 用外部 score(预先计算的 difficulty)selection += sorted(target_idx_per_class[c], key=lambda i: score[i], reverse=not reverse)[:args.curpc]else:if args.select_method == 'random':selection = random.sample(target_indices, min(args.curpc*args.num_classes, len(target_indices)))elif args.select_method == 'hard':selection = sorted(target_indices, key=lambda i: score[i], reverse=reverse)[:args.curpc*args.num_classes]elif args.select_method == 'simple':selection = sorted(target_indices, key=lambda i: score[i], reverse=not reverse)[:args.curpc*args.num_classes]return selectiondef main(args):'''Preparation'''print('=> args.output_dir', args.output_dir)start_time = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')log_dir = os.path.join(args.output_dir, 'CIFAR-100', start_time)os.makedirs(log_dir, exist_ok=True)device = torch.device(args.device)if device.type == 'cuda':print('Using GPU')torch.backends.cudnn.benchmark = True# 计算cpc(合成每类)和spc(要从真实数据选的每类数)args.cpc = int(args.image_per_class * args.alpha) # condensed images per classargs.spc = args.image_per_class - args.cpc # selected real images per classargs.num_classes = 100print('Target IPC: {}, num_classes: {}, distillation portion: {}, distilled images per class: {}, real images to be selected per class: {}'.format(args.image_per_class, args.num_classes, args.alpha, args.cpc, args.spc))# 加载数据dataset_dis, images_og, labels_og, dataset_test, train_sampler, test_sampler = load_data(args)# 加载difficulty scoreif args.score == 'forgetting':score = np.load(args.score_path)reverse = Trueelif args.score == 'cscore':score = np.load(args.score_path)reverse = Falsecurriculum_num = args.curriculum_num# 构造curriculum_arrangement分配:每轮要选多少个真实样本arrangement = curriculum_arrangement(args.spc, curriculum_num)# 加载测试集test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=512, sampler=test_sampler, num_workers=args.workers, pin_memory=True)# 加载教师模型teacher_model = create_model(args.teacher_model, device, args.num_classes, args.teacher_path)# 冻结教师模型参数for p in teacher_model.parameters():p.requires_grad = Falseteacher_model.eval()# 使用教师模型在原始数据集上做一次初筛,只有teacher预测对的样本才可能被选入teacher_correct_idx, _ = coarse_filtering(images_og, labels_og, teacher_model, 512, args, get_correct=True)print('teacher acc@1 on original training data: {:.2f}%'.format(100*len(teacher_correct_idx)/len(images_og)))'''Curriculum selection'''idx_selected = []dataset_sel = Nonedst_sel_transform = transforms.Compose([transforms.RandomResizedCrop(32),transforms.RandomHorizontalFlip(),])print('Selected images per class arrangement in each curriculum: ', arrangement)# 开始课程学习for i in range(curriculum_num):print('----Curriculum [{}/{}]----'.format(i+1, curriculum_num))args.curpc = arrangement[i]# 第0轮以蒸馏合成集为起点if i == 0:print('Begin with distilled dataset')syn_dataset = dataset_disdataset_sel = []print('Synthetic dataset size:', len(syn_dataset), "distilled data:", len(dataset_dis), "selected data:", len(dataset_sel))# 训练一个新的filter(每轮都从头开始训练)filter = create_model(args.filter_model, device, args.num_classes)# TODO:课程训练,教师模型打软标签filter, best_acc1 = curriculum_train(i, syn_dataset, test_loader, filter, teacher_model, args)print('Selecting real data...')if args.score == 'logits':selection = selection_logits(idx_selected, teacher_correct_idx, images_og, labels_og, filter, args)else:selection = selection_score(idx_selected, teacher_correct_idx, images_og, labels_og, filter, score, reverse, args)idx_selected += selectionprint('Selected {} in this curriculum'.format(len(selection)))imgs_select = images_og[idx_selected]labs_select = labels_og[idx_selected]dataset_sel = utils.TensorDataset(imgs_select, labs_select, dst_sel_transform)syn_dataset = torch.utils.data.ConcatDataset([dataset_dis, dataset_sel])print('----All curricula finished----')print('Final synthetic dataset size:', len(syn_dataset), "distilled data:", len(dataset_dis), "selected data:", len(dataset_sel)) print('Saving selected indices...')idx_file = os.path.join(log_dir, f'selected_indices.json')with open(idx_file, 'w') as f:json.dump({'ipc': args.image_per_class,'alpha': args.alpha, 'idx_selected': idx_selected}, f)f.close()'''Final evaluation'''num_eval = args.num_evalaccs = []for i in range(num_eval):print(f'Evaluation {i+1}/{num_eval}')eval_model = create_model(args.eval_model, device, args.num_classes)_, best_acc1 = curriculum_train(0, syn_dataset, test_loader, eval_model, teacher_model, args)accs.append(best_acc1)acc_mean = np.mean(accs)acc_std = np.std(accs)print('----Evaluation Results----')print(f'Acc@1(mean): {acc_mean:.2f}%, std: {acc_std:.2f}')print('Saving results...')log_file = os.path.join(log_dir, f'exp_log.txt')with open(log_file, 'w') as f:f.write('EXP Settings: \n')f.write(f'IPC: {args.image_per_class},\tdistillation portion: {args.alpha},\tcurriculum_num: {args.curriculum_num}\n')f.write(f'filter model: {args.filter_model},\tteacher model: {args.teacher_model},\tbatch_size: {args.batch_size},\tepochs: {args.epochs}\n')f.write(f"coarse stage strategy: {'select misclassified' if args.select_misclassified else 'select correctly classified'}\n")f.write(f'fine stage strategy: {args.select_method},\tdifficulty score: {args.score},\tbalance: {args.balance}\n')f.write(f'eval model: {args.eval_model},\tAcc@1: {acc_mean:.2f}%,\tstd: {acc_std:.2f}\n')f.close()if __name__ == "__main__":args = get_args_parser().parse_args()main(args)
算法整体逻辑
算法输入
- 蒸馏合成集DdistillD_{distill}Ddistill: 已经经过某种蒸馏算法生成的小规模“合成”数据集,每类包含CPC=IPC×αCPC=IPC×\alphaCPC=IPC×α。
- 原始训练集(image_og,labels_og):完整的CIFAR100训练样本,用于选样。
- 教师模型ϕteacherϕ_{teacher}ϕteacher: 在原始训练集上表现优秀的固定模型,用于提供“正确”的软标签参考。
- 超参数: IPCIPCIPC(每类最终图数)、α\alphaα(蒸馏比例)、课程轮数 JJJ、粗筛策略、精筛策略、难度分数类型等。
整体流程
-
初始化
- 解析命令行参数,计算:cpc=⌊IPC×α⌋cpc=⌊IPC×α⌋cpc=⌊IPC×α⌋, spc=IPC×cpcspc=IPC×cpcspc=IPC×cpc。
- 加载DdistillD_{distill}Ddistill、原始训练集和验证集。
- 加载并冻结教师模型ϕteacherϕ_{teacher}ϕteacher,在原始训练集上做一次推理,记录教师预测正确的索引集合 IteacherI_{teacher}Iteacher。
-
课程分配
将每类总共要选的spcspcspc张真实图,均匀分配到JJJ轮: [k1,k2,…,kJ],∑jkj=spc[k_1,k_2,…,k_J],\sum_{j}{k_j} =spc[k1,k2,…,kJ],∑jkj=spc
-
多轮“粗-细”选样循环
令当前集合S0=DdistillS_0=D_{distill}S0=Ddistill,已选索引集合Isel=∅I_{sel}=∅Isel=∅
对每个课程阶段j=1...Jj=1...Jj=1...J:
i. 蒸馏训练Filter:
- 在Sj−1S_{j-1}Sj−1上,用教师模型的“软标签”蒸馏训练一个新的filter模型ϕjϕ_jϕj
ii. Coarse(粗过滤)
- 用ϕjϕ_jϕj在整个原始训练集上做推理,得到所有样本的logits和预测标签。
- 根据
select_misclassified
决定保留“错分”样本索引,或保留“分对”样本索引,记为候选集 CCC。 - 交叉过滤:仅保留既在 CCC 中、又在IteacherI_{teacher}Iteacher中,且不在 IselI_{sel}Isel 中的索引。
iii. Fine(精细选择)
- 在上述候选索引里,依据“logits”或外部
pre-computed difficulty score
,对每个索引排序,规则可以是simple、hard、random - 如选
--balance
,则每类各取kjk_jkj张;否则全体一并取总数 kj×num_classesk_j×num\_classeskj×num_classes。 - 将本轮选中的新索引加入IselI_{sel}Isel。
iv. 更新合成集
Sj←Ddistill∪{真实样本i:i∈Isel}.S_j←D_{distill}∪\{真实样本 i:i∈I_{sel}\}. Sj←Ddistill∪{真实样本i:i∈Isel}. -
保存 & 最终评估
- 将IselI_{sel}Isel输出到Json以供后续复现
- 用最终的混合集SjS_jSj训练一个新的evaluation模型,测多次Top-1准确率,取均值与标准差。
这个流程确保每轮都针对模型“真正没有学到”的部分,有序补充,最终合成集既覆盖常见知识,也涵盖关键难点