R3GAN训练自己的数据集

简介

简介:这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。

论文题目:The GAN is dead; long live the GAN! A Modern Baseline GAN

会议:NeurIPS 2024

源码地址:https://www.github.com/brownvc/R3GAN

本文在调试代码的时候对代码做了一些修改,如果有遇到报错的问题可以直接复制我这篇博客修改后的代码:R3GAN利用配置好的Pytorch训练自己的数据集-CSDN博客这篇论文挑战了"GANs难以训练"的广泛观点,通过提出一个更稳定的损失函数和现代化的网络架构,构建了一个简洁而高效的GAN基线模型R3GAN。作者证明了通过合适的理论基础和架构设计,GANs可以稳定训练并达到优异性能。 https://blog.csdn.net/LJ1147517021/article/details/148315781?fromshare=blogdetail&sharetype=blogdetail&sharerId=148315781&sharerefer=PC&sharesource=LJ1147517021&sharefrom=from_link

摘要:论文反驳了GANs难以训练的普遍观点,提出了一个理论有保障的现代GAN基线。首先,推导出一个良好行为的正则化相对论GAN损失函数,解决了模式丢弃和不收敛问题,并数学证明了其局部收敛性。其次,该损失函数允许丢弃所有经验性技巧,用现代架构替换常见GANs中的过时骨干网络。以StyleGAN2为例,展示了简化和现代化的路线图,产生了新的极简基线R3GAN。尽管简单,该方法在FFHQ、ImageNet、CIFAR和Stacked MNIST数据集上超越了StyleGAN2,与最先进的GANs和扩散模型相比表现优异。

模型结构

生成器架构

核心设计原则:

  • 基于现代化ResNet架构,摒弃VGG-like设计
  • 每个分辨率阶段包含一个过渡层和两个残差块
  • 采用分组卷积和倒置瓶颈设计

关键特性:

  • 无归一化层:避免批量归一化等数据相关的归一化
  • Fix-up初始化:零初始化每个残差块的最后一层卷积
  • 双线性插值:用于上采样,避免棋盘效应

鉴别器架构

设计特点:

  • 与生成器完全对称的架构
  • 相同的残差块结构和过渡层设计
  • 分类器头:全局4×4深度卷积 + 线性层

损失函数

相对论配对GAN损失 (RpGAN):

L(θ,ψ) = E[f(D_ψ(G_θ(z)) - D_ψ(x))]

R1正则化:

R1(ψ) = (γ/2) * E[||∇_x D_ψ(x)||²]  (x~p_D)

R2正则化:

R2(θ,ψ) = (γ/2) * E[||∇_x D_ψ(x)||²]  (x~p_θ)

训练自己的数据集

1. 准备数据集

首先使用 dataset_tool.py 将您的图像数据转换为适合训练的格式:

# 从文件夹创建数据集
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip# 如果需要调整分辨率和裁剪
python dataset_tool.py --source=path/to/your/images --dest=path/to/output.zip \--resolution=256x256 --transform=center-crop

数据集要求:

  • 图像必须是正方形(如256x256, 512x512)
  • 分辨率必须是2的幂次(64, 128, 256, 512, 1024等)
  • 支持RGB或灰度图像
  • 可以是文件夹或ZIP格式

2. 创建自定义训练配置

train.py 中添加您自己的预设配置。参考现有预设,在 main() 函数中添加:

if opts.preset == 'YOUR_DATASET':# 网络架构参数WidthPerStage = [768, 768, 768, 512, 256]  # 每阶段宽度BlocksPerStage = [2, 2, 2, 2, 2]           # 每阶段块数CardinalityPerStage = [96, 96, 96, 48, 24] # 每阶段基数FP16Stages = [-1, -2, -3, -4]              # FP16优化的阶段NoiseDimension = 64                         # 噪声维度# 如果是条件生成(有类别标签)if opts.cond:c.G_kwargs.ConditionEmbeddingDimension = NoiseDimensionc.D_kwargs.ConditionEmbeddingDimension = WidthPerStage[0]# 训练调度参数ema_nimg = 500 * 1000      # EMA开始的图像数decay_nimg = 2e7           # 总衰减图像数# 各种调度器c.ema_scheduler = { 'base_value': 0, 'final_value': ema_nimg, 'total_nimg': decay_nimg }c.aug_scheduler = { 'base_value': 0, 'final_value': 0.3, 'total_nimg': decay_nimg }c.lr_scheduler = { 'base_value': 2e-4, 'final_value': 5e-5, 'total_nimg': decay_nimg }c.gamma_scheduler = { 'base_value': 2, 'final_value': 0.2, 'total_nimg': decay_nimg }c.beta2_scheduler = { 'base_value': 0.9, 'final_value': 0.99, 'total_nimg': decay_nimg }

