PyTorch中nn.Module详解和综合代码示例

在 PyTorch 中,nn.Module 是神经网络中最核心的基类,用于构建所有模型。理解并熟练使用 nn.Module 是掌握 PyTorch 的关键。


一、什么是 nn.Module

nn.Module 是 PyTorch 中所有神经网络模块的基类。可以把它看作是“神经网络的容器”,它封装了以下几件事:

  1. 网络层(如 Linear、Conv2d 等)
  2. 前向传播逻辑(forward 函数)
  3. 模型参数(自动注册并可训练)
  4. 可嵌套(可以包含多个子模块)
  5. 便捷的模型保存 / 加载等工具函数

二、基础用法

2.1 自定义模型类

import torch
import torch.nn as nnclass MyNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x

2.2 实例化与调用

model = MyNet()
x = torch.randn(32, 784)     # batch_size = 32
output = model(x)            # 自动调用 forward

三、构造方法详解

3.1 __init__()

  • 定义子模块、层等结构。
  • 例如 self.conv1 = nn.Conv2d(...) 会被自动注册为模型参数。

3.2 forward()

  • 定义前向传播逻辑。
  • 不能手动调用,应使用 model(x) 形式。

四、常见模块层

模块名作用示例
nn.Linear全连接层nn.Linear(128, 64)
nn.Conv2d卷积层nn.Conv2d(3, 16, 3)
nn.ReLU激活函数nn.ReLU()
nn.Sigmoid激活函数nn.Sigmoid()
nn.BatchNorm2d批归一化nn.BatchNorm2d(16)
nn.DropoutDropout 层nn.Dropout(0.5)
nn.LSTMLSTM 层nn.LSTM(10, 20)
nn.Sequential层的顺序容器见下文说明

五、模型嵌套结构(子模块)

你可以将一个 nn.Module 作为另一个模块的子模块嵌套:

class Block(nn.Module):def __init__(self):super().__init__()self.layer = nn.Sequential(nn.Linear(64, 64),nn.ReLU())def forward(self, x):return self.layer(x)class Net(nn.Module):def __init__(self):super().__init__()self.block1 = Block()self.block2 = Block()self.output = nn.Linear(64, 10)def forward(self, x):x = self.block1(x)x = self.block2(x)return self.output(x)

六、内置方法和属性

方法 / 属性说明
model.parameters()返回所有可训练参数(用于优化器)
model.named_parameters()返回带名字的参数迭代器
model.children()返回子模块迭代器
model.eval()设置为评估模式(Dropout、BN失效)
model.train()设置为训练模式
model.to(device)将模型转移到 GPU/CPU
model.state_dict()获取模型参数字典(保存)
model.load_state_dict()加载模型参数字典

七、使用 nn.Sequential

nn.Sequential 是一个顺序容器,可以用来简化网络结构定义:

model = nn.Sequential(nn.Linear(784, 128),nn.ReLU(),nn.Linear(128, 10)
)

等价于手写的自定义 nn.Module。适合前向传播是线性“流动”的结构。


八、实战完整示例:MNIST 分类网络

class MNISTNet(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Flatten(),nn.Linear(28*28, 256),nn.ReLU(),nn.Linear(256, 10))def forward(self, x):return self.net(x)# 实例化模型
model = MNISTNet()
print(model)# 配置训练
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)# 示例训练循环
for epoch in range(10):for images, labels in train_loader:output = model(images)loss = criterion(output, labels)optimizer.zero_grad()loss.backward()optimizer.step()

九、常见陷阱和建议

问题说明
forward() 不起作用应该使用 model(x),而不是手动调用 model.forward(x)
忘记 super().__init__()子模块将不会被注册
参数未注册层/模块必须赋值为 self.xxx = ...
训练/测试模式混淆注意 model.eval()model.train()

十、总结

项目说明
__init__()定义模型结构(子模块、层)
forward()定义前向传播
自动注册参数所有 self.xxx = nn.XXX(...) 都会被追踪
嵌套模块支持递归子模块调用
便捷方法.parameters().to().eval()

十一、综合示例

以下是基于 PyTorch nn.Module 封装的三种经典深度学习架构(ResNet18UNetTransformer)的简洁而完整的实现,适合初学者快速上手。


