MobileNet V1的Pytorch实现并加载预训练模型进行验证

一. 环境

  1. windonws 11
  2. RTX5060
  3. CUDA 12.8
  4. Pytorch 2.9.0dev20250630+cu128
  5. torchvision 0.23.0dev20250701+cu128

二. 代码

基于Mobilenet-CustomData 的Mobilenet_Pretrain.ipynb

1. 定义Mobile Net V1

import os
import time
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasetsfrom utils_pretrain import *# Define your data path here
IMAGENET_PATH = r'E:\AI\tiny-imagenet-200\tiny-imagenet-200'class Net(nn.Module):def __init__(self):super(Net, self).__init__()def conv_bn(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False),nn.BatchNorm2d(oup),nn.ReLU(inplace=True))def conv_dw(inp, oup, stride):return nn.Sequential(nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),nn.BatchNorm2d(inp),nn.ReLU(inplace=True),nn.Conv2d(inp, oup, 1, 1, 0, bias=False),nn.BatchNorm2d(oup),nn.ReLU(inplace=True),)self.model = nn.Sequential(conv_bn(  3,  32, 2), conv_dw( 32,  64, 1),conv_dw( 64, 128, 2),conv_dw(128, 128, 1),conv_dw(128, 256, 2),conv_dw(256, 256, 1),conv_dw(256, 512, 2),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 512, 1),conv_dw(512, 1024, 2),conv_dw(1024, 1024, 1),nn.AvgPool2d(7),)self.fc = nn.Linear(1024, 1000)def forward(self, x):x = self.model(x)x = x.view(-1, 1024)x = self.fc(x)return xmobilenet_model1 = Net()

2. 加载fine_tune模型

mobilenet_model1 = torch.nn.DataParallel(mobilenet_model1).cuda()
params = torch.load('D:\CodeSpace\MobileNet\Mobilenet-CustomData-master\Mobilenet-CustomData-master\Mobiilenet-finetune\moblienet_30e.pth.tar')
mobilenet_model1.load_state_dict(params,strict=False)

3. validate 设置

criterion = nn.CrossEntropyLoss().cuda()
batch_size = 10
workers = 4
epochs = 1
print_freq = 100valdir = os.path.join(IMAGENET_PATH, 'val')
if not os.path.exists(valdir):print(f"val文件夹 {valdir} 不存在,请检查路径。")exit(1)normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(valdir, transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),normalize,])),batch_size=batch_size, shuffle=False,num_workers=workers, pin_memory=True)def validate(val_loader, model, criterion):batch_time = AverageMeter()losses = AverageMeter()top1 = AverageMeter()top5 = AverageMeter()# switch to evaluate modemodel.eval()end = time.time()for i, (input, target) in enumerate(val_loader):# target = target.cuda(async=True) # non_blockingtarget = target.cuda(non_blocking=True)#input_var = torch.autograd.Variable(input, volatile=True)#target_var = torch.autograd.Variable(target, volatile=True)with torch.no_grad():input_var = input.cuda()target_var = target.cuda()# compute outputoutput = model(input_var)loss = criterion(output, target_var)# measure accuracy and record lossprec1, prec5 = accuracy(output.data, target, topk=(1, 5))# .data[0]改成.item()losses.update(loss.item(), input.size(0))top1.update(prec1.item(), input.size(0))top5.update(prec5.item(), input.size(0))# measure elapsed timebatch_time.update(time.time() - end)end = time.time()if i % print_freq == 0:print('Test: [{0}/{1}]\t''Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t''Loss {loss.val:.4f} ({loss.avg:.4f})\t''Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t''Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(i, len(val_loader), batch_time=batch_time, loss=losses,top1=top1, top5=top5))print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))return top1.avg

4. validate 和结果

