PyTorch图像预处理全解析(transforms)

1. 引言

在深度学习计算机视觉任务中,数据预处理和数据增强是模型训练的关键步骤,直接影响模型的泛化能力和最终性能表现。PyTorch 提供的 torchvision.transforms 模块,封装了丰富的图像变换方法,能够高效地完成图像标准化、裁剪、翻转等操作。该模块支持两种主要的使用方式:单步变换(Single Transform)和组合变换(Compose),可以灵活应对不同场景下的图像处理需求。

本文将详细解析 transforms 的核心 API、参数含义,并通过完整代码示例演示其使用方法。主要内容包括:

  1. 基础变换操作

    • 尺寸调整:Resize(target_size)
    • 随机裁剪:RandomCrop(size, padding=None, pad_if_needed=False)
    • 中心裁剪:CenterCrop(size)
    • 随机水平/垂直翻转:RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5)
  2. 颜色空间变换

    • 颜色抖动:ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
    • 随机灰度化:RandomGrayscale(p=0.1)
    • 高斯模糊:GaussianBlur(kernel_size, sigma=(0.1, 2.0))
  3. 数据标准化

    • 归一化:Normalize(mean, std)
    • 张量转换:ToTensor()
  4. 实用组合方法

    • 变换链:Compose([transforms1, transforms2,...])
    • 随机选择:RandomApply(transforms, p=0.5)
    • 随机排序:RandomOrder(transforms)

以图像分类任务为例,一个典型的数据增强流程可能如下:

from torchvision import transformstrain_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

其中,训练集使用更丰富的增强策略以提高模型鲁棒性,而验证集则采用较简单的预处理保持数据原始分布。通过合理配置这些变换参数,可以显著提升模型在各种视觉任务(如图像分类、目标检测、语义分割等)中的表现。


2. transforms 概述

transforms 是 PyTorch 生态系统中 torchvision 库的核心模块之一,专门用于计算机视觉任务中的图像数据处理。它提供了丰富的图像变换方法,主要分为三大类功能:

  1. 图像预处理

    • 尺寸调整:transforms.Resize() 可将图像统一缩放到指定尺寸(如 256x256)
    • 归一化:transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 使用 ImageNet 的均值和标准差进行标准化
    • 中心裁剪:transforms.CenterCrop(224) 从图像中心裁剪出指定大小的区域
  2. 数据增强(常用于训练阶段防止过拟合):

    • 随机裁剪:transforms.RandomCrop(224) 在随机位置裁剪
    • 颜色变换:transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
    • 随机水平翻转:transforms.RandomHorizontalFlip(p=0.5)
    • 随机旋转:transforms.RandomRotation(degrees=15)
  3. 格式转换

    • PIL图像转张量:transforms.ToTensor() 将图像转换为 PyTorch 张量(并自动将像素值归一化到 [0,1])
    • 张量转PIL图像:transforms.ToPILImage()

组合使用示例

from torchvision import transforms# 训练阶段的变换流水线
train_transform = transforms.Compose([transforms.Resize(256),              # 缩放至256x256transforms.RandomCrop(224),          # 随机裁剪224x224transforms.RandomHorizontalFlip(),   # 随机水平翻转transforms.ToTensor(),               # 转为张量transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化std=[0.229, 0.224, 0.225])
])# 验证阶段的变换流水线(通常不包含随机增强)
val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

在实际应用中,这些变换可以显著提升模型的泛化能力,特别是在数据量有限的情况下。对于不同的计算机视觉任务(如图像分类、目标检测等),可以根据具体需求组合不同的变换操作。


3. 核心 API 详解

3.1 基础变换

(1) Resize(size)
  • 功能:调整图像尺寸。

  • 参数

    • size (int or tuple):目标尺寸。如果是 int,短边缩放至该值,长边按比例调整;如果是 (h, w),则强制缩放到指定大小。

  • 示例

transform = transforms.Resize(256)  # 短边缩放到256,长边按比例调整
transform = transforms.Resize((224, 224))  # 强制缩放到224x224
(2) CenterCrop(size)
  • 功能:从图像中心裁剪指定大小的区域。

  • 参数

    • size (int or tuple):裁剪尺寸(int 表示正方形,(h, w) 表示矩形)。

  • 示例

transform = transforms.CenterCrop(224)  # 裁剪224x224的正方形
(3) RandomCrop(size)
  • 功能:随机位置裁剪图像。

  • 参数

    • size (int or tuple):裁剪尺寸。

    • padding (int or tuple, optional):填充边缘(防止裁剪过小)。

  • 示例

transform = transforms.RandomCrop(224, padding=10)  # 随机裁剪224x224,边缘填充10像素
(4) RandomHorizontalFlip(p=0.5)
  • 功能:以概率 p 水平翻转图像(默认 p=0.5)。

  • 示例