1、ResNet18 简洁实现(适合图像分类)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_planes, planes, stride=1, downsample=None):super().__init__()self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1   = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)self.bn2   = nn.BatchNorm2d(planes)self.downsample = downsampledef forward(self, x):identity = xif self.downsample:identity = self.downsample(x)out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += identityreturn F.relu(out)class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):super().__init__()self.in_planes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1   = nn.BatchNorm2d(64)self.pool  = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64,  layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc      = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.in_planes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_planes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(planes * block.expansion))layers = [block(self.in_planes, planes, stride, downsample)]self.in_planes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.in_planes, planes))return nn.Sequential(*layers)def forward(self, x):x = self.pool(F.relu(self.bn1(self.conv1(x))))x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x).flatten(1)return self.fc(x)def ResNet18(num_classes=1000):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

2、UNet(适合图像分割)

class UNetBlock(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.ReLU(inplace=True))def forward(self, x):return self.block(x)class UNet(nn.Module):def __init__(self, in_channels=1, out_channels=1):super().__init__()self.enc1 = UNetBlock(in_channels, 64)self.enc2 = UNetBlock(64, 128)self.enc3 = UNetBlock(128, 256)self.enc4 = UNetBlock(256, 512)self.pool = nn.MaxPool2d(2)self.bottleneck = UNetBlock(512, 1024)self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)self.dec4 = UNetBlock(1024, 512)self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)self.dec3 = UNetBlock(512, 256)self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)self.dec2 = UNetBlock(256, 128)self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)self.dec1 = UNetBlock(128, 64)self.final = nn.Conv2d(64, out_channels, kernel_size=1)def forward(self, x):e1 = self.enc1(x)e2 = self.enc2(self.pool(e1))e3 = self.enc3(self.pool(e2))e4 = self.enc4(self.pool(e3))b  = self.bottleneck(self.pool(e4))d4 = self.upconv4(b)d4 = self.dec4(torch.cat([d4, e4], dim=1))d3 = self.upconv3(d4)d3 = self.dec3(torch.cat([d3, e3], dim=1))d2 = self.upconv2(d3)d2 = self.dec2(torch.cat([d2, e2], dim=1))d1 = self.upconv1(d2)d1 = self.dec1(torch.cat([d1, e1], dim=1))return self.final(d1)

3、简化版 Transformer 编码器(适合序列建模)

class TransformerBlock(nn.Module):def __init__(self, embed_dim, heads, ff_hidden_dim, dropout=0.1):super().__init__()self.attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout, batch_first=True)self.ff = nn.Sequential(nn.Linear(embed_dim, ff_hidden_dim),nn.ReLU(),nn.Linear(ff_hidden_dim, embed_dim))self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):attn_out, _ = self.attn(x, x, x, attn_mask=mask)x = self.norm1(x + self.dropout(attn_out))ff_out = self.ff(x)x = self.norm2(x + self.dropout(ff_out))return xclass TransformerEncoder(nn.Module):def __init__(self, vocab_size, embed_dim=512, n_heads=8, ff_dim=2048, num_layers=6, max_len=512):super().__init__()self.embedding = nn.Embedding(vocab_size, embed_dim)self.pos_encoding = self._generate_positional_encoding(max_len, embed_dim)self.layers = nn.ModuleList([TransformerBlock(embed_dim, n_heads, ff_dim)for _ in range(num_layers)])self.dropout = nn.Dropout(0.1)def _generate_positional_encoding(self, max_len, d_model):pos = torch.arange(0, max_len).unsqueeze(1)i = torch.arange(0, d_model, 2)angle_rates = 1 / torch.pow(10000, (i / d_model))pos_enc = torch.zeros(max_len, d_model)pos_enc[:, 0::2] = torch.sin(pos * angle_rates)pos_enc[:, 1::2] = torch.cos(pos * angle_rates)return pos_enc.unsqueeze(0)def forward(self, x):B, T = x.shapex = self.embedding(x) + self.pos_encoding[:, :T].to(x.device)x = self.dropout(x)for layer in self.layers:x = layer(x)return x

