Python day50

@浙大疏锦行 python day50.

  • 在预训练模型(resnet18)中添加cbam注意力机制,需要修改模型的架构,同时应该考虑插入的cbam注意力机制模块的位置;
import torch
import torch.nn as nn
from torchvision import models# 自定义ResNet18模型,插入CBAM模块
class ResNet18_CBAM(nn.Module):def __init__(self, num_classes=10, pretrained=True, cbam_ratio=16, cbam_kernel=7):super().__init__()# 加载预训练ResNet18self.backbone = models.resnet18(pretrained=pretrained) # 修改首层卷积以适应32x32输入(CIFAR10)self.backbone.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)self.backbone.maxpool = nn.Identity()  # 移除原始MaxPool层(因输入尺寸小)# 在每个残差块组后添加CBAM模块self.cbam_layer1 = CBAM(in_channels=64, ratio=cbam_ratio, kernel_size=cbam_kernel)self.cbam_layer2 = CBAM(in_channels=128, ratio=cbam_ratio, kernel_size=cbam_kernel)self.cbam_layer3 = CBAM(in_channels=256, ratio=cbam_ratio, kernel_size=cbam_kernel)self.cbam_layer4 = CBAM(in_channels=512, ratio=cbam_ratio, kernel_size=cbam_kernel)# 修改分类头self.backbone.fc = nn.Linear(in_features=512, out_features=num_classes)def forward(self, x):# 主干特征提取x = self.backbone.conv1(x)x = self.backbone.bn1(x)x = self.backbone.relu(x)  # [B, 64, 32, 32]# 第一层残差块 + CBAMx = self.backbone.layer1(x)  # [B, 64, 32, 32]x = self.cbam_layer1(x)# 第二层残差块 + CBAMx = self.backbone.layer2(x)  # [B, 128, 16, 16]x = self.cbam_layer2(x)# 第三层残差块 + CBAMx = self.backbone.layer3(x)  # [B, 256, 8, 8]x = self.cbam_layer3(x)# 第四层残差块 + CBAMx = self.backbone.layer4(x)  # [B, 512, 4, 4]x = self.cbam_layer4(x)# 全局平均池化 + 分类x = self.backbone.avgpool(x)  # [B, 512, 1, 1]x = torch.flatten(x, 1)  # [B, 512]x = self.backbone.fc(x)  # [B, 10]return x# 初始化模型并移至设备
model = ResNet18_CBAM().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
  • 修改模型结构后,需要考虑模型训练的策略,一般来说可以先冻结原有的部分进行训练以期待新增的部分可以获得一个不错的表现;之后解冻原有部分中的高层layer并赋予一个较低的学习率来保证不会出现不应该的错误;最后解冻所有参数,也是赋予较低的学习率,来学习最终的端到端任务。
