Python训练营---Day44

DAY 44 预训练模型

知识点回顾:

  1. 预训练的概念
  2. 常见的分类预训练模型
  3. 图像预训练模型的发展史
  4. 预训练的策略
  5. 预训练代码实战:resnet18

作业:

  1. 尝试在cifar10对比如下其他的预训练模型,观察差异,尽可能和他人选择的不同
  2. 尝试通过ctrl进入resnet的内部,观察残差究竟是什么

选用 DenseNet121预训练模型,注意DenseNet121 模型的最后分类层名为classifier,而不是 ResNet 中的fc

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.models import resnet18, densenet121, vgg16# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./cifar_data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./cifar_data',train=False,transform=test_transform
)# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义DenseNet121模型
def create_densenet121(pretrained=True, num_classes=10):model = models.densenet121(pretrained=pretrained)# 修改最后一层全连接层in_features = model.classifier.in_featuresmodel.classifier = nn.Linear(in_features, num_classes) # DenseNet121 的最后一层分类器名称是classifierreturn model.to(device)# 5. 冻结/解冻模型层的函数
# 这种设计允许我们在迁移学习中保留预训练模型的特征提取部分(卷积层),只训练新添加的分类层(全连接层)。
def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""# 冻结/解冻除fc层外的所有参数for name, param in model.named_parameters():if 'classifier' not in name:    #排除名称中包含 "fc" 的参数,这些通常是全连接层的参数param.requires_grad = not freeze    #param.requires_grad是 PyTorch 中控制参数是否参与反向传播和梯度更新的标志# 打印冻结状态frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)   #统计所有requires_grad=False的参数数量total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 6. 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始冻结卷积层if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解冻控制:在指定轮次后解冻所有层if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解冻后调整优化器(可选)optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合model.train()  # 设置为训练模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新学习率调度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 结果print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")# 绘制损失和准确率曲线plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最终测试准确率# 7. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('训练过程中的Iteration损失变化')plt.grid(True)plt.show()# 8. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('准确率随Epoch变化')plt.legend()plt.grid(True)# 损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('损失值随Epoch变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主函数:训练模型
def main():# 参数设置epochs = 40  # 总训练轮次freeze_epochs = 5  # 冻结卷积层的轮次learning_rate = 1e-3  # 初始学习率weight_decay = 1e-4  # 权重衰减# 创建DenseNet121模型(加载预训练权重)model = create_densenet121(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 开始训练(前5轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/908347.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.文件操作相关的库

一、filesystem(C17) 和 fstream 1.std::filesystem::path - cppreference.cn - C参考手册 std::filesystem::path 表示路径 构造函数: path( string_type&& source, format fmt auto_format ); 可以用string进行构造,也可以用string进行隐式类…

【 java 集合知识 第二篇 】

目录 1.Map集合 1.1.快速遍历Map 1.2.HashMap实现原理 1.3.HashMap的扩容机制 1.4.HashMap在多线程下的问题 1.5.解决哈希冲突的方法 1.6.HashMap的put过程 1.7.HashMap的key使用什么类型 1.8.HashMapkey可以为null的原因 1.9.HashMap为什么不采用平衡二叉树 1.10.Hash…

【Dify 知识库 API】“根据文本更新文档” 真的是差异更新吗?一文讲透真实机制!

在使用 Dify 知识库 API 过程中,很多开发者在调用 /datasets/{dataset_id}/document/update-by-text 接口时,常常会产生一个疑问: 👉 这个接口到底是 “智能差异更新” 还是 “纯覆盖更新”? 网上的资料并不多,很多人根据接口名误以为是增量更新。今天我结合官方源码 …

大模型如何革新用户价值、内容匹配与ROI预估

写在前面 在数字营销的战场上,理解用户、精准触达、高效转化是永恒的追求。传统方法依赖结构化数据和机器学习模型,在用户价值评估、人群素材匹配以及策略ROI预估等核心问题上取得了显著成就。然而,随着数据维度日益复杂,用户行为愈发多变,传统方法也面临着特征工程繁琐、…

基于端到端深度学习模型的语音控制人机交互系统

基于端到端深度学习模型的语音控制人机交互系统 摘要 本文设计并实现了一个基于端到端深度学习模型的人机交互系统,通过语音指令控制其他设备的程序运行,并将程序运行结果通过语音合成方式反馈给用户。系统采用Python语言开发,使用PyTorch框架实现端到端的语音识别(ASR)…

【2025年】解决Burpsuite抓不到https包的问题

环境:windows11 burpsuite:2025.5 在抓取https网站时,burpsuite抓取不到https数据包,只显示: 解决该问题只需如下三个步骤: 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…

Jenkins 工作流程

