和鲸社区深度学习基础训练营2025年关卡4

使用 pytorch 构建一个简单的卷积神经网络(CNN)模型,完成对 CIFAR-10 数据集的图像分类任务。 直接使用 CNN 进行分类的模型性能。 提示: 数据集:CIFAR-10 网络结构:可以使用 2-3 层卷积层,ReLU 激活,MaxPooling 层,最后连接全连接层。

#1. 数据预处理与加载
import torch
import torchvision
import torchvision.transforms as transforms# 数据增强与归一化(使用CIFAR-10官方均值和标准差)
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),       # 随机裁剪增强泛化性transforms.RandomHorizontalFlip(),          # 随机水平翻转transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)# 数据加载器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)#2. CNN模型架构
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)  # 输入通道3(RGB),输出32通道self.bn1 = nn.BatchNorm2d(32)                 # 批量归一化self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.pool = nn.MaxPool2d(2, 2)                # 池化层(尺寸减半)self.fc1 = nn.Linear(128 * 4 * 4, 256)       # 全连接层(输入尺寸计算:32x32 → 16x16 → 8x8 → 4x4)self.fc2 = nn.Linear(256, 10)                 # 输出10类def forward(self, x):x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32x32 → 16x16x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16x16 → 8x8x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8x8 → 4x4x = x.view(-1, 128 * 4 * 4)                    # 展平x = F.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型并移至GPU(若可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = SimpleCNN().to(device)#3. 训练与优化
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 每5轮学习率×0.1# 训练循环(10个epoch)
for epoch in range(10):net.train()running_loss = 0.0for i, (inputs, labels) in enumerate(trainloader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:  # 每100批次打印一次print(f'Epoch [{epoch+1}/10], Step [{i+1}/{len(trainloader)}], Loss: {running_loss/100:.3f}')running_loss = 0.0scheduler.step()  # 更新学习率print(f"Epoch {epoch+1} completed, learning rate: {scheduler.get_last_lr()[0]:.6f}")#4. 模型评估与可视化
net.eval()
correct, total = 0, 0
with torch.no_grad():for (images, labels) in testloader:images, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914192.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914192.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端性能优化全攻略:从加载到渲染

目录 前言网络请求优化资源加载优化JavaScript执行优化渲染优化用户体验优化性能监控与分析总结 前言 随着Web应用复杂度不断提升,前端性能优化变得尤为重要。本文将系统性地介绍从资源加载到页面渲染的全链路性能优化策略,帮助开发者构建高效、流畅的…

hiredis: 一个轻量级、高性能的 C 语言 Redis 客户端库

目录 1.简介 2.安装和配置 2.1.源码编译安装(通用方法) 2.2.包管理器安装(特定系统) 2.3.Windows 安装 3.常用的函数及功能 3.1.连接管理函数 3.2.命令执行函数 3.3.异步操作函数 3.4.回复处理函数 3.5.错误处理 3.6.…

TCP套接字

1.概念套接字是专门进行网络间数据通信的一种文件类型,可以实现不同主机之间双向通信,包含了需要交换的数据和通信双方的IP地址和port端口号。2.套接字文件的创建int socket(int domain, int type, int protocol); 功能:该函数用来创建各种各…

Go语言高并发聊天室(一):架构设计与核心概念

Go语言高并发聊天室(一):架构设计与核心概念 🚀 引言 在当今互联网时代,实时通信已成为各类应用的核心功能。从微信、QQ到各种在线协作工具,高并发聊天系统的需求无处不在。本系列文章将手把手教你使用Go语…

Java基础:泛型

什么是泛型? 简单来说,Java泛型是JDK 5引入的一种特性,它允许你在定义类、接口和方法时使用类型参数(Type Parameters)。这些类型参数可以在编译时被具体的类型(如 String, Integer, MyCustomClass 等&…

RMSNorm实现

当前Qwen、Llama等系列RMSNorm实现源码均一致。具体现实如下: class RMSNorm(nn.Module):def __init__(self, hidden_size, eps1e-6):super().__init__()self.weight nn.Parameter(torch.ones(hidden_size))self.variance_epsilon epsdef forward(self, hidden_s…

智能Agent场景实战指南 Day 11:财务分析Agent系统开发

【智能Agent场景实战指南 Day 11】财务分析Agent系统开发 文章标签 AI Agent,财务分析,LLM应用,智能财务,Python开发 文章简述 本文是"智能Agent场景实战指南"系列第11篇,聚焦财务分析Agent系统的开发。文章深入解析如何构建一个能够自动处理财务报表…

人工智能安全基础复习用:可解释性

一、可解释性的核心作用1. 错误检测与模型改进发现模型的异常行为(如过拟合、偏见),优化性能。例:医疗模型中,可解释性帮助识别误诊原因。2. 安全与可信性关键领域(医疗、军事)需透明决策&#…

Qt:QCustomPlot类介绍

QCustomPlot的核心类就是QCustomPlot类。这个类继承自QWidget,因此可以像其他QWidget一样使用,比如放入布局中。QCustomPlot类基本结构一个QCustomPlot对象可以包含多个图层(通过QCPLayer表示),通常使用默认图层。它包…

Visual Studio 2022 上使用ffmpeg

目录 1. 添加包含目录 2. 添加库目录 3. 添加依赖项 4. 添加动态库目录 5. 测试 在解决方案中右击项目名称,弹出的窗口中选择 "属性"。 1. 添加包含目录 "C/C" -> "常规" -> "附加包含目录"中添加 ffmpeg中的…

Elasticsearch 线程池

Elasticsearch 线程池「每个线程池到底采用哪种实现策略」:Elasticsearch 线程池(ThreadPool)中 **所有内置线程池名称的常量定义**。 每个字符串常量对应一个 **线程池的名字(name)**,也就是你在 Thread…

深入理解 Next.js API 路由:构建全栈应用的终极指南

Next.js 是一个强大的 React 框架,不仅支持服务端渲染(SSR)和静态站点生成(SSG),还提供了内置的 API 路由功能,使开发者能够轻松构建全栈应用。传统的全栈开发通常需要单独搭建后端服务&#xf…

【6.1.2 漫画分布式事务技术选型】

漫画分布式事务技术选型 🎯 学习目标:掌握架构师核心技能——分布式事务技术选型与一致性解决方案,构建高可靠的分布式系统 🎭 第一章:分布式事务模式对比 🤔 2PC vs 3PC vs TCC vs Saga 想象分布式事务就…

液冷智算数据中心崛起,AI算力联动PC Farm与云智算开拓新蓝海(二)

从算法革新到基础设施升级,从行业渗透到地域布局,人工智能算力正以 “规模扩张 效率提升”双轮驱动中国数字经济转型。中国智能算力规模将在 2025 年突破 1000 EFLOPS,2028 年达到 2781.9 EFLOPS,五年复合增长率 46.2%&#xff0…

《QtPy:Python与Qt的完美桥梁》

QtPy 是什么 在 Python 的广袤编程宇宙中,当涉及到图形用户界面(GUI)开发,Qt 框架宛如一颗璀璨的明星,散发着独特的魅力。而 QtPy,作为 Python 与 Qt 生态系统交互中的关键角色,更是为开发者们开…

ubuntu环境下调试 RT-Thread

调试 RT-Thread 下载源码 github 搜索 RT-Thread 下载源码 安装 python scons 环境 你已经安装了 kconfiglib,但 scons --menuconfig 仍然提示找不到它。这种情况通常是由于 Python 环境不一致 导致的:你在一个 Python 环境中安装了 kconfiglib&#xff…

【数据结构初阶】--顺序表(二)

🔥个人主页:草莓熊Lotso 🎬作者简介:C研发方向学习者 📖个人专栏: 《C语言》 《数据结构与算法》《C语言刷题集》《Leetcode刷题指南》 ⭐️人生格言:生活是默默的坚持,毅力是永久的…

Java中的方法传参机制

1. 概述Java中的方法传参机制分为两种:值传递(Pass by Value) 和 引用传递(Pass by Reference)。然而,Java中所有的参数传递都是值传递,只不过对于对象来说,传递的是对象的引用地址的…

C++——this关键字和new关键字

一、this 关键字1. 什么是 this?this 是 C 中的一个隐式指针,它指向当前对象(即调用成员函数的对象),在成员函数内部使用,用于引用调用该函数的对象。每个类的非静态成员函数内部都可以使用 this。使用 thi…

Python中类静态方法:@classmethod/@staticmethod详解和实战示例

在 Python 中,类方法 (classmethod) 和静态方法 (staticmethod) 是类作用域下的两种特殊方法。它们使用装饰器定义,并且与实例方法 (def func(self)) 的行为有所不同。1. 三种方法的对比概览方法类型是否访问实例 (self)是否访问类 (cls)典型用途实例方法…