【过拟合和欠拟合】——深度学习.全连接神经网络

目录

1 概念认知

1.1 过拟合

1.2 欠拟合

1.3 如何判断

2 解决欠拟合

3 解决过拟合

3.1 L2正则化

3.1.1 数学表示

3.1.2 梯度更新

3.1.3 作用

3.1.4 代码实现

3.2 L1正则化

3.2.1 数学表示

3.2.2 梯度更新

3.2.3 作用

3.2.4 与L2对比

3.2.5 代码实现

3.3 Dropout

3.3.1 基本实现

3.3.2 权重影响

3.4 数据增强

3.4.1 图片缩放

3.4.2 随机裁剪

3.4.3 随机水平翻转

3.4.4 调整图片颜色

3.4.5 随机旋转

3.4.6 图片转Tensor

3.4.7 Tensor转图片

3.4.8 归一化

3.4.9 数据增强整合


        在训练深层神经网络时,由于模型参数较多,在数据量不足时很容易过拟合。而正则化技术主要就是用于防止过拟合,提升模型的泛化能力(对新数据表现良好)和鲁棒性(对异常数据表现良好)。

1 概念认知

1.1 过拟合

过拟合是指模型对训练数据拟合能力很强并表现很好,但在测试数据上表现较差。

过拟合常见原因有:

  1. 数据量不足:当训练数据较少时,模型可能会过度学习数据中的噪声和细节。

  2. 模型太复杂:如果模型很复杂,也会过度学习训练数据中的细节和噪声。

  3. 正则化强度不足:如果正则化强度不足,可能会导致模型过度学习训练数据中的细节和噪声。

举个例子:

1.2 欠拟合

欠拟合是由于模型学习能力不足,无法充分捕捉数据中的复杂关系。

1.3 如何判断

过拟合

训练误差低,但验证时误差高。模型在训练数据上表现很好,但在验证数据上表现不佳,说明模型可能过度拟合了训练数据中的噪声或特定模式。

欠拟合

训练误差和测试误差都高。模型在训练数据和测试数据上的表现都不好,说明模型可能太简单,无法捕捉到数据中的复杂模式。

2 解决欠拟合

欠拟合的解决思路比较直接:

  1. 增加模型复杂度:引入更多的参数、增加神经网络的层数或节点数量,使模型能够捕捉到数据中的复杂模式。

  2. 增加特征:通过特征工程添加更多有意义的特征,使模型能够更好地理解数据。

  3. 减少正则化强度:适当减小 L1、L2 正则化强度,允许模型有更多自由度来拟合数据。

  4. 训练更长时间:如果是因为训练不足导致的欠拟合,可以增加训练的轮数或时间.

3 解决过拟合

避免模型参数过大是防止过拟合的关键步骤之一。

模型的复杂度主要由权重w决定,而不是偏置b。偏置只是对模型输出的平移,不会导致模型过度拟合数据。

怎么控制权重w,使w在比较小的范围内?

考虑损失函数,损失函数的目的是使预测值与真实值无限接近,如果在原来的损失函数上添加一个非0的变量


L_1(\hat{y},y) = L(\hat{y},y) + f(w)

其中f(w)是关于权重w的函数,f(w)>0

要使L1变小,就要使L变小的同时,也要使f(w)变小。从而控制权重w在较小的范围内。

3.1 L2正则化

L2 正则化通过在损失函数中添加权重参数的平方和来实现,目标是惩罚过大的参数值。

3.1.1 数学表示

设损失函数为 L(\theta),其中 \theta 表示权重参数,加入L2正则化后的损失函数表示为:


L_{\text{total}}(\theta) = L(\theta) + \lambda \cdot \frac{1}{2} \sum_{i} \theta_i^2

其中:

  • L(\theta)是原始损失函数(比如均方误差、交叉熵等)。

  • \lambda是正则化强度,控制正则化的力度。

  • \theta_i是模型的第 i 个权重参数。

  • \frac{1}{2} \sum_{i} \theta_i^2是所有权重参数的平方和,称为 L2 正则化项。

L2 正则化会惩罚权重参数过大的情况,通过参数平方值对损失函数进行约束。

为什么是\frac{\lambda}{2}

假设没有1/2,则对L2 正则化项\theta_i的梯度为:2\lambda\theta_i,会引入一个额外的系数 2,使梯度计算和更新公式变得复杂。

