python训练day44 预训练模型

预训练模型发展史

预训练模型的训练策略

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 3. 创建数据加载器(可调整batch_size)
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 训练函数(支持学习率调度器)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()  # 设置为训练模式train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新学习率调度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 结果print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")# 绘制损失和准确率曲线plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最终测试准确率# 5. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('训练过程中的Iteration损失变化')plt.grid(True)plt.show()# 6. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('准确率随Epoch变化')plt.legend()plt.grid(True)# 损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('损失值随Epoch变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 导入ResNet模型
from torchvision.models import resnet18# 定义ResNet18模型(支持预训练权重加载)
def create_resnet18(pretrained=True, num_classes=10):# 加载预训练模型(ImageNet权重)model = resnet18(pretrained=pretrained)# 修改最后一层全连接层,适配CIFAR-10的10分类任务in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)# 将模型转移到指定设备(CPU/GPU)model = model.to(device)return model# 创建ResNet18模型(加载ImageNet预训练权重,不进行微调)
model = create_resnet18(pretrained=True, num_classes=10)
model.eval()  # 设置为推理模式# 测试单张图片(示例)
from torchvision import utils# 从测试数据集中获取一张图片
dataiter = iter(test_loader)
images, labels = dataiter.next()
images = images[:1].to(device)  # 取第1张图片# 前向传播
with torch.no_grad():outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 显示图片和预测结果
plt.imshow(utils.make_grid(images.cpu(), normalize=True).permute(1, 2, 0))
plt.title(f"预测类别: {predicted.item()}")
plt.axis('off')
plt.show()
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./data',train=False,transform=test_transform
)# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义ResNet18模型
def create_resnet18(pretrained=True, num_classes=10):model = models.resnet18(pretrained=pretrained)# 修改最后一层全连接层in_features = model.fc.in_featuresmodel.fc = nn.Linear(in_features, num_classes)return model.to(device)# 5. 冻结/解冻模型层的函数
def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""# 冻结/解冻除fc层外的所有参数for name, param in model.named_parameters():if 'fc' not in name:param.requires_grad = not freeze# 打印冻结状态frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 6. 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始冻结卷积层if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解冻控制:在指定轮次后解冻所有层if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解冻后调整优化器(可选)optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合model.train()  # 设置为训练模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新学习率调度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 结果print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")# 绘制损失和准确率曲线plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最终测试准确率# 7. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('训练过程中的Iteration损失变化')plt.grid(True)plt.show()# 8. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('准确率随Epoch变化')plt.legend()plt.grid(True)# 损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('损失值随Epoch变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主函数:训练模型
def main():# 参数设置epochs = 40  # 总训练轮次freeze_epochs = 5  # 冻结卷积层的轮次learning_rate = 1e-3  # 初始学习率weight_decay = 1e-4  # 权重衰减# 创建ResNet18模型(加载预训练权重)model = create_resnet18(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 开始训练(前5轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/89038.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/89038.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[论文阅读]MISSRce

论文title: MISSRec: Pre-training and Transferring Multi-modal Interest-aware Sequence Representation for Recommendation

Redis学习笔记——黑马点评 附近商铺到UV统计 完结

前言: 今天完结了Redis的所有实战篇。 学习收获: GEO数据结构: GEO就是Geolocation的简写形式,代表地理坐标。Redis在3.2版本中加入对Geo的支持,存储、管理和操作地理空间数据的特殊数据结构,它能高效处…

【客户端排查】mac电脑怎么查看客户端的实时运行日志

先退出客户端;打开访达里的应用程序; 打开【显示包内容】; 找到MacOS 双击里面的终端程序; 双击后,客户端会自动启动,且可以在终端中查看客户端的实时日志啦~

HarmonyOS NEXT仓颉开发语言实战案例:健身App

各位好,今日分享一个健身app的首页: 这个页面看起比之前的案例要稍微复杂一些,主要在于顶部部分,有重叠的背景,还有偏移的部分。重叠布局可以使用Stack容器实现,超出容器范围的偏移可以使用负数间距来实现&…

TreeMap源码分析 红黑树

今天尝试刨一下TreeMap的祖坟。 底层结构对比 先来看一下与HashMap、LinkedHashMap和TreeMap的对比,同时就当是复习一下: HashMap使用数组存储数据,并使用单向链表结构存储hash冲突数据,同一个冲突桶中数据量大的时候&#xff…

华为云Flexus+DeepSeek征文|基于Dify构建拍照识题智能学习助手

华为云FlexusDeepSeek征文|基于Dify构建拍照识题智能学习助手 一、构建拍照识题智能学习助手前言二、构建拍照识题智能学习助手环境2.1 基于FlexusX实例的Dify平台2.2 基于MaaS的模型API商用服务 三、构建拍照识题智能学习助手实战3.1 配置Dify环境3.2 配置Dify工具…

题解:CF2120E Lanes of Cars

