【一起来学AI大模型】PyTorch 实战示例:使用 BatchNorm 处理张量(Tensor)

PyTorch 实战示例 演示如何在神经网络中使用 BatchNorm 处理张量(Tensor),涵盖关键实现细节和常见陷阱。示例包含数据准备、模型构建、训练/推理模式切换及结果分析。


示例场景:在 CIFAR-10 数据集上实现带 BatchNorm 的 CNN

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 1. 数据准备 & 预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化到[-1,1]
])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)# 2. 定义带 BatchNorm 的 CNN
class CNNWithBN(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(# Conv-BN-ReLU-Pool 模块nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),  # 关键!通道数=64nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),  # 通道数=128nn.ReLU(),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Linear(128 * 8 * 8, 512),nn.BatchNorm1d(512),  # 全连接层也适用BNnn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)  # 展平return self.classifier(x)model = CNNWithBN().to(device)# 3. 训练循环(重点:BN的训练模式)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # 配合BN的Weight Decaydef train(epoch):model.train()  # 切换到训练模式(启用BN的mini-batch统计)for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 4. 测试推理(重点:BN的推理模式)
def test():model.eval()  # 切换到评估模式(使用全局统计量)correct = 0with torch.no_grad():  # 禁用梯度计算for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = outputs.max(1)correct += predicted.eq(labels).sum().item()accuracy = 100. * correct / len(test_set.dataset)print(f'Test Accuracy: {accuracy:.2f}%')return accuracy# 5. 执行训练与测试
for epoch in range(10):train(epoch)acc = test()# 6. 查看BN层参数(实战调试)
print("\nBatchNorm层参数检查:")
for name, module in model.named_modules():if isinstance(module, nn.BatchNorm2d):print(f"{name}: weight={module.weight.data.mean().item():.4f}, "f"bias={module.bias.data.mean().item():.4f}")print(f"  Running Mean: {module.running_mean.mean().item():.4f}, "f"Running Var: {module.running_var.mean().item():.4f}")

关键实战细节解析

1. BatchNorm 层初始化
nn.BatchNorm2d(num_features)  # 必须与输入通道数一致
nn.BatchNorm1d(512)          # 全连接层适用
2. 模式切换的重要性
模式代码BN行为忘记切换的后果
训练model.train()使用当前batch的统计量更新 running_mean/running_var推理时统计量错误,精度大幅下降
推理model.eval()固定使用训练积累的 running_mean/running_var训练引入测试噪声,收敛不稳定
3. 参数解读(以 nn.BatchNorm2d 为例)
# 可学习参数
bn_layer.weight   # γ (缩放因子), shape=(C,)
bn_layer.bias     # β (偏移因子), shape=(C,)# 自动统计量(训练时更新)
bn_layer.running_mean   # 全局均值估计, shape=(C,)
bn_layer.running_var    # 全局方差估计, shape=(C,)
4. 常见错误及解决方案
  • 错误1:Batch Size 过小(<16)

    # 解决方案:使用GroupNorm替代
    nn.GroupNorm(num_groups=32, num_channels=128)

  • 错误2:忘记在测试时调用 model.eval()

    # 正确做法:在推理前显式切换模式
    model.eval()
    with torch.no_grad():output = model(input_tensor)

  • 错误3:微调时错误处理 BN 统计量

    # 冻结BN的统计量(只更新γ/β)
    for module in model.modules():if isinstance(module, nn.BatchNorm2d):module.eval()  # 固定running_mean/var


BatchNorm 张量变换可视化

假设输入张量维度:(batch_size, channels, height, width) = (4, 3, 2, 2)

input_tensor = torch.randn(4, 3, 2, 2)  # 模拟输入数据# BatchNorm2d 操作步骤
bn = nn.BatchNorm2d(3)  # 通道数=3# 前向传播分解:
# 1. 计算每个通道的均值和方差
mean_per_channel = input_tensor.mean(dim=[0, 2, 3])  # shape=(3,)
var_per_channel = input_tensor.var(dim=[0, 2, 3], unbiased=False)# 2. 标准化 (x - μ) / √(σ² + ε)
normalized = (input_tensor - mean_per_channel[None, :, None, None]) / torch.sqrt(var_per_channel[None, :, None, None] + 1e-5)# 3. 缩放和偏移
output = normalized * bn.weight[None, :, None, None] + bn.bias[None, :, None, None]

性能对比(CIFAR-10 实验结果)

模型测试精度收敛速度训练稳定性
无 BatchNorm78.2%慢 (20 epochs)需要精细调参
带 BatchNorm86.7%快 (8 epochs)高学习率鲁棒
BatchNorm + Dropout85.9%最优正则化

注意:BN 的轻微正则化效果可能部分替代 Dropout,但组合使用需调整丢弃概率

通过这个实战示例,你可以直观理解 BatchNorm 如何操作张量,以及它在实际训练中的关键作用。建议在 Colab 中运行代码并尝试修改 BN 参数(如 momentum 参数控制统计量更新速度),观察对结果的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/88161.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/88161.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第8章:应用层协议HTTP、SDN软件定义网络、组播技术、QoS

应用层协议HTTP 应用层协议概述 应用层协议非常多&#xff0c;我们重点熟悉以下常见协议功能即可。 Telnet:远程登录协议&#xff0c;基于TCP 23端口&#xff0c;用于远程管理设备&#xff0c;采用明文传输。安全外壳协议 (SecureShell,SSH) ,基于TCP 22端口&#xff0c;用于…

uniapp页面间通信

uniapp中通过eventChannel实现页面间通信的方法&#xff0c;这是一种官方推荐的高效传参方式。我来解释下这种方式的完整实现和注意事项&#xff1a;‌发送页面&#xff08;父页面&#xff09;‌&#xff1a;uni.navigateTo({url: /pages/detail/detail,success: (res) > {/…

Android ViewModel机制与底层原理详解

Android 的 ViewModel 是 Jetpack 架构组件库的核心部分&#xff0c;旨在以生命周期感知的方式存储和管理与 UI 相关的数据。它的核心目标是解决两大痛点&#xff1a; 数据持久化&#xff1a; 在配置变更&#xff08;如屏幕旋转、语言切换、多窗口模式切换&#xff09;时保留数…

双倍硬件=双倍性能?TDengine线性扩展能力深度实测验证!

软件扩展能力是软件架构设计中的一个关键要素&#xff0c;具有良好扩展能力的软件能够充分利用新增的硬件资源。当软件性能与硬件增加保持同步比例增长时&#xff0c;我们称这种现象为软件具有线性扩展能力。要实现这种线性扩展并不简单&#xff0c;它要求软件架构精心设计&…

频繁迭代下完成iOS App应用上架App Store:一次快速交付项目的完整回顾

在一次面向商户的会员系统App开发中&#xff0c;客户要求每周至少更新一次版本&#xff0c;涉及功能迭代、UI微调和部分支付方案的更新。团队使用Flutter进行跨平台开发&#xff0c;但大部分成员日常都在Windows或Linux环境&#xff0c;只有一台云Mac用于打包。如何在高频率发布…

springsecurity03--异常拦截处理(认证异常、权限异常)

目录 Spingsecurity异常拦截处理 认证异常拦截 权限异常拦截 注册异常拦截器 设置跨域访问 Spingsecurity异常拦截处理 认证异常拦截 /*自定义认证异常处理器类*/ Component public class MyAuthenticationExceptionHandler implements AuthenticationEntryPoint {Overr…

企业如何制作网站?网站制作的步骤与流程?

以下是2025年网站制作的综合指南&#xff0c;涵盖核心概念、主流技术及实施流程&#xff1a; 一、定义与范畴 网站制作是通过页面结构设计、程序设计、数据库开发等技术&#xff0c;将视觉设计转化为可交互网页的过程&#xff0c;包含前端展示与后台功能实现。其核心目标是为企…

Rust+Blender:打造高性能游戏引擎

基于Rust和Blender的游戏引擎 以下是基于Rust和Blender的游戏引擎开发实例,涵盖不同应用场景和技术方向的实际案例。案例分为工具链整合、渲染技术、物理模拟等类别,每个案例附核心代码片段或实现逻辑。 工具链整合案例 案例1:Blender模型导出到Bevy引擎 使用blender-bev…

Git基本操作1

Git 是一款分布式版本控制系统&#xff0c;主要用于高效管理代码版本和团队协作开发。它能精确记录每次代码修改&#xff0c;支持版本回溯和分支管理&#xff0c;让开发者可以并行工作而互不干扰。通过本地提交和远程仓库同步&#xff0c;Git 既保障了代码安全&#xff0c;又实…

React Native 组件间通信方式详解

React Native 组件间通信方式详解 在 React Native 开发中&#xff0c;组件间通信是核心概念之一。以下是几种主要的组件通信方式及其适用场景&#xff1a; 简单父子通信&#xff1a;使用 props 和回调函数兄弟组件通信&#xff1a;提升状态到共同父组件跨多级组件&#xff1a;…

TCP的可靠传输机制

TCP通过校验和、序列号、确认应答、重发控制、连接管理以及窗口控制等机制实现可靠性的传输。 先来看第一个可靠性传输的方法。 通过序列号和可靠性提供可靠性 TCP是面向字节的。TCP把应用层交下来的报文&#xff08;可能要划分为许多较短的报文段&#xff09;看成一个一个字节…

没有DBA的敏捷开发管理

前言一家人除了我都去旅游了&#xff0c;我这项请假&#xff0c;请不动啊。既然在家了&#xff0c;闲着也是闲着&#xff0c;就复盘下最近的工作&#xff0c;今天就复盘表结构管理吧&#xff0c;随系统启动的&#xff0c;不是flyway&#xff0c;而是另一个liquibase&#xff0c…

go-carbon v2.6.10发布,轻量级、语义化、对开发者友好的 golang 时间处理库

carbon 是一个轻量级、语义化、对开发者友好的 Golang 时间处理库&#xff0c;提供了对时间穿越、时间差值、时间极值、时间判断、星座、星座、农历、儒略日 / 简化儒略日、波斯历 / 伊朗历的支持。 carbon 目前已捐赠给 dromara 开源组织&#xff0c;已被 awesome-go 收录&am…

【AI News | 20250708】每日AI进展

AI Repos 1、claude-code-templates Claude Code Templates是一款全面的命令行工具&#xff0c;旨在为不同编程语言和框架&#xff08;如JavaScript/TypeScript、Python等&#xff0c;Go和Rust即将推出&#xff09;提供优化的Claude Code配置。它通过交互式设置、自动化钩子&a…

Nginx源码安装+静态站点部署指南(CentOS 7)

安装包&#xff1a;可自行前往我的飞书下载 Docs 也可以进入 nginx 官网&#xff0c;下载自己所需适应版本 nginx 开始安装nginx 1. 创建准备目录 cd /opt mkdir soft module # 创建软件包和源码解压目录 2. 安装依赖环境 yum -y install make zlib zlib-devel gcc-c l…

交换机的核心原理和作用

一、交换机的核心原理交换机是一种用于连接多台设备的网络硬件&#xff0c;其核心原理基于二层网络&#xff08;数据链路层&#xff09;的 MAC 地址寻址1. MAC 地址学习与存储当交换机接收到数据帧时&#xff0c;会读取帧中的源 MAC 地址&#xff0c;并将该地址与对应的端口号记…

【工具变量】上市公司企业金融强监管数据、资管新规数据(2001-2024年)

数据简介&#xff1a;参考顶刊《经济研究》李青原&#xff08;2022&#xff09;老师的做法&#xff0c;Post 为时间虚拟变量&#xff0c;根据资管新规实施的时间&#xff0c;当观测期为2018 年上半年及之后时&#xff0c;Post 取值1&#xff0c;否则取值0。PreFin 为资管新规实…

CSS Grid与Flexbox布局实战对比

概述 CSS布局技术在过去几年经历了重大变革&#xff0c;从传统的基于浮动和定位的方法&#xff0c;到现在强大的Flexbox和Grid布局系统。这两种现代布局方法极大地简化了复杂界面的开发过程&#xff0c;但它们各自适用于不同的场景。本文将对Flexbox和Grid进行深入比较&#x…

[Pytest][Part 4]多种测试运行方式

实现需求2&#xff1a;有两种运行测试的方式&#xff1a;通过config配置文件运行&#xff0c;测试只需要修改config配置文件cmdline 运行这里是新建一个config类来存储所有的测试配置&#xff0c;以后配置有修改的话也只需要修改这个类。根据目前的测试需求&#xff0c;config中…

平衡二叉树的删除操作

对于平衡二叉树的操作应对与考试只需要模拟出过程即可&#xff0c;且他的过程和插入的平衡方法一样&#xff0c;不一样的只是对于平衡因子的计算上。接下来我将给出方法①删除结点&#xff08;方法同“二叉排序树”&#xff09; ②一路向北找到最小不平衡子树&#xff0c;找不到…