15.手动实现BatchNorm(BN)

15.1 BatchNorm操作手动实现

import torch 
from torch import nndef batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):if not torch.is_grad_enabled():#这个是推理模式X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape)==2:mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:mean=X.mean(dim=(0,2,3),keepdim=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)# 更新移动平均的均值和方差X_hat=(X-mean)/torch.sqrt(var+eps)moving_mean=momentum*moving_mean+(1.0-momentum)*meanmoving_var=momentum*moving_var+(1.0-momentum)*varY=gamma*X_hat+betareturn Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):def __init__(self, num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#这是两个需要更新的参数self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))self.moving_mean=torch.zeros(shape)self.moving_var=torch.ones(shape)#这个不能为0,应该是/sqrt(var)def forward(self,X):#计算设备对齐if self.moving_mean.device!=X.device:self.moving_mean=self.moving_mean.to(X.device)self.moving_var=self.moving_var.to(X.device)Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y
model=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),#Flatten()之后就是[batch_size,features] 2维度的向量矩阵nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10))

15.2 BatchNorm实验效果

################################################################################################################
"""BatchNorm"""
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):if not torch.is_grad_enabled():#这个是推理模式X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape)==2:mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:mean=X.mean(dim=(0,2,3),keepdim=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)# 更新移动平均的均值和方差X_hat=(X-mean)/torch.sqrt(var+eps)moving_mean=momentum*moving_mean+(1.0-momentum)*meanmoving_var=momentum*moving_var+(1.0-momentum)*varY=gamma*X_hat+betareturn Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):def __init__(self, num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#这是两个需要更新的参数self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))self.moving_mean=torch.zeros(shape)self.moving_var=torch.ones(shape)#这个不能为0,应该是/sqrt(var)def forward(self,X):#计算设备对齐if self.moving_mean.device!=X.device:self.moving_mean=self.moving_mean.to(X.device)self.moving_var=self.moving_var.to(X.device)Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y
################################################################################################################
transforms=transforms.Compose([transforms.Resize(28),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),#Flatten()之后就是[batch_size,features] 2维度的向量矩阵nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10)).to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################
print("BatchNorm算法学习参数效果:")
print("gamma:",model[1].gamma.reshape((-1,)))
print("beta:",model[1].beta.reshape((-1,)))

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914777.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914777.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【项目实践】SMBMS(Javaweb版)汇总版

文章目录前期准备工作数据库、数据表创建web项目创建项目文件目录配置Tomcat,导入依赖建立实体类编写基础公共方法类导入基础资源登录功能登录页面持久层dao层的用户登录及接口实现dao层接口实现所需的方法业务层sevice层的接口的实现接口实现相关的业务逻辑编写ser…

隐藏源IP的核心方案与高防实践

一、源IP暴露的风险 直接DDoS攻击:2025年Q2全球DDoS攻击峰值达3.8Tbps(来源:Cloudflare报告)漏洞利用:暴露的SSH端口平均每天遭受12,000暴力破解尝试数据泄露:直接连接数据库风险提升300% 二、4种有效隐藏方…

深度学习图像分类数据集—五种电器识别分类

该数据集为图像分类数据集,适用于ResNet、VGG等卷积神经网络,SENet、CBAM等注意力机制相关算法,Vision Transformer等Transformer相关算法。 数据集信息介绍:五种电器识别分类:[notebook, phone, powerbank, tablet, w…

Windows11家庭版配置frigate 嵌入自研算法(基于Yolov8)-【2】

使用 YOLOv8 的 results.xyxy 结构,下面是一个完整的 MQTT 推送脚本,用于把识别到的目标(比如突涌水、水渍、障碍物等)发送到 Frigate 的 MQTT 接口。✅ 前提假设 YOLOv8 推理代码已经运行并生成 results.xyxy。每一行是 [x1, y1,…

安装llama-factory报错 error: subprocess-exited-with-error

报错信息如下 Using cached https://mirrors.aliyun.com/pypi/packages/17/89/940a509ee7e9449f0c877fa984b37b7cc485546035cc67bbc353f2ac20f3/av-15.0.0.tar.gz (3.8 MB)Preparing metadata (pyproject.toml) ... errorerror: subprocess-exited-with-error Preparing metad…

QT 多线程 管理串口

记录一下自己使用多线程进行串口管理和数据读取的过程。如果有问题的话可以发消息给我。背景在使用QT制作一个串口数据读取处理的小软件的时候,发现了存在界面卡顿的情况,感觉性能太低,于是考虑把串口数据的读取和处理都放到子线程的缓冲区中…

在虚拟环境中复现论文(环境配置)

前提:已经下载condawinR,输入cmd进入命令行conda create -n PPT python3.8.3 pytorch1.7.0conda create -n PPT(虚拟环境名) python3.8.3(包名) pytorch1.7.0(包名)安装完毕,激活虚拟环境:conda activate PPT根据论文readme要求安…

Flutter Web 的发展历程:Dart、Flutter 与 WasmGC

Flutter Web 应该是 Flutter 开发者里最不“受宠”的平台了,但是其实 Flutter 和 Dart 团队对于 Web 的投入一直没有减少,这也和 Flutter 还有 Dart 的"出生"有关系,今天就借着 Dart 团队的 mer Ağacan 和 Martin Kustermann 在油…

c#方法关键字,ref、out、int

在 C# 中,ref、out 和 in 是用于方法参数传递的关键字,它们控制参数如何在方法和调用者之间传递数据。以下是对这三个关键字的详细分析:1. ref 关键字(引用传递)作用允许方法修改调用者的变量:通过引用传递…

设计模式—初识设计模式

1.设计模式经典面试题分析几个常见的设计模式对应的面试题。1.1原型设计模式1.使用UML类图画出原型模式核心角色(意思就是使用会考察使用UML画出设计模式中关键角色和关系图等)2.原型设计模式的深拷贝和浅拷贝是什么,写出深拷贝的两种方式的源…

深度学习-参数初始化、损失函数

A、参数初始化参数初始化对模型的训练速度、收敛性以及最终的性能产生重要影响。它可以尽量避免梯度消失和梯度爆炸的情况。一、固定值初始化在神经网络训练开始时,将权重或偏置初始化为常数。但这种方法在实际操作中并不常见。1.1全零初始化将所有的权重参数初始化…

格密码--Ring-SIS和Ring-LWE

1. 多项式环&#xff08;Polynomial Rings&#xff09; 设 f∈Z[x]f \in \mathbb{Z}[x]f∈Z[x] 是首一多项式&#xff08;最高次项系数为1&#xff09; 则环 RZ[x]/(f)R \mathbb{Z}[x]/(f)RZ[x]/(f) 元素为&#xff1a;所有次数 <deg⁡(f)< \deg(f)<deg(f) 的多项式…

前端工作需要和哪些人打交道?

前端工作中需要协作的角色及协作要点 前端工作中需要协作的角色及协作要点 前端开发处于产品实现的 “中间环节”,既要将设计方案转化为可交互的界面,又要与后端对接数据,还需配合团队推进项目进度。日常工作中,需要频繁对接的角色包括以下几类,每类协作都有其核心目标和…

万字长文解析 OneCode3.0 AI创新设计

一、研究概述与背景 1.1 研究背景与意义 在 AI 技术重塑软件开发的浪潮中&#xff0c;低代码平台正经历从 “可视化编程” 到 “意图驱动开发” 的根本性转变。这种变革不仅提升了开发效率&#xff0c;更重新定义了人与系统的交互方式。作为国内领先的低代码平台&#xff0c;On…

重学前端006 --- 响应式网页设计 CSS 弹性盒子

文章目录盒模型一、盒模型的基本概念二、两种盒模型的对比 举例三、总结Flexbox 弹性盒子布局一、Flexbox 的核心概念​​二、Flexbox 的基本语法​​​​1. 定义 Flex 容器​​​2. Flex 容器的主要属性​​​​3. Flex 项目的主要属性​​​​三、Flexbox 的常见布局示例​​…

rLLM:用于LLM Agent RL后训练的创新框架

rLLM&#xff1a;用于LLM Agent RL后训练的创新框架 本文介绍了rLLM&#xff0c;一个用于语言智能体后训练的可扩展框架。它能让用户轻松构建自定义智能体与环境&#xff0c;通过强化学习进行训练并部署。文中还展示了用其训练的DeepSWE等智能体的出色表现&#xff0c;以及rLL…

rocky8 --Elasticsearch+Logstash+Filebeat+Kibana部署【7.1.1版本】

软件说明&#xff1a; 所有软件包下载地址&#xff1a;Past Releases of Elastic Stack Software | Elastic 打开页面后选择对应的组件及版本即可&#xff01; 所有软件包名称如下&#xff1a; 架构拓扑&#xff1a; 集群模式&#xff1a; 单机模式 架构规划&#xff1a…

【JVM】内存分配与回收原则

在 Java 开发中&#xff0c;自动内存管理是 JVM 的核心能力之一&#xff0c;而内存分配与回收的策略直接影响程序的性能和稳定性。本文将详细解析 JVM 的内存分配机制、对象回收规则以及背后的设计思想&#xff0c;帮助开发者更好地理解 JVM 的 "自动化" 内存管理逻辑…

Qt获取hid设备信息

Qt 中通过 HID&#xff08;Human Interface Device&#xff09;接口获取指定的 USB 设备&#xff0c;并读取其数据。资源文件中包含了 hidapi.h、hidapi.dll 和 hidapi.lib。通过这些文件&#xff0c;您可以在 Qt 项目中实现对 USB 设备的 HID 接口调用。#include <QObject&…

Anaconda Jupyter 使用注意事项

Anaconda Jupyter 使用注意事项 1.将cell转换为markdown。 First, select the cell you want to convertPress Esc to enter command mode (the cell border should turn blue)Press M to convert the cell to Markdown在编辑模式下按下ESC键&#xff0c;使单元块&#xff08;c…