transform = transforms.RandomHorizontalFlip(p=0.7)  # 70%概率水平翻转
(5) RandomRotation(degrees)
  • 功能:随机旋转图像。

  • 参数

    • degrees (float or tuple):旋转角度范围(如 30 表示 [-30°, 30°](10, 30) 表示 [10°, 30°])。

  • 示例

transform = transforms.RandomRotation(30)  # 随机旋转 ±30°

3.2 张量转换 & 标准化

(1) ToTensor()
  • 功能

    • 将 PIL.Image 或 numpy.ndarray 转换为 torch.Tensor[C, H, W] 格式)。

    • 像素值从 [0, 255] 缩放到 [0.0, 1.0]

  • 示例

transform = transforms.ToTensor()  # 转换为张量
(2) Normalize(mean, std)
  • 功能:对张量进行标准化(逐通道计算:(x - mean) / std)。

  • 参数

    • mean (list):各通道均值(如 ImageNet 的 [0.485, 0.456, 0.406])。

    • std (list):各通道标准差(如 ImageNet 的 [0.229, 0.224, 0.225])。

  • 示例

transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)

3.3 颜色变换

(1) ColorJitter
  • 功能:随机调整亮度、对比度、饱和度和色相。

  • 参数说明:

  • brightness (float 或 tuple):亮度调整范围

    • 当输入为单个浮点数时(如 0.2),表示亮度调整范围为 [1-0.2, 1+0.2] = [0.8, 1.2]
    • 当输入为元组时(如 (0.7, 1.3)),表示直接指定亮度调整范围
    • 示例:brightness=0.5 表示图片亮度将在原始值的50%-150%之间随机调整
  • contrast (float 或 tuple):对比度调整范围

    • 调节方式与brightness相同
    • 示例:contrast=(0.8, 1.5) 表示对比度将在原始值的80%-150%之间随机调整
    • 应用场景:

    • 这些参数常用于图像增强和数据增强任务
    • 在训练深度学习模型时,随机调整这些参数可以增加训练数据的多样性
    • 每个参数的调整都是在指定范围内随机取值,而不是固定值
    • saturation (float 或 tuple):饱和度调整范围

      • 调节方式与brightness相同
      • 示例:saturation=0.3 表示饱和度将在原始值的70%-130%之间随机调整
    • hue (float 或 tuple):色相调整范围

      • 当输入为单个浮点数时(如 0.1),表示色相调整范围为 [-0.1, 0.1]
      • 当输入为元组时(如 (-0.2, 0.3)),表示直接指定色相调整范围
      • 注意:色相值通常以弧度表示,范围一般为[-0.5, 0.5]
      • 示例:hue=0.05 表示色相将在[-0.05, 0.05]范围内随机调整
  • 示例

transform = transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1
)
(2) Grayscale(num_output_channels=1)
  • 功能:将图像转为灰度图。

  • 参数

    • num_output_channels:输出通道数(1 或 3)。

  • 示例

transform = transforms.Grayscale(num_output_channels=3)  # 转为3通道灰度图

4. 完整代码示例

4.1 定义训练和测试的变换