best_prec1 = 0for epoch in range(0, epochs):# evaluate on validation setprec1 = validate(val_loader, mobilenet_model1, criterion)# prec3 = validate(val_loader, mobilenet_model3, criterion)# remember best prec@1 and save checkpointis_best = prec1 > best_prec1best_prec1 = max(prec1, best_prec1)
Test: [0/1000]	Time 10.112 (10.112)	Loss 3.6720 (3.6720)	Prec@1 20.000 (20.000)	Prec@5 90.000 (90.000)
Test: [100/1000]	Time 0.008 (0.109)	Loss 2.5464 (3.8335)	Prec@1 40.000 (4.752)	Prec@5 100.000 (84.851)
Test: [200/1000]	Time 0.009 (0.059)	Loss 4.6422 (3.8443)	Prec@1 0.000 (5.174)	Prec@5 70.000 (83.333)
Test: [300/1000]	Time 0.009 (0.042)	Loss 3.4911 (3.8356)	Prec@1 10.000 (5.316)	Prec@5 80.000 (83.588)
Test: [400/1000]	Time 0.009 (0.034)	Loss 4.4792 (3.8361)	Prec@1 0.000 (5.162)	Prec@5 90.000 (83.691)
Test: [500/1000]	Time 0.009 (0.028)	Loss 3.6648 (3.8577)	Prec@1 0.000 (5.070)	Prec@5 90.000 (83.972)
Test: [600/1000]	Time 0.008 (0.025)	Loss 2.6866 (3.8470)	Prec@1 20.000 (5.158)	Prec@5 90.000 (84.160)
Test: [700/1000]	Time 0.008 (0.023)	Loss 3.4880 (3.8465)	Prec@1 0.000 (5.235)	Prec@5 100.000 (84.379)
Test: [800/1000]	Time 0.008 (0.021)	Loss 4.0429 (3.8602)	Prec@1 0.000 (5.006)	Prec@5 70.000 (84.120)
Test: [900/1000]	Time 0.009 (0.019)	Loss 3.8612 (3.8640)	Prec@1 0.000 (5.006)	Prec@5 60.000 (84.029)* Prec@1 5.060 Prec@5 83.860

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/89708.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/89708.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTTP协议利用TCP的特性来实现长连接

在讨论网络协议时,经常会有人提出这样一个问题:“既然HTTP是基于TCP的,而TCP本身支持长连接,为什么HTTP不支持长连接?”这种说法其实是一种误解。实际上,HTTP确实可以并且经常使用长连接(也称为持久连接)。 什么是长连接? 首先,我们需要明确什么是“长连接”。在网…

整流电路Multisim电路仿真实验汇总——硬件工程师笔记

目录 1 整流电路基础 1.1 整流电路基本原理 1.2 整流电路的类型 1.2.1 单相整流电路 1.2.2 三相整流电路 1.3 整流电路的应用 1.3.1 直流电源 1.3.2 电池充电器 1.3.3 变频调速系统 1.34 电解和电镀 1.4 整流电路的优缺点 1.4.1 优点 1.4.2 缺点 2 二极管整流电路…

LangChain 全面入门

什么是 LangChain? LangChain 是一个专门为 大语言模型 (LLM) 应用开发设计的开源框架,帮你快速实现: • 多轮对话 • 知识库问答 (RAG) • 多工具协同调用 (function calling / tool) • 智能体 Agent 自动决策任务链 解耦 LLM 接口、Prom…

RabbitMQ 高级特性之消息确认

1. 简介 RabbitMQ 的消息发送流程: producer 将消息发送给 broker,consumer 从 broker 中获取消息并消费 那么在这里就涉及到了两种消息发送,即 producer 与 broker 之间和 consumer 与 broker 之间。 “消息确认” 讨论的是 consumer 与…

【51单片机用数码管显示流水灯的种类是按钮控制数码管加一和流水灯】2022-6-14

缘由 #include "REG52.h" unsigned char code smgduan[]{0x3f,0x06,0x5b,0x4f,0x66,0x6d,0x7d,0x07,0x7f,0x6f,0x77,0x7c,0x39,0x5e,0x79,0x71,0,64}; //共阴0~F消隐减号 unsigned char Js0, miao0;//中断计时 秒 分 时 毫秒 sbit k0P3^0; sbit k1P3^1; void smxs(u…

Android15 开机动画播放结束之后如何直接启动应用

问题背景 软件版本:Android15 在一些需求场景里面,需要开机动画播放结束立马去启动一个应用,下面介绍如何实现这种方案。 解决方案 首选我们需要知道开机动画播放结束之后的流程,这里会调用到wms里面,也就是一些enableScreen之类的函数,知道这个大概流程之后,再去对应…

AI实践:大模型痛点和解决方案讨论

大家好,我是星野,欢迎来到我的CSDN博客。在这个技术日新月异的时代,我们一起学习,共同进步。 今天想和大家分享的是大模型在实际应用中的痛点以及解决方案,特别是RAG(检索增强生成)技术。 大模…

Web前端工程化