3. 开始训练

# 无条件生成(如人脸、风景等)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200# 条件生成(有类别标签)
python train.py \--outdir=./training-runs \--data=./datasets/your_dataset.zip \--gpus=4 \--batch=256 \--mirror=1 \--aug=1 \--cond=1 \--preset=YOUR_DATASET \--tick=1 \--snap=200

4. 参数说明

  • --gpus: GPU数量
  • --batch: 总批次大小
  • --mirror: 是否启用水平翻转增强
  • --aug: 是否启用数据增强
  • --cond: 是否训练条件模型(需要标签)
  • --tick: 多少kimg输出一次进度
  • --snap: 多少tick保存一次模型

5. 生成图像

训练完成后,使用保存的模型生成图像:

# 生成8张图像
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl# 条件生成(指定类别)
python gen_images.py \--seeds=0-7 \--outdir=generated_images \--class=5 \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

6. 评估指标

python calc_metrics.py \--metrics=fid50k_full,kid50k_full \--data=./datasets/your_dataset.zip \--network=training-runs/xxxxx-your_dataset/network-snapshot-xxxxx.pkl

7.报错指南

1.UnboundLocalError: local variable 'NoiseDimension' referenced before assignment

解决办法:在 train.py 中,NoiseDimension 只在特定的预设配置块中定义(如 CIFAR10、FFHQ-64 等)。如果您使用的 --preset 参数不匹配任何现有预设,这个变量就不会被定义,导致使用时出错。可以使用作者定义好的预先设置。

--preset=CIFAR10
--preset=FFHQ-64  
--preset=FFHQ-256
--preset=ImageNet-32
--preset=ImageNet-64

2.RuntimeError: Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "R3GAN\torch_utils\custom_ops.py".

解决办法:这个错误是因为R3GAN使用了自定义的CUDA操作符,需要C++编译器来编译。在Windows系统上缺少MSVC/GCC/CLANG编译器。

修改 torch_utils/custom_ops.py:找到 get_plugin 函数(大约第84行),在函数开头添加:

def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):# 禁用所有自定义插件return Nonedef bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):# 强制使用 'ref' 实现impl = 'ref'

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/82937.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【PhysUnits】15.1 引入P1后的加一特质(add1.rs)

一、源码 代码实现了类型系统中的"加一"操作(Add1 trait),用于在编译期进行数字的增量计算。 //! 加一操作特质实现 / Increment operation trait implementation //! //! 说明: //! 1. Z0、P1,、N1 1&#xff0…

记录算法笔记(2025.5.29)最小栈

设计一个支持 push ,pop ,top 操作,并能在常数时间内检索到最小元素的栈。 实现 MinStack 类: MinStack() 初始化堆栈对象。void push(int val) 将元素val推入堆栈。void pop() 删除堆栈顶部的元素。int top() 获取堆栈顶部的元素。int get…

Android高级开发第一篇 - JNI(初级入门篇)

文章目录 Android高级开发JNI开发第一篇(初级入门篇)🧠 一、什么是 JNI?✅ 为什么要用 JNI? ⚙️ 二、开发环境准备开发工具 🚀 三、创建一个支持 JNI 的 Android 项目第一步:创建新项目项目结构…

PyTorch Image Models (timm) 技术指南

timm PyTorch Image Models (timm) 技术指南功能概述 一、引言二、timm 库概述三、安装 timm 库四、模型加载与推理示例4.1 通用推理流程4.2 具体模型示例4.2.1 ResNeXt50-32x4d4.2.2 EfficientNet-V2 Small 模型4.2.3 DeiT-3 large 模型4.2.4 RepViT-M2 模型4.2.5 ResNet-RS-1…

openEuler安装MySql8(tar包模式)

操作系统版本: openEuler release 22.03 (LTS-SP4) MySql版本: 下载地址: https://dev.mysql.com/downloads/mysql/ 准备安装: 关闭防火墙: 停止防火墙 #systemctl stop firewalld.service 关闭防火墙 #systemc…

从零开始的数据结构教程(六) 贪心算法

🍬 标题一:贪心核心思想——发糖果时的最优分配策略 贪心算法 (Greedy Algorithm) 是一种简单直观的算法策略。它在每一步选择中都采取在当前状态下最好或最优(即最有利)的选择,从而希望得到一个全局最优解。这就像你…

CPP中CAS std::chrono 信号量与Any类的手动实现