添加1/2后,对\theta_i的梯度为:\lambda\theta_i

3.1.2 梯度更新

在 L2 正则化下,梯度更新时,不仅要考虑原始损失函数的梯度,还要考虑正则化项的影响。更新规则为:
\theta_{t+1} = \theta_t - \eta \left( \nabla L(\theta_t) + \lambda \theta_t \right)

其中:

  • \eta 是学习率。

  • \nabla L(\theta_t)是损失函数关于参数\theta_t的梯度。

  • \lambda \theta_t 是 L2 正则化项的梯度,对应的是参数值本身的衰减。

很明显,参数越大惩罚力度就越大,从而让参数逐渐趋向于较小值,避免出现过大的参数。

3.1.3 作用

  1. 防止过拟合:当模型过于复杂、参数较多时,模型会倾向于记住训练数据中的噪声,导致过拟合。L2 正则化通过抑制参数的过大值,使得模型更加平滑,降低模型对训练数据噪声的敏感性。

  2. 限制模型复杂度:L2 正则化项强制权重参数尽量接近 0,避免模型中某些参数过大,从而限制模型的复杂度。通过引入平方和项,L2 正则化鼓励模型的权重均匀分布,避免单个权重的值过大。

  3. 提高模型的泛化能力:正则化项的存在使得模型在测试集上的表现更加稳健,避免在训练集上取得极高精度但在测试集上表现不佳。

  4. 平滑权重分布:L2 正则化不会将权重直接变为 0,而是将权重值缩小。这样模型就更加平滑的拟合数据,同时保留足够的表达能力。

3.1.4 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 设置随机种子以保证可重复性
torch.manual_seed(42)# 生成随机数据
n_samples = 100
n_features = 20
X = torch.randn(n_samples, n_features)  # 输入数据
y = torch.randn(n_samples, 1)  # 目标值# 定义一个简单的全连接神经网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(n_features, 50)self.fc2 = nn.Linear(50, 1)def forward(self, x):x = torch.relu(self.fc1(x))return self.fc2(x)# 训练函数
def train_model(use_l2=False, weight_decay=0.01, n_epochs=100):# 初始化模型model = SimpleNet()criterion = nn.MSELoss()  # 损失函数(均方误差)# 选择优化器if use_l2:optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=weight_decay)  # 使用 L2 正则化else:optimizer = optim.SGD(model.parameters(), lr=0.01)  # 不使用 L2 正则化# 记录训练损失train_losses = []# 训练过程for epoch in range(n_epochs):optimizer.zero_grad()  # 清空梯度outputs = model(X)  # 前向传播loss = criterion(outputs, y)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数train_losses.append(loss.item())  # 记录损失if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {loss.item():.4f}')return train_losses# 训练并比较两种模型
train_losses_no_l2 = train_model(use_l2=False)  # 不使用 L2 正则化
train_losses_with_l2 = train_model(use_l2=True, weight_decay=0.01)  # 使用 L2 正则化# 绘制训练损失曲线
plt.plot(train_losses_no_l2, label='Without L2 Regularization')
plt.plot(train_losses_with_l2, label='With L2 Regularization')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss: L2 Regularization vs No Regularization')
plt.legend()
plt.show()

3.2 L1正则化

L1 正则化通过在损失函数中添加权重参数的绝对值之和来约束模型的复杂度。

3.2.1 数学表示

设模型的原始损失函数为L(\theta),其中 \theta 表示模型权重参数,则加入 L1 正则化后的损失函数表示为:
L_{\text{total}}(\theta) = L(\theta) + \lambda \sum_{i} |\theta_i|

其中:

  • L(\theta) 是原始损失函数。

  • \lambda是正则化强度,控制正则化的力度。

  • |\theta_i|是模型第i 个参数的绝对值。

  • \sum_{i} |\theta_i|是所有权重参数的绝对值之和,这个项即为 L1 正则化项。

3.2.2 梯度更新

在 L1 正则化下,梯度更新时的公式是:
\theta_{t+1} = \theta_t - \eta \left( \nabla L(\theta_t) + \lambda \cdot \text{sign}(\theta_t) \right)

其中:

  • \eta 是学习率。

  • \nabla L(\theta_t)是损失函数关于参数\theta_t的梯度。

  • \text{sign}(\theta_t) 是参数\theta_t的符号函数,表示当 \theta_t为正时取值为 1,为负时取值为 -1,等于 0 时为 0。