Web前端工程化 前端工程化是指将软件工程的方法和原则应用到前端开发中,以提高开发效率、保证代码质量、便于团队协作和项目维护的一套体系化实践。以下是前端工程化的主要内容和实践: 核心组成部分 1. 模块化开发 JavaScript模块化:Comm…

Java 原生 HTTP Client

​介绍 Java 原生 HttpClient 是从 Java 11 开始引入的标准库,用于简化 HTTP 请求的发送与响应处理。它支持同步和异步请求,并内置对 HTTP/1.1 和 HTTP/2 协议的支持。HttpClient 提供了易用的 API 来设置请求头、请求体、处理响应以及配置 SSL/TLS 加密…

【C语言刷题】第十天:加量加餐继续,代码题训练,融会贯通IO模式

🔥个人主页:艾莉丝努力练剑 ❄专栏传送门:《C语言》、《数据结构与算法》、C语言刷题12天IO强训、LeetCode代码强化刷题 🍉学习方向:C/C方向 ⭐️人生格言:为天地立心,为生民立命,为…

【WEB】Polar靶场 6-10题 详细笔记

六.jwt 这题我又不会写 先来了解下jwt **JWT(JSON Web Token)**是一种基于JSON的开放标准(RFC 7519),主要用于在网络应用环境间传递声明信息。JWT通常用于身份验证和信息交换,确保在各方之间安全地传输信…

高阶亚马逊运营秘籍:关键词矩阵打法深度解析与应用

当竞争对手还在为单个大词竞价厮杀时,头部卖家已悄然构建了一张覆盖数千长尾关键词的隐形网络,精准触达每一个细分需求,以更低的成本撬动更高的转化率在亚马逊流量红利消退、广告成本高企的2025年,传统“爆款关键词”打法已显疲态…

【问题解决】org.springframework.web.util.NestedServletException Handler dispatch failed;

详细异常信息: org.springframework.web.util.NestedServletException: Handler dispatch failed; nested exception is java.lang.NoClassDefFoundError: javax/xml/bind/DatatypeConverter at org.springframework.web.servlet.DispatcherServlet.doDispatch(Disp…

【已解决】mac 聚焦搜索设置了edge 的地址栏搜索为google,还是跳转到百度

问题详情:在macbook的聚焦搜索中点击edge搜索的时候,跳转到了百度,即使已经将地址栏的搜索引擎设置为了goole,但是还是会跳转到百度。解决方案:1、打开safari浏览器。(看清了,是打开Safari&…

MimicMotion 让你的图片动起来

MimicMotion 是由腾讯公司推出的一款人工智能人像动态视频生成框架。可以模仿视频动作再让图片模仿动作姿态,最后生成视频。 MimicMotion 的核心在于其置信度感知的姿态引导技术,确保视频帧的高质量和时间上的平滑过渡。 以前咱们也手搭过Animate-X让图…

云计算考核 - 分析电子银行需求采用微服务架构对系统进行设计

二、使用的技术以及分析 微服务(Microservices)是一种架构风格,一个大型复杂软件应用由一个或多个微服务组成。系统中的各个微服务可被独立部署,各个微服务之间是松耦合的。每个微服务仅关注于完成一件任务并很好地完成该任务。在…

Ionic 安装使用教程

一、Ionic 简介 Ionic 是一个基于 Web 技术(HTML、CSS、JavaScript)的跨平台移动应用开发框架,结合 Angular、React 或 Vue 可快速构建 iOS 和 Android 应用。Ionic 提供丰富的 UI 组件、命令行工具及原生插件封装,广泛用于混合应…

渗透测试 - 简介

Web渗透测试简介 Web渗透测试(Penetration Testing)是一种模拟黑客攻击的安全评估方法,旨在发现Web应用程序中的漏洞,帮助开发者修复问题并提升系统安全性。它涉及主动测试目标系统(如网站或API)的弱点&am…

云原生AI研发体系建设路径

当AI遇上云原生,就像咖啡遇上牛奶,总能擦出不一样的火花 ☕️ 📋 文章目录 引言:为什么要建设云原生AI研发体系整体架构设计:搭建AI研发的"乐高积木"技术栈选择:选择合适的"武器装备"…

【网络安全】深入理解 IoC 与 IoA:从“事后识别”到“事前防御”

1. 简介 在网络安全领域,IoC(Indicators of Compromise,入侵指标) 和 IoA(Indicators of Attack,攻击指标) 是两个核心概念。它们是安全分析师识别攻击行为、调查事件、制定防御策略的重要依据…