14.使用GoogleNet/Inception网络进行Fashion-Mnist分类

14.1 GoogleNet网络结构设计

在这里插入图片描述
在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
class Inception(nn.Module):def __init__(self, in_channels,c1,c2,c3,c4,**kwargs):super(Inception,self).__init__(**kwargs)#第一条路线:1*1的卷积层self.p1_1=nn.Conv2d(in_channels,c1,kernel_size=1)#第二条路线:1*1的卷积层+3*3的卷积层self.p2_1=nn.Conv2d(in_channels,c2[0],kernel_size=1)self.p2_2=nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)#第三条路线:1*1的卷积层+5*5的卷积层self.p3_1=nn.Conv2d(in_channels,c3[0],kernel_size=1)self.p3_2=nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)#第四条路线:3*3Maxpool+1*1 convsself.p4_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)self.p4_2=nn.Conv2d(in_channels,c4,kernel_size=1)def forward(self,x):p1=F.relu(self.p1_1(x))#第一层p2=F.relu(self.p2_2(F.relu(self.p2_1(x))))p3=F.relu(self.p3_2(F.relu(self.p3_1(x))))p4=F.relu(self.p4_2(self.p4_1(x)))ft=torch.concat((p1,p2,p3,p4),dim=1)return ft
#组建googlenet
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(nn.Conv2d(64,64,kernel_size=1),nn.ReLU(),nn.Conv2d(64,192,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b3=nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b4=nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b5=nn.Sequential(Inception(832,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,b2,b3,b4,b5,nn.Linear(480,10)).to(device)
summary(model,input_size=(1,224,224),batch_size=1)

在这里插入图片描述

14.2 GoogleNet网络实现Fashion-Mnist分类

import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
class Inception(nn.Module):def __init__(self, in_channels,c1,c2,c3,c4,**kwargs):super(Inception,self).__init__(**kwargs)#第一条路线:1*1的卷积层self.p1_1=nn.Conv2d(in_channels,c1,kernel_size=1)#第二条路线:1*1的卷积层+3*3的卷积层self.p2_1=nn.Conv2d(in_channels,c2[0],kernel_size=1)self.p2_2=nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)#第三条路线:1*1的卷积层+5*5的卷积层self.p3_1=nn.Conv2d(in_channels,c3[0],kernel_size=1)self.p3_2=nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)#第四条路线:3*3Maxpool+1*1 convsself.p4_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)self.p4_2=nn.Conv2d(in_channels,c4,kernel_size=1)def forward(self,x):p1=F.relu(self.p1_1(x))#第一层p2=F.relu(self.p2_2(F.relu(self.p2_1(x))))p3=F.relu(self.p3_2(F.relu(self.p3_1(x))))p4=F.relu(self.p4_2(self.p4_1(x)))ft=torch.concat((p1,p2,p3,p4),dim=1)return ft
#组建googlenet
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(nn.Conv2d(64,64,kernel_size=1),nn.ReLU(),nn.Conv2d(64,192,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b3=nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b4=nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b5=nn.Sequential(Inception(832,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,b2,b3,b4,b5,nn.Linear(480,10)).to(device)
transforms=transforms.Compose([transforms.Resize(96),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一个是mean,第二个是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/88804.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/88804.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NE综合实验2:RIP 与 OSPF 动态路由精细配置、FTPTELNET 服务搭建及精准访问限制

NE综合实验2:RIP 与 OSPF 动态路由精细配置、FTPTELNET 服务搭建及精准访问限制 涉及的协议可以看我之前的文章: RIP实验 OSPF协议:核心概念与配置要点解析 ACL协议:核心概念与配置要点解析 基于OSPF动态路由与ACL访问控制的网…

Android 插件化实现原理详解

Android 插件化实现原理详解 插件化技术是Android开发中一项重要的高级技术,它允许应用动态加载和执行未安装的APK模块。以下是插件化技术的核心实现原理和关键技术点: 一、插件化核心思想宿主与插件: 宿主(Host):主应用APK&#…

空间智能-李飞飞团队工作总结(至2025.07)

李飞飞团队在空间智能(Spatial Intelligence)领域的研究自2024年起取得了一系列突破性进展,其里程碑成果可归纳为以下核心方向: 一、理论框架提出与定义(2024年) 1、空间智能概念系统化 a.定义: 李飞飞首次明确空间智能为“机器在3D空间和时间中感知、推理和行动的能…

【算法深练】BFS:“由近及远”的遍历艺术,广度优先算法题型全解析

前言 宽度优先遍历BFS与深度优先遍历DFS有本质上的区别,DFS是一直扩到低之后找返回,而BFS是一层层的扩展就像剥洋葱皮一样。 通常BFS是将所有路径同时进行尝试,所以BFS找到的第一个满足条件的位置,一定是路径最短的位置&#xf…

ZW3D 二次开发-创建球体

使用中望3d用户函数 cvxPartSphere 创建球体 函数定义: ZW_API_C evxErrors cvxPartSphere(svxSphereData *Sphere, int *idShape); typedef struct svxSphereData {evxBoolType Combine; /**<@brief combination method */svxPoint Center; /**<@brief sphere ce…

艺术总监的构图“再造术”:用PS生成式AI,重塑照片叙事框架

在视觉叙事中&#xff0c;我们常常面临一个核心的“对立统一”&#xff1a;一方面是**“被捕捉的瞬间”&#xff08;The Captured Moment&#xff09;&#xff0c;即摄影师在特定时间、特定地点所记录下的客观现实&#xff1b;另一方面是“被期望的叙事”**&#xff08;The Des…

ChatGPT无法登陆?分步排查指南与解决方案

ChatGPT作为全球领先的AI对话工具&#xff0c;日均处理超百万次登录请求&#xff0c;登陆问题可能导致用户无法正常使用服务&#xff0c;影响工作效率或学习进度。 无论是显示「网络错误」「账号未激活」&#xff0c;还是持续加载无响应&#xff0c;本文将从网络连接、账号状态…

用Joern执行CPGQL找到C语言中不安全函数调用的流程

1. 引入 静态应用程序安全测试&#xff08;Static application security testing&#xff09;简称SAST&#xff0c;是透过审查程式源代码来识别漏洞&#xff0c;提升软件安全性的作法。 Joern 是一个强大的开源静态应用安全测试&#xff08;SAST&#xff09;工具&#xff0c;专…

读文章 Critiques of World model

论文名称&#xff1a;对世界模型的批判 作者单位&#xff1a; CMU&#xff0c; UC SD 原文链接&#xff1a;https://arxiv.org/pdf/2507.05169 摘要&#xff1a; 世界模型&#xff08;World Model&#xff09;——即真实世界环境的算法替代物&#xff0c;是生物体所体验并与之…

利用docker部署前后端分离项目

后端部署数据库:redis部署:拉取镜像:doker pull redis运行容器:docker run -d -p 6379:6379 --name my_redis redismysql部署:拉取镜像:docker pull mysql运行容器:我这里3306被占了就用的39001映射docker run -d -p 39001:3306 -v /home/mysql/conf:/etc/mysql/conf.d -v /hom…

YOLOv11调参指南

YOLOv11调参 1. YOLOv11参数体系概述 YOLOv11作为目标检测领域的前沿算法&#xff0c;其参数体系可分为四大核心模块&#xff1a; 模型结构参数&#xff1a;决定网络深度、宽度、特征融合方式训练参数&#xff1a;控制学习率、优化器、数据增强策略检测参数&#xff1a;影响预测…

云原生核心技术解析:Docker vs Kubernetes vs Docker Compose

云原生核心技术解析&#xff1a;Docker vs Kubernetes vs Docker Compose &#x1f6a2;☸️⚙️ 一、云原生核心概念 ☁️ 云原生&#xff08;Cloud Native&#xff09; 是一种基于云计算模型构建和运行应用的方法论&#xff0c;核心目标是通过以下技术实现弹性、可扩展、高可…

keepalive模拟操作部署

目录 keepalived双机热备 一、配置准备 二、配置双机热备&#xff08;基于nginx&#xff09; web1端 修改配置文件 配置脚本文件 web2端 修改配置文件 配置脚本文件 模拟检测 开启keepalived服务 访问结果 故障模拟 中止nginx 查看IP 访问浏览器 重启服务后…

Java 中的 volatile 是什么?

&#x1f449; volatile &#xff1a;不稳定的 英[ˈvɒlətaɪl] 美[ˈvɑːlətl] adj. 不稳定的;<计>易失的;易挥发的&#xff0c;易发散的;爆发性的&#xff0c;爆炸性的;易变的&#xff0c;无定性的&#xff0c;无常性的;短暂的&#xff0c;片刻的;活泼的&#xff…

MongoDB性能优化实战指南:原理、实践与案例

MongoDB性能优化实战指南&#xff1a;原理、实践与案例 在大规模数据存储与查询场景下&#xff0c;MongoDB凭借其灵活的文档模型和水平扩展能力&#xff0c;成为众多互联网及企业级应用的首选。然而&#xff0c;在生产环境中&#xff0c;随着数据量和并发的增长&#xff0c;如何…

细谈kotlin中缀表达式

Kotlin 是一种适应你编程风格的语言&#xff0c;允许你在想什么时候写代码就什么时候写代码。Kotlin 提供了一些机制&#xff0c;帮助我们编写易读易懂的代码。其中一个非常有趣的机制是 中缀表达式&#xff08;infix notation&#xff09;。它允许我们定义和调用函数时省略点号…

[Nagios Core] CGI接口 | 状态数据管理.dat | 性能优化

链接&#xff1a;https://assets.nagios.com/downloads/nagioscore/docs/nagioscore/4/en/ docs&#xff1a;Nagios Core Nagios Core 是功能强大的基础设施监控系统&#xff0c;包含 CGI 程序&#xff0c;允许用户通过 Web 界面查看当前状态、历史记录等。通过以下技术栈实现…

Linux进程优先级机制深度解析:从Nice值到实时调度

前言 在Linux系统中&#xff0c;进程优先级决定了CPU资源的分配顺序&#xff0c;直接影响系统性能和关键任务的响应速度。无论是优化服务器负载、确保实时任务稳定运行&#xff0c;还是避免低优先级进程拖慢系统&#xff0c;合理调整进程优先级都是系统管理和性能调优的重要技能…

深入浅出Kafka Broker源码解析(下篇):副本机制与控制器

一、副本机制深度解析 1.1 ISR机制实现 1.1.1 ISR管理核心逻辑 ISR&#xff08;In-Sync Replicas&#xff09;是Kafka保证数据一致性的核心机制&#xff0c;其实现主要分布在ReplicaManager和Partition类中&#xff1a; public class ReplicaManager {// ISR变更集合&#xff0…

Fluent许可文件安装和配置

在使用Fluent软件进行流体动力学模拟之前&#xff0c;正确安装和配置Fluent许可文件是至关重要的一步。本文将为您提供详细的Fluent许可文件安装和配置指南&#xff0c;帮助您轻松完成许可文件的安装和配置&#xff0c;确保Fluent软件能够顺利运行。 一、Fluent许可文件安装步骤…