因为 L1 正则化依赖于参数的绝对值,其梯度更新时不是简单的线性缩小,而是通过符号函数来直接调整参数的方向。这就是为什么 L1 正则化能促使某些参数完全变为 0。

3.2.3 作用

  1. 稀疏性:L1 正则化的一个显著特性是它会促使许多权重参数变为 。这是因为 L1 正则化倾向于将权重绝对值缩小到零,使得模型只保留对结果最重要的特征,而将其他不相关的特征权重设为零,从而实现 特征选择 的功能。

  2. 防止过拟合:通过限制权重的绝对值,L1 正则化减少了模型的复杂度,使其不容易过拟合训练数据。相比于 L2 正则化,L1 正则化更倾向于将某些权重完全移除,而不是减小它们的值。

  3. 简化模型:由于 L1 正则化会将一些权重变为零,因此模型最终会变得更加简单,仅依赖于少数重要特征。这对于高维度数据特别有用,尤其是在特征数量远多于样本数量的情况下。

  4. 特征选择:因为 L1 正则化会将部分权重置零,因此它天然具有特征选择的能力,有助于自动筛选出对模型预测最重要的特征。

3.2.4 与L2对比

  • L1 正则化 更适合用于产生稀疏模型,会让部分权重完全为零,适合做特征选择。

  • L2 正则化 更适合平滑模型的参数,避免过大参数,但不会使权重变为零,适合处理高维特征较为密集的场景。

3.2.5 代码实现

l1_lambda = 0.001
# 计算 L1 正则化项并将其加入到总损失中
l1_norm = sum(p.abs().sum() for p in model.parameters())
loss = loss + l1_lambda * l1_norm

3.3 Dropout

Dropout 的工作流程如下:

  1. 在每次训练迭代中,随机选择一部分神经元(通常以概率 p丢弃,比如 p=0.5)。

  2. 被选中的神经元在当前迭代中不参与前向传播和反向传播。

  3. 在测试阶段,所有神经元都参与计算,但需要对权重进行缩放(通常乘以 1−p),以保持输出的期望值一致。

Dropout 是一种在训练过程中随机丢弃部分神经元的技术。它通过减少神经元之间的依赖来防止模型过于复杂,从而避免过拟合。

3.3.1 基本实现

import torch
import torch.nn as nndef dropout():dropout = nn.Dropout(p=0.5)x = torch.randint(0, 10, (5, 6), dtype=torch.float)print(x)# 开始dropoutprint(dropout(x))if __name__ == "__main__":dropout()

Dropout过程:

  1. 按照指定的概率把部分神经元的值设置为0;

  2. 为了规避该操作带来的影响,需对非 0 的元素使用缩放因子1/(1-p)进行强化。

假设某个神经元的输出为 x,Dropout 的操作可以表示为:

  • 在训练阶段:

  • 在测试阶段:

    y=x

为什么要使用缩放因子1/(1-p)?

在训练阶段,Dropout 会以概率 p随机将某些神经元的输出设置为 0,而以概率 1−p 保留这些神经元。

假设某个神经元的原始输出是 x,那么在训练阶段,它的期望输出值为:

E(y_train)=(1-p)\cdot (\frac{x}{1-p})+p\cdot 0=x

通过这种缩放,训练阶段的期望输出值仍然是 x,与没有 Dropout 时一致。

3.3.2 权重影响

示例:对图片进行随机丢弃

import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import osfrom matplotlib import pyplot as plttorch.manual_seed(42)def load_img(path, resize=(224, 224)):pil_img = Image.open(path).convert('RGB')print("Original image size:", pil_img.size)  # 打印原始尺寸transform = transforms.Compose([transforms.Resize(resize),transforms.ToTensor()  # 转换为Tensor并自动归一化到[0,1]])return transform(pil_img)  # 返回[C,H,W]格式的tensorif __name__ == '__main__':dirpath = os.path.dirname(__file__)path = os.path.join(dirpath, 'img', '100.jpg')  # 使用os.path.join更安全# 加载图像 (已经是[0,1]范围的Tensor)trans_img = load_img(path)# 添加batch维度 [1, C, H, W],因为Dropout默认需要4D输入trans_img = trans_img.unsqueeze(0)# 创建Dropout层dropout = nn.Dropout2d(p=0.2)drop_img = dropout(trans_img)# 移除batch维度并转换为[H,W,C]格式供matplotlib显示trans_img = trans_img.squeeze(0).permute(1, 2, 0).numpy()drop_img = drop_img.squeeze(0).permute(1, 2, 0).numpy()# 确保数据在[0,1]范围内drop_img = drop_img.clip(0, 1)# 显示图像fig = plt.figure(figsize=(10, 5))ax1 = fig.add_subplot(1, 2, 1)ax1.imshow(trans_img)ax2 = fig.add_subplot(1, 2, 2)ax2.imshow(drop_img)plt.show()