def set_trainable_layers(model, trainable_parts):print(f"\n---> 解冻以下部分并设为可训练: {trainable_parts}")for name, param in model.named_parameters():param.requires_grad = Falsefor part in trainable_parts:if part in name:param.requires_grad = Truebreakdef train_staged_finetuning(model, criterion, train_loader, test_loader, device, epochs):optimizer = None# 初始化历史记录列表,与你的要求一致all_iter_losses, iter_indices = [], []train_acc_history, test_acc_history = [], []train_loss_history, test_loss_history = [], []for epoch in range(1, epochs + 1):epoch_start_time = time.time()# --- 动态调整学习率和冻结层 ---if epoch == 1:print("\n" + "="*50 + "\n🚀 **阶段 1:训练注意力模块和分类头**\n" + "="*50)set_trainable_layers(model, ["cbam", "backbone.fc"])optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)elif epoch == 6:print("\n" + "="*50 + "\n✈️ **阶段 2:解冻高层卷积层 (layer3, layer4)**\n" + "="*50)set_trainable_layers(model, ["cbam", "backbone.fc", "backbone.layer3", "backbone.layer4"])optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)elif epoch == 21:print("\n" + "="*50 + "\n🛰️ **阶段 3:解冻所有层,进行全局微调**\n" + "="*50)for param in model.parameters(): param.requires_grad = Trueoptimizer = optim.Adam(model.parameters(), lr=1e-5)# --- 训练循环 ---model.train()running_loss, correct, total = 0.0, 0, 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录每个iteration的损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append((epoch - 1) * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 按你的要求,每100个batch打印一次if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_loss_history.append(epoch_train_loss)train_acc_history.append(epoch_train_acc)# --- 测试循环 ---model.eval()test_loss, correct_test, total_test = 0, 0, 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_loss_history.append(epoch_test_loss)test_acc_history.append(epoch_test_acc)# 打印每个epoch的最终结果print(f'Epoch {epoch}/{epochs} 完成 | 耗时: {time.time() - epoch_start_time:.2f}s | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')# 训练结束后调用绘图函数print("\n训练完成! 开始绘制结果图表...")plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)# 返回最终的测试准确率return epoch_test_accmodel = ResNet18_CBAM().to(device)
criterion = nn.CrossEntropyLoss()
epochs = 50print("开始使用带分阶段微调策略的ResNet18+CBAM模型进行训练...")
final_accuracy = train_staged_finetuning(model, criterion, train_loader, test_loader, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")torch.save(model.state_dict(), 'resnet18_cbam_finetuned.pth')
print("模型已保存为: resnet18_cbam_finetuned.pth")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/96216.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/96216.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VPS海外节点性能监控全攻略:从基础配置到高级优化

在全球化业务部署中,VPS海外节点的稳定运行直接影响用户体验。本文将深入解析如何构建高效的性能监控体系,涵盖网络延迟检测、资源阈值设置、告警机制优化等核心环节,帮助运维人员实现跨国服务器的可视化管控。 VPS海外节点性能监控全攻略&am…

C语言初学者笔记【结构体】

文章目录一、结构体的使用1. 结构体声明2. 变量创建与初始化3. 特殊声明与陷阱二、内存对齐1. 规则:2. 示例分析:3. 修改默认对齐数:三、结构体传参四、结构体实现位段1. 定义2. 内存分配3. 应用场景4. 跨平台问题:5. 注意事项&am…

基于XGBoost算法的数据回归预测 极限梯度提升算法 XGBoost

一、作品详细简介 1.1附件文件夹程序代码截图 全部完整源代码,请在个人首页置顶文章查看: 学行库小秘_CSDN博客​编辑https://blog.csdn.net/weixin_47760707?spm1000.2115.3001.5343 1.2各文件夹说明 1.2.1 main.m主函数文件 该MATLAB 代码实现了…

数据安全系列4:常用的对称算法浅析

常用的算法介绍 常用的算法JAVA实现 jce及其它开源包介绍、对比 传送门 数据安全系列1:开篇 数据安全系列2:单向散列函数概念 数据安全系列3:密码技术概述 时代有浪潮,就有退去的时候 在我的博客文章里面,其中…

云计算学习100天-第26天

地址重写地址重写语法——关于Nginx服务器的地址重写,主要用到的配置参数是rewrite 语法格式: rewrite regex replacement flag rewrite 旧地址 新地址 [选项]地址重写步骤:#修改配置文件(访问a.html重定向到b.html) cd /usr/local/ngin…

【Python办公】字符分割拼接工具(GUI工具)

目录 专栏导读 项目简介 功能特性 🔧 核心功能 1. 字符分割功能 2. 字符拼接功能 🎨 界面特性 现代化设计 用户体验优化 技术实现 开发环境 核心代码结构 关键技术点 使用指南 安装步骤 完整代码 字符分割操作 字符拼接操作 应用场景 数据处理 文本编辑 开发辅助 项目优势 …

Windows 命令行:dir 命令

专栏导航 上一篇:Windows 命令行:Exit 命令 回到目录 下一篇:MFC 第一章概述 本节前言 学习本节知识,需要你首先懂得如何打开一个命令行界面,也就是命令提示符界面。链接如下。 参考课节:Windows 命令…

软考高级--系统架构设计师--案例分析真题解析

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录前言试题一 软件架构设计一、2019年 案例分析二、2020年 案例分析三、2021年 案例分析四、2022年 案例分析试题二 软件系统设计一、2019年 案例分析二、2020年 案例分…

css中的性能优化之content-visibility: auto

content-visibility: auto的核心机制是让浏览器智能跳过屏幕外元素的渲染工作,包括布局和绘制,直到它们接近视口时才渲染。这与虚拟滚动等传统方案相比优势明显,只需要一行CSS就能实现近似效果。值得注意的是必须配合contain-intrinsic-size属…

通过uniapp将vite vue3项目打包为android系统的.apk包,并实现可自动升级功能

打包vue项目,注意vite.config.ts文件和路由文件设置 vite.config.ts,将base等配置改为./ import {fileURLToPath, URL } from node:urlimport {defineConfig } from vite import vue from @vitejs/plugin-vue import AutoImport from unplugin-auto-import/vite import Com…

经营帮租赁经营板块:解锁资产运营新生态,赋能企业增长新引擎

在商业浪潮奔涌向前的当下,企业资产运营与租赁管理的模式不断迭代,“经营帮” 以其租赁经营板块为支点,构建起涵盖多元业务场景、适配不同需求的生态体系,成为众多企业破局资产低效困局、挖掘增长新动能的关键助力。本文将深度拆解…

C语言---编译的最小单位---令牌(Token)

文章目录C语言中令牌几类令牌是编译器理解源代码的最小功能单元,是编译过程的第一步。C语言中令牌几类 1、关键字: 具有固定含义的保留字,如 int, if, for, while, return 等。 2、标识符: 由程序员定义的名称,用于变…

机器学习 | Python中进行特征重要性分析的9个常用方法

在Python中,特征重要性分析是机器学习模型解释和特征选择的关键步骤。以下是9种常用方法及其实现示例: 1. 基于树的模型内置特征重要性 原理:树模型(如随机森林、XGBoost)根据特征分裂时的纯度提升(基尼不纯度/信息增益)计算重要性。 from sklearn.ensemble import Ra…

心路历程-了解网络相关知识

在做这个题材的时候,考虑的一个点就是:自己的最初的想法;可是技术是不断更新的; 以前的材料会落后,但是万变不能变其中;所以呈现出来的知识点也相对比较老旧,为什么呢? 因为最新的素…

CAT1+mqtt

文章目录 MQTT知识点mqtt数据固定报头可变报头(连接请求)有效载荷 阿里云MQTT测试订阅Topic下发数据给MQTT.fxMQTT.fx 发布消息给服务器 下载mqtt(C-嵌入式版)我的W5500项目路径使用Cat1连接阿里云平台AT指令串口连接1. 开机联网2. 激活内置SIM卡(贴片卡)3. 我这里使用连接的是…

AiPPT怎么样?好用吗?

AiPPT怎么样?好用吗?AiPPT 是一款智能高效的PPT生成工具,通过AI技术快速将主题或文档(如Word/PDF)转化为专业PPT,提供超10万套行业模板,覆盖商务、教育等22场景,支持一键生成大纲、文…

恶补DSP:2.F28335的定时器系统

一、定时器原理F28335 城市的三座时钟塔(Timer0、Timer1、Timer2)是城市时间管理的核心设施,每座均为32位精度,依靠城市能源脉冲(系统时钟 SYSCLKOUT,典型频率为150 MHz)驱动。它们由两个核心模…

用倒计时软件为考研备考精准导航 复习 模拟考试 日期倒计时都可以用

考研,是一场与时间的博弈。从决定报名的那一刻起,日历上的每一个数字都被赋予了特殊意义 —— 报名截止日、现场确认期、初试倒计时、成绩查询点…… 这些节点如同航标,指引着备考者的方向。而在这场漫长的征途里,一款精准、易用的…

React学习(七)

目录:1.react-进阶-antd-搜索2.react-进阶-antd-依赖项说明 3.react-进阶-antd-删除1.react-进阶-antd-搜索我们jsx代码里只能返回一个最顶层的根元素下拉框简化写法:把这个对象结构赋值一下:清空定义个参数类型做修改事件需要定义三个…

Unix Domain Socket(UDS)和 TCP/IP(使用 127.0.0.1)进程间通信(IPC)的比较

Unix Domain Socket(UDS)和 TCP/IP(使用 127.0.0.1 或 localhost)都是进程间通信(IPC)的方式,但它们在实现、性能和适用场景上有显著区别。以下是两者的对比:1. 通信机制Unix Domain…