4、 总结对比

模型类型场景特点
ResNet18图像分类深残差网络结构,适合迁移学习
UNet图像分割对称结构,编码 + 解码 + skip
TransformerNLP / 序列建模全注意力机制,无卷积无循环

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/90562.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/90562.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入解析三大Web安全威胁:文件上传漏洞、SQL注入漏洞与WebShell

文章目录文件上传漏洞SQL注入漏洞WebShell三者的核心关联:攻击链闭环文件上传漏洞 文件上传漏洞(File Upload Vulnerability) 当Web应用允许用户上传文件但未实施充分的安全验证时,攻击者可上传恶意文件(如WebShell、…

【对比】群体智能优化算法 vs 贝叶斯优化

在机器学习、工程优化和科学计算中,优化算法的选择直接影响问题求解的效率与效果。群体智能优化算法(Swarm Intelligence, SI)和贝叶斯优化(Bayesian Optimization, BO)是两种截然不同的优化范式,分别以不同…

LLMs之Agent:ChatGPT Agent发布—统一代理系统将研究与行动无缝对接,开启智能助理新时代

LLMs之Agent:ChatGPT Agent发布—统一代理系统将研究与行动无缝对接,开启智能助理新时代 目录 OpenAI重磅发布ChatGPT Agent—统一代理系统将研究与行动无缝对接,开启智能助理新时代 第一部分:Operator 和深度研究的自然演进 第…

Linux726 raid0,raid1,raid5;raid 创建、保存、停止、删除

RAID创建 创建raid0 安装mdadm yum install mdadm mdadm --create /dev/md0 --raid-devices2 /dev/sdb5 /dev/sdb6 [rootsamba caozx26]# mdadm --create /dev/md0 --raid-devices2 /dev/sdb3 /dev/sdb5 --level0 mdadm: Defaulting to version 1.2 metadata mdadm: array /dev…

深入剖析 MetaGPT 中的提示词工程:WriteCode 动作的提示词设计

今天,我想和大家分享关于 AI 提示词工程的文章。提示词(Prompt)是大型语言模型(LLM)生成高质量输出的关键,而在像 MetaGPT 这样的 AI 驱动软件开发框架中,提示词的设计直接决定了代码生成的可靠…

关于 ESXi 中 “ExcelnstalledOnly 已禁用“ 的解决方案

第一步:使用ssh登录esxi esxcli system settings advanced list -o /User/execInstalledOnly可能会得到以下内容 esxcli system settings advanced list -o /User/execInstalledOnlyPath: /User/ExecInstalledOnlyType: integerInt Value: 0Default Int Value: 1Min…

HTML5 Canvas 绘制圆弧效果

HTML5 Canvas 绘制圆弧效果 以下是一个使用HTML5 Canvas绘制圆弧的完整示例&#xff0c;你可以直接在浏览器中运行看到效果&#xff1a; <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"view…

智能Agent场景实战指南 Day 18:Agent决策树与规划能力

【智能Agent场景实战指南 Day 18】Agent决策树与规划能力 开篇 欢迎来到"智能Agent场景实战指南"系列的第18天&#xff01;今天我们将深入探讨智能Agent的核心能力之一&#xff1a;决策树与规划能力。在现代业务场景中&#xff0c;Agent需要具备类似人类的决策能力…

AI 编程工具 Trae 重要的升级。。。

大家好&#xff0c;我是樱木。 今天打开 Trae &#xff0c;已经看到它进行图标升级&#xff0c;之前的图标&#xff0c;国际和国内版本长得非常像&#xff0c;现在做了很明显的区分&#xff0c;这点给 Trae 团队点个赞。 自从 Claude 使出了压力以来&#xff0c;Cursor 锁区&…

排序算法,咕咕咕

1.选择排序void selectsort(vector<int>& v) { for(int i0;i<v.size()-1;i) {int minii;for(int ji1;j<v.size();j){if(v[i]>v[j]){minij;}}if(mini!i)swap(v[i],v[mini]); } }2.堆排序void adjustdown(vector<int>& v,int root,int size) { int …

数据库查询系统——pyqt+python实现Excel内查课

一、引言 数据库查询系统处处存在&#xff0c;在教育信息化背景下&#xff0c;数据库查询技术更已深度融入教务管理场景。本系统采用轻量化架构&#xff0c;结合Excel课表&#xff0c;通过PythonPyQt5实现跨平台桌面应用&#xff0c;以实现简单查课效果。 二、GUI界面设计 使用…

base64魔改算法 | jsvmp日志分析并还原

前言 上一篇我们讲了标准 base64 算法还原&#xff0c;为了进一步学习 base64 算法特点&#xff0c;本文将结合 jsvmp 日志&#xff0c;实战还原出 base64 魔改算法。 为了方便大家学习&#xff0c;我将入参和上篇文章一样&#xff0c;入参为 Hello, World!。 插桩 在js代码中&…

vue3笔记(2)自用

目录 一、作用域插槽 二、pinia的使用 一、Pinia 基本概念与用法 1. 安装与初始化 2. 创建 Store 3. 在组件中使用 Store 4. 高级用法 5、storeToRefs 二、Pinia 与 Vuex 的主要区别 三、为什么选择 Pinia&#xff1f; 三、定义全局指令 1.封装通用 DOM 操作&#…

大模型面试回答,介绍项目

1. 模型准备与转换&#xff08;PC端/服务器&#xff09;你先在PC上下载或训练好大语言模型&#xff08;如HuggingFace格式&#xff09;。用RKLLM-Toolkit把模型转换成瑞芯微NPU能用的专用格式&#xff08;.rkllm&#xff09;&#xff0c;并可选择量化优化。把转换好的模型文件拷…

Oracle 19.20未知BUG导致oraagent进程内存泄漏

故障现象查询操作系统进程的使用排序&#xff0c;这里看到oraagent的物理内存达到16G&#xff0c;远远超过正常环境&#xff08;正常环境在19.20大概就是100M多一点&#xff09;[rootorastd tmp]# ./hmem|more PID NAME VIRT(kB) SHARED(kB) R…

尝试几道算法题,提升python编程思维

一、跳跃游戏题目描述&#xff1a; 给定一个非负整数数组 nums&#xff0c;你最初位于数组的第一个下标。数组中的每个元素代表你在该位置可以跳跃的最大长度。判断你是否能够到达最后一个下标。示例&#xff1a;输入&#xff1a;nums [2,3,1,1,4] → 输出&#xff1a;True输入…

【菜狗处理脏数据】对很多个不同时间序列数据的文件聚类—20250722

目录 具体做法 可视化方法1&#xff1a;PCA降维 可视化方法2、TSNE降维可视化&#xff08;非线性降维&#xff0c;更适合聚类&#xff09; 可视化方法3、轮廓系数评判好坏 每个文件有很多行列的信息&#xff0c;每列是一个驾驶相关的数据&#xff0c;需要对这些文件进行聚类…

Qwen-MT:翻得快,译得巧

我们再向大家介绍一位新朋友&#xff1a;机器翻译模型Qwen-MT。开发者朋友们可通过Qwen API&#xff08;qwen-mt-turbo&#xff09;&#xff0c;来直接体验它又快又准的翻译技能。 本次更新基于强大的 Qwen3 模型&#xff0c;进一步使用超大规模多语言和翻译数据对模型进行训练…

在 OceanBase 中,使用 TO_CHAR 函数 直接转换日期格式,简洁高效的解决方案

SQL语句SELECT TO_CHAR(TO_DATE(your_column, DD-MON-YY), YYYY-MM-DD) AS formatted_date FROM your_table;关键说明&#xff1a;核心函数&#xff1a;TO_DATE(30-三月-15, DD-MON-YY) → 将字符串转为日期类型TO_CHAR(..., YYYY-MM-DD) → 格式化为 2015-03-30处理中文月份&a…

pnpm运行electronic项目报错,npm运行正常。electronic项目打包为exe报错

pnpm运行electronic项目报错 使用 pnpm 运行 electronic 项目报错&#xff0c;npm 运行正常&#xff0c;报错内容如下 error during start dev server and electron app: Error: Electron uninstallat getElectronPath (file:///E:/project/xxx-vue/node_modules/.pnpm/elect…