效果:

说明:

nn.Dropout2d(p):Dropout2d 是针对二维数据设计的 Dropout 层,它在训练过程中随机将输入张量的某些通道(二维平面)置为零。

参数要求格式示例形状说明
输入(N, C, H, W)(16, 64, 32, 32)批大小×通道×高×宽
输出(N, C, H, W)(16, 64, 32, 32)与输入同形,部分通道归零

3.4 数据增强

样本数量不足(即训练数据过少)是导致过拟合(Overfitting)的常见原因之一,可以从以下角度理解:

  • 当训练数据过少时,模型容易“记住”有限的样本(包括噪声和无关细节),而非学习通用的规律。

  • 简单模型更可能捕捉真实规律,但数据不足时,复杂模型会倾向于拟合训练集中的偶然性模式(噪声)。

  • 样本不足时,训练集的分布可能与真实分布偏差较大,导致模型学到错误的规律。

  • 小数据集中,个别样本的噪声(如标注错误、异常值)会被放大,模型可能将噪声误认为规律。

数据增强(Data Augmentation)是一种通过人工生成或修改训练数据来增加数据集多样性的技术,常用于解决过拟合问题。数据增强通过“模拟”更多训练数据,迫使模型学习泛化性更强的规律,而非训练集中的偶然性模式。其本质是一种低成本的正则化手段,尤其在数据稀缺时效果显著。

在了解计算机如何处理图像之前,需要先了解图像的构成元素。

图像是由像素点组成的,每个像素点的值范围为: [0, 255], 像素值越大意味着较亮。比如一张 200x200 的图像, 则是由 40000 个像素点组成, 如果每个像素点都是 0 的话, 意味着这是一张全黑的图像。

我们看到的彩色图一般都是多通道的图像, 所谓多通道可以理解为图像由多个不同的图像层叠加而成, 例如我们看到的彩色图像一般都是由 RGB 三个通道组成的,还有一些图像具有 RGBA 四个通道,最后一个通道为透明通道,该值越小,则图像越透明。

数据增强是提高模型泛化能力(鲁棒性)的一种有效方法,尤其在图像分类、目标检测等任务中。数据增强可以模拟更多的训练样本,从而减少过拟合风险。数据增强通过torchvision.transforms模块来实现。

数据增强的好处

大幅度降低数据采集和标注成本;

模型过拟合风险降低,提高模型泛化能力;

官方地址:

transforms:Transforming and augmenting images — Torchvision 0.22 documentation

transforms:

常用变换类

  • transforms.Compose:将多个变换操作组合成一个流水线。

  • transforms.ToTensor:将 PIL 图像或 NumPy 数组转换为 PyTorch 张量,将图像数据从 uint8 类型 (0-255) 转换为 float32 类型 (0.0-1.0)。

  • transforms.Normalize:对张量进行标准化。

  • transforms.Resize:调整图像大小。

  • transforms.CenterCrop:从图像中心裁剪指定大小的区域。

  • transforms.RandomCrop:随机裁剪图像。

  • transforms.RandomHorizontalFlip:随机水平翻转图像。

  • transforms.RandomVerticalFlip:随机垂直翻转图像。

  • transforms.RandomRotation:随机旋转图像。

  • transforms.ColorJitter:随机调整图像的亮度、对比度、饱和度和色调。

  • transforms.RandomGrayscale:随机将图像转换为灰度图像。

  • transforms.RandomResizedCrop:随机裁剪图像并调整大小。

3.4.1 图片缩放

具体参考官方文档:Illustration of transforms — Torchvision 0.22 documentation

参考代码:

from PIL import Imagedef test03():img1 = plt.imread('./img/100.jpg')plt.imshow(img1)plt.show()img = Image.open('./img/100.jpg')transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])r_img = transform(img)print(r_img.shape)r_img = r_img.permute(1, 2, 0)plt.imshow(r_img)plt.show()