1. 触发构建 Jenkins 的工作流程从触发构建开始。构建可以由以下几种方式触发: 代码提交触发:通过与版本控制系统(如 Git、SVN)集成,当代码仓库有新的提交时,Jenkins 会自动触发构建。 定时触发&#xff…

Jmeter如何进行多服务器远程测试?

🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 JMeter是Apache软件基金会的开源项目,主要来做功能和性能测试,用Java编写。 我们一般都会用JMeter在本地进行测试,但是受到…

Kafka入门-生产者

生产者 生产者发送流程: 延迟时间为0ms时,也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于:异步发送不需要等待结果,同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…

分类预测 | Matlab实现CNN-LSTM-Attention高光谱数据分类

分类预测 | Matlab实现CNN-LSTM-Attention高光谱数据分类 目录 分类预测 | Matlab实现CNN-LSTM-Attention高光谱数据分类分类效果功能概述程序设计参考资料 分类效果 功能概述 代码功能 该MATLAB代码实现了一个结合CNN、LSTM和注意力机制的高光谱数据分类模型,核心…

gemini和chatgpt数据对比:谁在卷性能、价格和场景?

先把结论“剧透”给赶时间的朋友:顶配 Gemini Ultra/2.5 Pro 在纸面成绩上普遍领先,而 ChatGPT 家族(GPT-4o / o3 / 4.1)则在延迟、生态和稳定性上占优。下面把核心数据拆开讲,方便你对号入座。附带参考来源&#xff0…

代码训练LeetCode(23)随机访问元素

代码训练(23)LeetCode之随机访问元素 Author: Once Day Date: 2025年6月5日 漫漫长路,才刚刚开始… 全系列文章可参考专栏: 十年代码训练_Once-Day的博客-CSDN博客 参考文章: 380. O(1) 时间插入、删除和获取随机元素 - 力扣(LeetCode)力…

C++面试5——对象存储区域详解

C++对象存储区域详解 核心观点:内存是程序员的战场,存储区域决定对象的生杀大权!栈对象自动赴死,堆对象生死由你,全局对象永生不死,常量区对象只读不灭。 一、四大地域生死簿 栈区(Stack) • 特点:自动分配释放,速度极快(类似高铁进出站) • 生存期:函数大括号{}就…

STM32 智能小车项目 L298N 电机驱动模块

今天开始着手做智能小车的项目了 在智能小车或机器人项目中,我们经常会听到一个词叫 “H 桥电机驱动”,尤其是常见的 L298N 模块,就是基于“双 H 桥”原理设计的。那么,“H 桥”到底是什么?为什么要用“双 H 桥”来驱动…

python项目如何创建docker环境

这里写自定义目录标题 python项目创建docker环境docker配置国内镜像源构建一个Docker 镜像验证镜像合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPant…

MySQL-多表关系、多表查询

一. 一对多(多对一) 1. 例如;一个部门下有多个员工 在数据库表中多的一方(员工表)、添加字段,来关联一的一方(部门表)的主键 二. 外键约束 1.如将部门表的部门直接删除,然而员工表还存在其部门下的员工,出现了数据的不一致问题&am…

【 HarmonyOS 5 入门系列 】鸿蒙HarmonyOS示例项目讲解

【 HarmonyOS 5 入门系列 】鸿蒙HarmonyOS示例项目讲解 一、前言:移动开发声明式 UI 框架的技术变革 在移动操作系统的发展历程中,UI 开发模式经历了从命令式到声明式的重大变革。 根据华为开发者联盟 2024 年数据报告显示,HarmonyOS 设备…

【SSM】SpringMVC学习笔记7:前后端数据传输协议和异常处理

这篇学习笔记是Spring系列笔记的第7篇,该笔记是笔者在学习黑马程序员SSM框架教程课程期间的笔记,供自己和他人参考。 Spring学习笔记目录 笔记1:【SSM】Spring基础: IoC配置学习笔记-CSDN博客 对应黑马课程P1~P20的内容。 笔记2…

借助 Spring AI 和 LM Studio 为业务系统引入本地 AI 能力

Spring AI 1.0.0-SNAPSHOTLM Studio 0.3.16qwen3-4b 参考 Unable to use spring ai with LMStudio using spring-ai openai module Issue #2441 spring-projects/spring-ai GitHub LM Studio 下载安装 LM Studio下载 qwen3-4b 模型。对于 qwen3 系列模型,测试…

C++学习-入门到精通【13】标准库的容器和迭代器

C学习-入门到精通【13】标准库的容器和迭代器 目录 C学习-入门到精通【13】标准库的容器和迭代器一、标准模板库简介1.容器简介2.STL容器总览3.近容器4.STL容器的通用函数5.首类容器的通用typedef6.对容器元素的要求 二、迭代器简介1.使用istream_iterator输入,使用…