前言 CAS(Compare and Swap) 是一种用于多线程同步的原子指令。它通过比较和交换操作来确保数据的一致性和线程安全性。CAS操作涉及三个操作数:内存位置V、预期值E和新值U。当且仅当内存位置V的值与预期值E相等时,CAS才会将内存位…

Axure设计案例——科技感对比柱状图

想让数据对比展示摆脱平淡无奇,瞬间抓住观众的眼球吗?那就来看看这个Axure设计的科技感对比柱状图案例!科技感设计风格运用独特元素打破传统对比柱状图的常规,营造出一种极具冲击力的视觉氛围。每一组柱状体都仿佛是科技战场上的士…

怒更一波免费声音克隆和AI配音功能

宝子们! 最近咱软件TransDuck的免费声音克隆和AI配音功能被大家用爆啦!感谢各位自来水疯狂安利!! DD这里也是收到好多用户提的宝贵建议!所以,连夜肝了波更新! 这次重点更新使用克隆音色进行A…

UDP协议原理与Java编程实战:无连接通信的奥秘

1.UDP协议核心原理 1. 无连接特性:快速通信的基石 UDP(User Datagram Protocol,用户数据报协议)是TCP/IP协议族中无连接的轻量级传输层协议。与TCP的“三次握手”建立连接不同,UDP通信无需提前建立链路,发送…

vue-seamless-scroll 结束从头开始,加延时后滚动

今天遇到一个大屏需求: 1️⃣初始进入页面停留5秒,然后开始滚动 2️⃣最后一条数据出现在最后一行时候暂停5秒,然后返回1️⃣ 依次循环,发现vue-seamless-scroll的方法 ScrollEnd是监测最后一条数据消失在第一行才回调&#xff…

[Protobuf] 快速上手:安全高效的序列化指南

标题:[Protobuf] (1)快速上手 水墨不写bug 文章目录 一、什么是protobuf?二、protobuf的特点三、使用protobuf的过程?1、定义消息格式(.proto文件)(1)指定语法版本(2)package 声明符 2、使用protoc编译器生成代码&…

uniapp调用java接口 跨域问题

前言 之前在Windows10本地 调试一个旧项目,手机移动端用的是Uni-app,vue的版本是v2。后端是java spring-boot。运行手机移动端的首页请求后台接口老是提示错误信息。 错误信息如下: Access to XMLHttpRequest at http://localhost:8080/api/…

[ Qt ] | Qlabel使用

目录 属性 setTextFormat 插入图片 设置图片根据窗口大小实时变化 边框和对其方式 ​编辑 设置缩进 设置伙伴 Qlabel可以用来显式图片和文字 属性 text textFormat Qlabel独有的机制:buddy setTextFormat 插入图片 设置图片根据窗口大小实时变化 Qt中表…

Springboot 项目一启动就获取HttpSession

在 Spring Boot 项目中,HttpSession 是有状态的,通常只有在用户发起 HTTP 请求并建立会话后才会创建。因此,在项目启动时(即应用刚启动还未处理任何请求)是无法获取到 HttpSession 的。 方法一:使用 HttpS…

Step9—Ambari Web UI 初始化安装 (Ambari3.0.0)

Ambari Web UI 安装 如果还不会系统性的部署,或者前置内容不熟悉,建议从Step1 开始阅读。不通版本针对于不同操作系统可能存在差异!这里我也整理好了 https://doc.janettr.com/install/manual/ 1. 进入 Ambari Web UI 并登录 在浏览器中访…

热门大型语言模型(LLM)应用开发框架

我们来深入探索这些强大的大型语言模型(LLM)应用开发框架,并且我会尝试用文本形式描述一些核心的流程图,帮助您更好地理解它们的工作机制。由于我无法直接生成图片,我会用文字清晰地描述流程图的各个步骤和连接。 Lang…

机器学习数据降维方法

1.数据类型 2.如何选择降维方法进行数据降维 3.线性降维:主成分分析(PCA)、线性判别分析(LDA) 4.非线性降维 5.基于特征选择的降维 6.基于神经网络的降维 数据降维是将高维数据转换为低维表示的过程,旨在保…

太阳系运行模拟程序-html动画

太阳系运行模拟程序-html动画 by AI: <!DOCTYPE html> <html lang"zh"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>交互式太阳系…

2025年全国青少年信息素养大赛 scratch图形化编程挑战赛 小低组初赛 内部集训模拟题解析

2025年信息素养大赛初赛scratch模拟题解析 博主推荐 所有考级比赛学习相关资料合集【推荐收藏】 scratch资料 Scratch3.0系列视频课程资料零基础学习scratch3.0【入门教学 免费】零基础学习scratch3.0【视频教程 114节 免费】 历届蓝桥杯scratch国赛真题解析历届蓝桥杯scr…