3.4.2 随机裁剪

img = Image.open('./img/100.jpg')transform = transforms.Compose([transforms.RandomCrop(size=(224, 224)), transforms.ToTensor()])r_img = transform(img)print(r_img.shape)r_img = r_img.permute(1, 2, 0)plt.imshow(r_img)plt.show()

3.4.3 随机水平翻转

RandomHorizontalFlip(p):随机水平翻转图像,参数p表示翻转概率(0 ≤ p ≤ 1),p=1 表示必定翻转,p=0 表示不翻转

img = Image.open('./img/100.jpg')transform = transforms.Compose([transforms.RandomHorizontalFlip(p=1), transforms.ToTensor()])r_img = transform(img)print(r_img.shape)r_img = r_img.permute(1, 2, 0)plt.imshow(r_img)plt.show()

3.4.4 调整图片颜色

transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

brightness:

  • 亮度调整的范围。

  • 可以float(min, max) 元组:

    • 如果是 float(如 brightness=0.2),则亮度在 [max(0, 1 - 0.2), 1 + 0.2] = [0.8, 1.2] 范围内随机缩放。

    • 如果是 (min, max)(如 brightness=(0.5, 1.5)),则亮度在 [0.5, 1.5] 范围内随机缩放。

contrast:

  • 对比度调整的范围。

  • 格式与 brightness 相同。

saturation:

  • 饱和度调整的范围。

  • 格式与 brightness 相同。

hue:

  • 色调调整的范围。

  • 可以是一个浮点数(表示相对范围)或一个元组 (min, max)。

  • 取值范围必须为 [-0.5, 0.5](因为色相在 HSV 色彩空间中是循环的,超出范围会导致颜色异常)。

  • 例如,hue=0.1 表示色调在 [-0.1, 0.1] 之间随机调整。

img = Image.open('./img/100.jpg')transform = transforms.Compose([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor()])r_img = transform(img)print(r_img.shape)r_img = r_img.permute(1, 2, 0)plt.imshow(r_img)plt.show()

3.4.5 随机旋转

RandomRotation用于对图像进行随机旋转。

transforms.RandomRotation(degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0
)

degrees:

  • 旋转角度的范围,可以是一个浮点数或元组 (min_degree, max_degree)。

  • 例如,degrees=30 表示旋转角度在 [-30, 30] 之间随机选择。

  • 例如,degrees=(30, 60) 表示旋转角度在 [30, 60] 之间随机选择。

interpolation:

  • 插值方法,用于旋转图像。

  • 默认是 InterpolationMode.NEAREST(最近邻插值)。

  • 其他选项包括 InterpolationMode.BILINEAR(双线性插值)、InterpolationMode.BICUBIC(双三次插值)等。

expand:

  • 是否扩展图像大小以适应旋转后的图像。如:当需要保留完整旋转后的图像时(如医学影像、文档扫描)

  • 如果为 True,旋转后的图像可能会比原始图像大。

  • 如果为 False,旋转后的图像大小与原始图像相同。

center:

  • 旋转中心点的坐标,默认为图像中心。

  • 可以是一个元组 (x, y),表示旋转中心的坐标。

fill:

  • 旋转后图像边缘的填充值。

  • 可以是一个浮点数(用于灰度图像)或一个元组(用于 RGB 图像)。默认填充0(黑色)

# 加载图像image = Image.open('./img/100.jpg')# 定义 RandomRotation 变换transform = transforms.RandomRotation(degrees=30)  # 旋转角度在 [-30, 30] 之间随机选择# 应用变换rotated_image = transform(image)# 显示图像plt.imshow(rotated_image)plt.axis('off')plt.show()

3.4.6 图片转Tensor

import torch
from PIL import Image
from torchvision import transforms
import osdef test001():dir_path = os.path.dirname(__file__)file_path = os.path.join(dir_path,'img', '1.jpg')file_path = os.path.relpath(file_path)print(file_path)# 1. 读取图片img = Image.open(file_path)# transforms.ToTensor()用于将 PIL 图像或 NumPy 数组转换为 PyTorch 张量,并自动进行数值归一化和维度调整# 将像素值从 [0, 255] 缩放到 [0.0, 1.0](浮点数)# 自动将图像格式从 (H, W, C)(高度、宽度、通道)转换为 PyTorch 标准的 (C, H, W)transform = transforms.ToTensor()img_tensor = transform(img)print(img_tensor)if __name__ == "__main__":test001()