根据贪心,不难想到每次会把最长队伍末尾的那辆车移动到最短队伍的末尾。但由于 k k k 的存在,会导致一些冗余移动的存在。设需要挪动 C C C 辆车,则怒气值可以表示为 f ( C ) k C f(C) kC f(C)kC,其中 f ( C ) f(C) f(C) 是…

Excel基础:选择和移动

本文演示Excel中基础的选择和移动操作,并在最后提供了一张思维导图,方便记忆。 文章目录 一、选择1.1 基础选择1.1.1 选择单个单元格1.1.2 选择连续范围 1.2 行列选择1.2.1 选择整行整列1.2.2 选择多行多列 1.3 全选1.3.1 全选所有单元格1.3.2 智能选择…

Java面试宝典:基础四

80. int vs Integer 维度intInteger类型基本数据类型(8种之一)包装类默认值0null应用场景性能敏感场景(计算密集)Web表单、ORM框架(区分null和0)特殊能力无提供工具方法(如parseInt())和常量(如MAX_VALUE)示例:

RabbitMQ + JMeter 深度集成指南:中间件性能优化全流程解析!

在 2025 年的数字化浪潮中,中间件性能直接决定系统的稳定性和用户体验,而 RabbitMQ 作为消息队列的“老大哥”,在分布式系统中扮演着关键角色。然而,高并发场景下,消息堆积、延迟激增等问题可能让系统不堪重负&#xf…

uniapp image引用本地图片不显示问题

1. uniapp image引用本地图片不显示问题 在uniapp 开发过程中采用image引入本地资源图片。 1.1. 相对路径和绝对路径问题 在UniApp中开发微信小程序时,引入图片时,相对路径和绝对路径可能会有一些差异。这差异主要涉及到小程序和UniApp框架的文件结构、…

论文阅读:arxiv 2025 ThinkSwitcher: When to Think Hard, When to Think Fast

总目录 大模型安全相关研究:https://blog.csdn.net/WhiffeYF/article/details/142132328 ThinkSwitcher: When to Think Hard, When to Think Fast https://arxiv.org/pdf/2505.14183#page2.08 https://www.doubao.com/chat/10031179784579842 文章目录 速览一、…

智能体记忆原理-prompt设计

智能体记忆的管理与设计开发分为以下几步: 1.记忆的抽取; 2.记忆的存储; 3.记忆的搜索; 一、记忆抽取一: FACT_RETRIEVAL_PROMPT f"""你是一位个人信息整理助手,专门负责准确存储事实、用…

026 在线文档管理系统技术架构解析:基于 Spring Boot 的企业级文档管理平台

在线文档管理系统技术架构解析:基于Spring Boot的企业级文档管理平台 在企业数字化转型的进程中,高效的文档管理系统已成为提升协作效率的核心基础设施。本文将深入解析基于Spring Boot框架构建的在线文档管理系统,该系统整合公告信息管理、…

AWTK-MVVM的一些使用技巧总结(1)

在项目中用了一段时间的AWTK-MVVM框架,由于AWTK-MVVM本身的文档十分欠缺,自己经过一段时间的研究折腾出了几个技巧,在此记录总结。 用fscript启用传统UI代码 AWTK-MVVM里面重新设计了navigator机制,重定位了navigator_to的调用方…

openwrt使用quilt工具制作补丁

前言:简单聊一下为什么需要制作补丁,因为openwrt的编译是去下载很多组件放到dl目录下面,这些组件都是压缩包。如果我们要修改这些组件里面的源码,就需要对这些组件打pacth,也就是把我们的差异点在编译的时候合入到对应…

强化学习 (1)基本概念

grid-world example 一个由多个格子组成的二维网格 三种格子:accessible可通行的; forbidden禁止通行的; target目标 state状态 state是智能体相对于环境的状态(情况) 在grid-world example里,state指的…

【Typst】纵向时间轴

概述 6月10日实验了一个纵向时间轴排版效果,当时没有做成单独的模块,也存在一些Bug。 今天(6月29日)在原基础上进行了一些改进,并总结为模块。 目前暂时发布出来,可用,后续可能会进行大改。 使用案例 导入模块使用…

【Visual Studio Code上传文件到服务器】

在 Visual Studio Code (VS Code) 中上传文件到 Linux 系统主要通过 SSH 协议实现,结合图形界面(GUI)或命令行工具操作。以下是具体说明及进度查看、断点续传的实现方法: ⚙️ 一、VS Code 上传文件到 Linux 的机制 SSH 远程连接 …

手机控车一键启动汽车智能钥匙

手机一键启动车辆的方法 手机一键启动车辆是一种便捷的汽车启动方式,它通过智能手机应用程序实现对车辆的远程控制。以下是详细的步骤: 完成必要的认证与激活步骤。打开手机上的相关移动管家手机控车APP,并与车载蓝牙建立连接。在APP的主界面…