from torchvision import transforms# 训练集变换(含数据增强)
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),      # 随机缩放裁剪至224x224transforms.RandomHorizontalFlip(),      # 50%概率水平翻转transforms.ColorJitter(                 # 随机颜色调整brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),                 # 转为张量 [C, H, W], 值范围[0, 1]transforms.Normalize(                  # 标准化(ImageNet参数)mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 测试集变换(仅预处理)
test_transform = transforms.Compose([transforms.Resize(256),                # 短边缩放到256transforms.CenterCrop(224),            # 中心裁剪224x224transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

4.2 应用到数据集 

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader# 加载CIFAR10数据集(应用变换)
train_dataset = CIFAR10(root='./data', train=True, transform=train_transform,  # 应用训练变换download=True
)test_dataset = CIFAR10(root='./data', train=False, transform=test_transform,   # 应用测试变换download=True
)# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

5. 总结

使用 Compose 可以方便地组合多个变换操作,这些变换会按照添加顺序依次执行。例如:

transforms.Compose([transforms.Resize(256),          # 调整图像大小transforms.RandomCrop(224),      # 随机裁剪transforms.ToTensor(),           # 转换为张量transforms.Normalize(            # 标准化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

在实际应用中,训练和测试阶段通常采用不同的转换策略:

标准化(Normalize)是一个关键步骤,它能:

当使用预训练模型时,应该采用该模型训练时使用的均值和标准差(常见的是 ImageNet 的统计值:mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])。如果不使用预训练模型,可以计算自己数据集的统计值进行标准化。

  • PyTorch 中的 transforms 模块是计算机视觉任务中图像处理的核心工具,它提供了一系列用于图像预处理、数据增强和数据类型转换的功能。这些转换操作可以高效地将原始图像数据转换为适合深度学习模型训练的格式。

    主要功能包括:

  • 预处理:如图像大小调整(Resize)、中心裁剪(CenterCrop)、转换为张量(ToTensor)等基础操作
  • 数据增强:训练时增加数据多样性的随机变换,如随机水平翻转(RandomHorizontalFlip)、随机旋转(RandomRotation)
  • 张量转换:将 PIL 图像或 numpy 数组转换为 PyTorch 张量,并进行数值归一化等操作
  • 训练阶段:建议使用数据增强来提升模型泛化能力,常用增强方法包括:
    • RandomCrop:随机裁剪图像
    • ColorJitter:随机调整亮度、对比度、饱和度
    • RandomHorizontalFlip:随机水平翻转
    • RandomRotation:随机旋转
  • 测试阶段:通常只需基础预处理,如固定大小的裁剪和标准化
  • 将输入数据缩放到相近的数值范围
  • 加速模型收敛过程
  • 提高训练稳定性

掌握 transforms 的使用,可以显著提升计算机视觉任务的效率和模型性能! 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89515.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89515.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

slam中的eskf观测矩阵推导

在之前的《slam中的eskf推导》一文中,没有写观测矩阵 H 矩阵的过程,现在补上这部分。前置列举几个等下推导需要用到的一些点:平面特征点构造观测矩阵例如在 fastlio 中,是利用平面特征点到拟合平面的距离来构造观测方程&#xff0…

Python_2

逻辑判断 首先得首先&#xff0c;我们想判断一个逻辑的正确与否&#xff0c;一定是需要一个能够表现出逻辑的词 如果我只说一个1 2&#xff0c;那么大家都不知道我在说什么但是如果我说1<2,那么大家就能判断这个语句的正确与否了 下面是几个常用的逻辑词 < 小于>大于&…

Liunx-Lvs配置项目练习

1.实验环境配置Lvs调度器有两块网卡 一块仅主机和一块nat网卡&#xff0c;客户端nat模式&#xff0c;两台服务器为仅主机模式2.集群和分布式简介集群与分布式系统简介集群 (Cluster)集群是指将多台计算机(通常为同构的)通过高速网络连接起来&#xff0c;作为一个整体对外提供服…

T5(Text-to-Text Transfer Transformer) 模型

下面是对 T5&#xff08;Text-to-Text Transfer Transformer&#xff09; 模型的详细介绍&#xff0c;包括其原理、架构、训练方式、优势与局限&#xff0c;以及与其他模型&#xff08;如 BERT、GPT&#xff09;的对比。一、T5 是什么&#xff1f;T5&#xff08;Text-to-Text T…

PostgreSQL技术大讲堂 - 第97讲:PG数据库编码和区域(locale)答疑解惑

PostgreSQL从入门到精通系列课程&#xff0c;近100节PG技术讲解&#xff0c;让你从小白一步步成长为独当一面的PG专业人员&#xff0c;点击这里查看章节内容。 PostgreSQL从入门到精通课程&#xff0c;持续更新&#xff0c;欢迎加入。第97讲&#xff1a;PostgreSQL 数据库编码…

【IEEE独立出版 】第六届机器学习与计算机应用国际学术会议(ICMLCA 2025)

第六届机器学习与计算机应用国际学术会议&#xff08;ICMLCA 2025&#xff09; 大会简介 第六届机器学习与计算机应用国际学术会议(ICMLCA 2025)定于2025年10月17-19日在中国深圳隆重举行。本届会议将主要关注机器学习和计算机应用面临的新的挑战问题和研究方向&#xff0c;着力…

对于编码电机-520直流减速电机

编码电机的介绍 编码器是一种将角位移或者直线位移转换成一连串电数字脉冲的一种传感器。我们可以通过编码器测量电机转动的位移或者速度信息。 编码器按照工作原理&#xff0c;可以分为增量式编码器和绝对式编码器&#xff0c;绝对式编码器的每一个位置对应一个确定的数字码&a…

Rust入门之并发编程基础(三)

Rust入门之并发编程基础&#xff08;三&#xff09; 题记&#xff1a;6月底7月初&#xff0c;结束北京的工作生活回到二线省会城市发展了&#xff0c;鸽了较久了&#xff0c;要继续坚持学习Rust&#xff0c;坚持写博客。 背景 我们平时使用计算机完成某项工作的时候&#xf…

一文读懂循环神经网络—深度循环神经网络(DRNN)

目录 一、从 RNN 到 DRNN&#xff1a;为什么需要 “深度”&#xff1f; 二、DRNN 的核心结构 1. 时间维度&#xff1a;循环传递 2. 空间维度&#xff1a;多层隐藏层 3. 双向 DRNN&#xff08;Bidirectional DRNN&#xff09; 三、DRNN 的关键挑战与优化 1. 梯度消失 / 爆…

磁悬浮轴承系统中由不平衡力引发的恶性循环机制深度解析

磁悬浮轴承系统中由不平衡力引发的 “振动-激励-更大振动”恶性循环 是一个典型的 正反馈失控过程,其核心在于 传感器信号的污染 与 控制器对真实位移的误判。以下是其逐步演进的原理详解: 恶性循环的触发与演进 1:不平衡力的产生(根源) 转子存在质量偏心,质心(CM)偏离…

优迅股份IPO隐忧:毛利水平“两连降”,研发费用率不及行业均值

撰稿|行星来源|贝多财经近日&#xff0c;厦门优迅芯片股份有限公司&#xff08;下称“优迅股份”&#xff09;的科创板IPO审核状态变更为“已问询”&#xff0c;中信证券为其保荐机构。天眼查App信息显示&#xff0c;优迅股份成立于2003年2月&#xff0c;是中国首批专业从事光通…

Linux探秘坊-------15.线程概念与控制

1.线程概念 1.什么是线程2.线程 vs 进程不同的操作系统有不同的实现方式&#xff1a; linux &#xff1a;直接使用pcb的功能来模拟线程&#xff0c;不创建新的数据结构windows&#xff1a; 使用新的数据结构TCB&#xff0c;来进行实现&#xff0c;一个PCB里有很多个TCB 3.资源划…

Github库镜像到本地私有Gitlab服务器

上一节我们看了如何架设自己的Gitlab服务器&#xff0c;今天我们看怎么把Github库转移到自己的Gitlab上。 首先登录github&#xff0c;进入自己的库复制地址。 克隆镜像库 在本地新建一个文件夹 在文件夹执行CMD指令 git clone --mirror gitgithub.com:thinbug/A.git–mirror参…

【C++】——类和对象(中)——默认成员函数

一、类的默认成员函数默认成员函数就是用户没有显示实现&#xff0c;不过编译器会自动生成的成员函数&#xff0c;称为默认成员函数。一个类默认成员函数一共有6个&#xff0c;在我们不写的情况下&#xff0c;编译器就会自动生成这6个成员函数&#xff0c;不过我们重点要学习的…

MATLAB知识点总结

1.将A图与B图相同范围内归一化显示在同一个figure上&#xff1a; figure, plot(A(150:450,500)/max(A(150:450,500))) hold on plot(D(150:450,500)/max(D(150:450,500)),‘R’) 将两幅图像的一定范围显示在同一图像上。 figure,plot(A(350,100:450)) hold on plot(G(350,100:4…

易天光通信10G SFP+ 1550nm 120KM 双纤光模块:远距离传输的实力担当

目录 前言 一、10G SFP双纤光模块概述 二、易天10G SFP 120KM 双纤光模块核心优势与应用 核心优势&#xff1a; 主要关键应用如下&#xff1a; 三、易天10G SFP 120KM 双纤光模块客户优势 总结 关于易天 前言 在构建高效稳定的网络架构时&#xff0c;10G SFP 光模块 12…

【深度学习】神经网络 批量标准化-part6

九、批量标准化是一种广泛使用的神经网络正则化技术&#xff0c;对每一层的输入进行标准化&#xff0c;进行缩放和平移&#xff0c;目的是加速训练&#xff0c;提高模型稳定性和泛化能力&#xff0c;通常在全连接层或是卷积层之和&#xff0c;激活函数之前使用核心思想对每一批…

【数据可视化-67】基于pyecharts的航空安全深度剖析:坠毁航班数据集可视化分析

&#x1f9d1; 博主简介&#xff1a;曾任某智慧城市类企业算法总监&#xff0c;目前在美国市场的物流公司从事高级算法工程师一职&#xff0c;深耕人工智能领域&#xff0c;精通python数据挖掘、可视化、机器学习等&#xff0c;发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…

【科研绘图系列】R语言绘制分组箱线图

文章目录 介绍 加载R包 数据下载 导入数据 画图1 画图2 合并图 系统信息 参考 介绍 【科研绘图系列】R语言绘制分组箱线图 加载R包 library(ggplot2) library(patchwork)rm(list = ls()) options(stringsAsFactors = F)

基于Android的旅游计划App

项目介绍系统打开进入登录页面&#xff0c;如果没有注册过账号&#xff0c;点击注册按钮输入账号、密码、邮箱即可注册&#xff0c;注册后可登录进入系统&#xff0c;系统分为首页、预订、我的三大模块&#xff0c;下面具体详细说说三大模块功能说明。1.首页显示旅游备忘或旅游…