3.4.7 Tensor转图片

import torch
from PIL import Image
from torchvision import transformsdef test002():# 1. 随机一个数据表示图片img_tensor = torch.randn(3, 224, 224)# 2. 创建一个transformstransform = transforms.ToPILImage()# 3. 转换为图片img = transform(img_tensor)img.show()# 4. 保存图片img.save("./test.jpg")if __name__ == "__main__":test002()

练习:通过一个Demo加深对Torch的API理解和使用

import torch
from PIL import Image
from torchvision import transforms
import osdef test003():# 获取文件的相对路径dir_path = os.path.dirname(__file__)file_path = os.path.relpath(os.path.join(dir_path, 'dog.jpg'))# 加载图片img = Image.open(file_path)# 转换图片为tensortransform = transforms.ToTensor()t_img = transform(img)print(t_img.shape)# 获取GPU资源,将图片处理移交给CUDAdevice = 'cuda' if torch.cuda.is_available() else 'cpu't_img = t_img.to(device)t_img += 0.3# 将图片移交给CPU进行图片保存处理,一般IO操作是基于CPU的t_img = t_img.cpu()transform = transforms.ToPILImage()img = transform(t_img)img.show()if __name__ == "__main__":test003()

3.4.8 归一化

  • 标准化:将图像的像素值从原始范围(如 [0, 255] 或 [0, 1])转换为均值为 0、标准差为 1 的分布。

  • 加速训练:标准化后的数据分布更均匀,有助于加速模型训练。

  • 提高模型性能:标准化可以使模型更容易学习到数据的特征,提高模型的收敛性和稳定性。

img = Image.open('./img/100.jpg')transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])r_img = transform(img)print(r_img.shape)r_img = r_img.permute(1, 2, 0)plt.imshow(r_img)plt.show()

mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]

均值(Mean):数据集中所有图像在每个通道上的像素值的平均值。

标准差(Std):数据集中所有图像在每个通道上的像素值的标准差。

RGB 三个通道的均值和标准差 不是随便定义的,而是需要根据具体的数据集进行统计计算。这些值是 ImageNet 数据集的统计结果,已成为计算机视觉任务的默认标准。

数据集计算均值和标准差

以CIFAR10数据集为例:

# 获取数据集
train_data = datasets.CIFAR10(root='./cifar10',train=True,download=True,transform=transforms.ToTensor()  # 自动将PIL图像转为[0,1]范围的张量
)def compute_mean_std(dataset):# 初始化累加器mean = torch.zeros(3)std = torch.zeros(3)num_samples = len(dataset)# 遍历数据集计算均值for img, _ in dataset:# dim=(1, 2) 表示对图像的高度(H)和宽度(W)维度求均值,保留通道维度(C)。mean += img.mean(dim=(1, 2)) # 全局的通道均值mean /= num_samplesprint(mean)# 遍历数据集计算标准差for img, _ in dataset:# 原始mean 是一个形状为 [3] 的张量,表示每个通道的均值。# 使用 view(3, 1, 1) 将 mean 的形状从 [3] 改变为 [3, 1, 1]。# 这样,mean 的形状变为 [3, 1, 1],其中 3 表示通道数,1 和 1 分别表示高度和宽度的维度。# 当执行 img - mean.view(3, 1, 1) 时,PyTorch 会利用广播机制将 mean 自动扩展到与 img 相同的形状 [3, H, W]。# 然后利用方差公式计算:var=E(x-E(x))^2std += (img - mean.view(3, 1, 1)).pow(2).mean(dim=(1, 2))# 计算出所有图片的方差后,计算平均方差,然后求标准差std = torch.sqrt(std / num_samples)return mean, stdmean, std = compute_mean_std(train_data)
print(f"Mean: {mean}")  # 输出类似 [0.4914, 0.4822, 0.4465]
print(f"Std: {std}")  # 输出类似 [0.2470, 0.2435, 0.2616]

3.4.9 数据增强整合

使用transforms.Compose()把要增强的操作整合到一起:

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms, datasets, utilsdef test01():# 定义数据增强和归一化transform = transforms.Compose([transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转 ±10 度transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),  # 随机裁剪到 32x32,缩放比例在0.8到1.0之间transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 随机调整亮度、对比度、饱和度、色调transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化,这是一种常见的经验设置,适用于数据范围 [0, 1],使其映射到 [-1, 1]])# 加载 CIFAR-10 数据集,并应用数据增强trainset = datasets.CIFAR10(root="./cifar10_data", train=True, download=True, transform=transform)dataloader = DataLoader(trainset, batch_size=4, shuffle=False)# 获取一个批次的数据images, labels = next(iter(dataloader))# 还原图片并显示plt.figure(figsize=(10, 5))for i in range(4):# 反归一化:将像素值从 [-1, 1] 还原到 [0, 1]img = images[i] / 2 + 0.5# 转换为 PIL 图像img_pil = transforms.ToPILImage()(img)# 显示图片plt.subplot(1, 4, i + 1)plt.imshow(img_pil)plt.axis('off')plt.title(f'Label: {labels[i]}')plt.show()if __name__ == "__main__":test01()

代码解释:

transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

若数据分布与ImageNet差异较大(如医学影像、卫星图、MNIST等),或均值和标准差未知时,可用此简化设置。

将图片进行归一化,使数据更符合正态分布,归一化公式:

normalized\_img=\frac{img-0.5}{0.5}=2*img-1

img = img / 2 + 0.5

表示反归一化,是归一化的逆运算:

img=\frac{normalized\_img+1}{2} =\frac{normalized\_img}{2}+0.5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89900.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89900.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java设计模式之行为型模式(备忘录模式)应用场景分析

最近看到一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站 一、用户交互与编辑操作 文本编辑器撤销/重做 场景描述:用户编辑文档时,可通过CtrlZ撤销误操作,或通过Ctr…

5.Java的4个权限修饰符

1.private(私有访问权限)最严格的访问修饰符,它限定被修饰的成员仅能在声明它的当前类内部访问。其他任何外部类都无法直接访问该成员。作用:强制封装,确保类内部实现细节的隐藏性和数据安全性2.默认权限(包…

Linux入门介绍

目录 一、环境 二、Linux发展历史 1、计算机 2、操作系统 四、认识Linux的 内核版本名称 一、环境 一般是Centos 7 Ubuntu 20.04 / 22.04 前者已经停止更新与维护,但很多公司还在使用前者 二、Linux发展历史 1、计算机 1945年 2.14---埃尼阿克---军事用处&…

spring boot2升级boot3

spring boot2升级boot3 整体流程如下 1、借助于开源的自动化代码重构工具OpenRewrite,快速地进行代码重构等 2、相关坐标升级更改 3、配置文件属性更改 4、打包、构建与运行验证 1. 前期准备工作第一步:确保升级之前项目是可编译运行的第二步&#xff1a…

mac终端设置代理

在Mac上配置终端走代理,需设置终端(如zsh或bash)使用HTTP/HTTPS/SOCKS代理,以便命令行工具(如curl、git、npm)通过代理访问网络。以下是详细步骤,适用于macOS 10.15及以上版本。 前提条件 代理服…

VSTO Excel中打开WinForm.ShowDialog()后,如果要使用当前的wb.Application在后台操作其他Excel文件(保持隐藏状态)

在VSTO Excel中打开WinForm.ShowDialog()后,如果要使用当前的wb.Application在后台操作其他Excel文件(保持隐藏状态),可以通过以下几种方式实现: 方法一:设置Application属性控制可见性 // 在WinForm中获取…

【网络安全】DDOS攻击

如果文章不足还请各位师傅批评指正!你有没有过这种经历:双 11 抢券时页面卡成幻灯片,游戏团战突然全员掉线,刷视频时进度条永远转圈圈?除了 “网渣”,可能还有个更糟的原因 —— 你正被 DDoS 攻击 “堵门”…

第9天 | openGauss中一个表空间可以存储多个数据库

接着昨天继续学习openGauss,今天是第9天了。今天学习内容是o一个数据库可以存储在多个表空间中。 老规矩,先登陆墨天轮为我准备的实训实验室 rootmodb:~# su - omm ommmodb:~$ gsql -r作业要求 1.创建表空间newtbs1 omm# CREATE TABLESPACE newtbs1 RELATIVE LOCATI…

H3C路由器模拟PPPOE拨号

拓扑简图 效果图 PPPoE服务器端脚本 1. 基础配置 system-view sysname PPPoE-Server # 可选,设置设备名称2. 创建本地用户(认证账号)​ local-user pppuser class network # 创建网络类用户 password simple 123456 # 设置密码(PAP/CHAP共用) service-type ppp #

Github Actions Workflows 上传 Dropbox

一、注册 访问 https://www.dropbox.com/register选择 "个人" 如果想免费使用,一定要选择 “继续使用2GB的Dropbox Basic 套餐”,如下: 二、在 Dropbox 中 创建app 需要去注册的邮箱中验证一下邮箱.访问 https://www.dropbox.com…

生产管理系统实现生产全过程可视化

随着现代工业的不断发展,智能制造、数字化转型已成为企业提高竞争力的重要途径。生产管理作为企业运营的核心环节,直接关系到产品质量、生产效率以及成本控制。传统的生产管理方式大多依赖手工记录和经验管理,存在信息滞后、数据不一致、响应…

CSS实现背景色下移10px

众所周知,背景颜色是不能移动的,通常是填充满当前容器。 不过可以想想其它办法。。 🧐 利用css3的线性属性linear,在垂直方向向下推要移动的距离设成透明颜色,能在视觉上巧妙实现下移的效果。 .title {height: 20px;background: linear-gradient(to bottom,rgba(255, …

访问 gitlab 跳转 0.0.0.0

1、检查防火墙是否关闭2、检查服务器端口是否被占用3、检查服务器是否对外开放80端口(gitlab 默认使用80端口)以阿里云服务器为例如果没有SSH 、HTTP、HTTPS 开放,需要增加规则进行添加点击确定即可。

Kotlin集合与空值

我们已经学习了 Kotlin 中的空安全(null safety)。在本节中,我们将讨论如何处理集合中的空值(null),因为集合比其他数据类型更复杂。我们还将讨论如何处理可空元素时常用的便利方法。 集合与空值 可空集合和…

nextjs编程式跳转

Next.js 中&#xff0c;你可以通过多种方式实现编程式导航&#xff08;即通过代码而非 <Link> 组件跳转页面&#xff09;。以下是完整的实现方法&#xff1a; 1. 使用 useRouter Hook&#xff08;函数组件&#xff09; 这是最常用的方法&#xff0c;适用于函数组件&#…

Git Remote命令介绍:远程仓库管理

一、Git Remote 是什么 git remote主要用于管理远程仓库&#xff0c;可以轻松地与远程仓库进行交互&#xff0c;实现代码的共享与同步 。 二、Git Remote 的作用 &#xff08;一&#xff09;连接桥梁 假设你正在参与一个大型的 Web 应用开发项目&#xff0c;团队成员分布在…

Android开发中的11种行为型设计模式深度解析

在Android应用开发中&#xff0c;设计模式是解决特定问题的可重用方案&#xff0c;其中行为型设计模式尤其重要&#xff0c;它们专注于对象之间的通信和职责分配。本文将深入解析Android开发中最常用的11种行为型设计模式&#xff0c;每个模式都配有详细的介绍和实际应用示例&a…

Python 模块未找到?这样解决“ModuleNotFoundError”

在 Python 开发中&#xff0c;遇到“ModuleNotFoundError”时&#xff0c;通常是因为 Python 解释器无法找到你尝试导入的模块。这可能是由于多种原因导致的&#xff0c;比如模块未安装、路径不正确、虚拟环境未激活等。今天&#xff0c;就让我们一起探讨如何解决“ModuleNotFo…

Numpy库,矩阵形状与维度操作

目录 一.numpy库简介与安装 numpy库的安装 二.numpy核心功能 1.矩阵处理 2.数学运算 三.数据的维度与属性 1.维度管理 2.属性方法 四.数据类型与存储范围 五.矩阵形状与维度操作 六.数据升维与reshape()方法 一.numpy库简介与安装 NumPy是Python中用于科学计算的核心…

图论(2):最短路

最短路一、模板1. Floyd2. 01BFS3. SPFA4. Dijkstra&#xff08;弱化版&#xff09;5. Dijkstra&#xff08;优化版&#xff09;二、例题1. Floyd1.1 传送门1.2 无向图最小环1.3 灾后重建1.4 飞猪2. 01BFS2.1 Kathiresan2.2 障碍路线2.3 奇妙的棋盘3. SPFA3.1 奶牛派对3.2 营救…