pytorch--模型训练的一般流程

文章目录

    • 前言
    • 0、数据集准备
    • 1、数据集
    • 2、dataset
    • 3、model
    • 4、训练模型

前言

在pytorch中模型训练一般分为以下几个步骤:
0、数据集准备
1、数据集读取(dataset模块)
2、数据集转换为tensor(dataloader模块)
3、定义模型model(编写模型代码,主要是前向传播)
4、定义损失函数loss
5、定义优化器optimizer
6、最后一步是模型训练阶段train:这一步会,利用循环把dataset->dataloader->model->loss->optimizer合并起来。
相比于普通的函数神经网络并没有特别神奇的地方,我们不妨训练过程看成普通函数参数求解的过程,也就是最优化求解参数。以Alex模型为例,进行分类训练。

0、数据集准备

分类数据不需要进行标注,只需要给出类别就可以了,对应分割,检测需要借助labelme或者labelimg进行标注。将数据分为训练集,验证集,测试集。训练集用于模型训练,验证集用于训练过程中检验模型训练参数的表现,测试集是模型训练完成之后验证模型的表现。

1、数据集

从这里下载数据集The TU Darmstadt Database (formerly the ETHZ Database)一个三种类型115 motorbikes + 50 x 2 cars + 112 cows = 327张照片,把数据分为训练train和验证集val

在这里插入图片描述

并对train和val文件夹形成对应的标签文件,每一行为照片的名称和对应的类别编号(从0开始):
在这里插入图片描述

2、dataset

现在写一个名为dataset.py文件,写一个VOCDataset的类,来读取训练集和验证集,VOCDataset继承了torch.utils.data.Dataset,并重写父类的两个函数__getitem__:返回每个图像及其对应的标签,def __len__返回数据集的数量:


import torch  
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from PIL import Image
import osclass VOCDataset(Dataset):def __init__(self, img_dir, label_root, transform=None):self.img_root = img_dirself.label_root = label_rootself.transform = transform# 获取所有图像路径self.img_paths= [os.path.join(self.img_root, f) for f in os.listdir(self.img_root) if f.endswith('.png')]# 读取txt中class标签,txt文件每行格式为: img_name class_idself.label_classes = {}with open(label_root, 'r') as f:for line in f:img_name, class_id = line.strip().split()self.label_classes[img_name] = int(class_id)def __len__(self):return len(self.img_paths)def __getitem__(self, idx):img_path = self.img_paths[idx]img = Image.open(img_path).convert('RGB')# 获取对应的标签img_name = os.path.basename(img_path)target = self.label_classes.get(img_name, -1)if target == -1:raise ValueError(f"Image {img_name} not found in label file.")if self.transform:img = self.transform(img)else:img = transforms.ToTensor()(img)return img, target

3、model

新建一个model.py的文件,写一个Alex的类(参考动手学深度学习7.1),继承torch.nn.Module,重写forword函数:

from torch import nn
from torchvision import modelsclass AlexNet(nn.Module):def __init__(self,num_class=3):super(AlexNet, self).__init__()self.conv2d1=nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=1)self.pool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=0)self.conv2d2=nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,stride=1,padding=2)self.pool2=nn.MaxPool2d(kernel_size=3,stride=2,padding=0)self.conv2d3=nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,stride=1,padding=1)self.conv2d4=nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,stride=1,padding=1)self.conv2d5=nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1)self.pool3=nn.MaxPool2d(kernel_size=3,stride=2,padding=0)# 全连接层4096self.fc1=nn.Linear(256*5*5,4096)self.fc2=nn.Linear(4096,4096)self.fc3=nn.Linear(4096,num_class)self.sequential = nn.Sequential(self.conv2d1,nn.ReLU(),self.pool1,self.conv2d2,nn.ReLU(),self.pool2,self.conv2d3,nn.ReLU(),self.conv2d4,nn.ReLU(),self.conv2d5,nn.ReLU(),self.pool3,nn.Flatten(),self.fc1,nn.ReLU(),nn.Dropout(0.5),self.fc2,nn.ReLU(),nn.Dropout(0.5),self.fc3)# 初始化权重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward(self,x):x = self.sequential(x)return x

4、训练模型

首先定义损失函数和优化器:

  criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)

新建一个train.py的文件:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import VOCDataset
from model import AlexNet, ResnetPretrained
from torchvision import models
from torchvision.datasets import CIFAR10
from dataset import VOCDataset
import tensorboarddef train(model, train_dataset, val_dataset, num_epochs=20, batch_size=32, learning_rate=0.001):# 1. 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)# 2. 定义损失函数和优化器criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 3. 修正学习率调度器(放在循环外)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)# 4. 训练模型best_acc = 0.0for epoch in range(num_epochs):model.train()running_loss = 0.0total = 0for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.cuda(), labels.cuda()optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)total += inputs.size(0)if i % 100 == 0:avg_loss = running_loss / totalprint(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {avg_loss:.4f}')# 每个epoch结束后验证model.eval()correct = 0total_val = 0val_loss = 0.0with torch.no_grad():for inputs, labels in val_loader:inputs, labels = inputs.cuda(), labels.cuda()outputs = model(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct += (predicted == labels).sum().item()val_loss += loss.item() * inputs.size(0)epoch_acc = 100 * correct / total_valavg_val_loss = val_loss / total_valprint(f'Epoch {epoch+1}/{num_epochs} | 'f'Train Loss: {running_loss/total:.4f} | 'f'Val Loss: {avg_val_loss:.4f} | 'f'Val Acc: {epoch_acc:.2f}%')# 更新学习率(基于验证集准确率)#scheduler.step(epoch_acc)# 保存最佳模型if epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), 'best_alexnet_cifar10.pth')print(f'Best Validation Accuracy: {best_acc:.2f}%')if __name__ == "__main__":# 1. 定义数据集路径train_img_dir = r'F:\dataset\tud\TUDarmstadt\PNGImages\train'val_img_dir = r'F:\dataset\tud\TUDarmstadt\PNGImages\val'train_label_file = r'F:\dataset\tud\TUDarmstadt\PNGImages/train_set.txt'val_label_file = r'F:\dataset\tud\TUDarmstadt\PNGImages/val_set.txt'# 2. 创建数据集实例# 增强数据增强transform_train = transforms.Compose([transforms.Resize((256, 256)),  # 先放大transforms.RandomCrop(224),  # 随机裁剪transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 验证集不需要数据增强,但需要同样的预处理transform_val = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 创建训练和验证数据集train_dataset = VOCDataset(train_img_dir, train_label_file, transform=transform_train)val_dataset = VOCDataset(val_img_dir, val_label_file, transform=transform_val)print(f'Train dataset size: {len(train_dataset)}')print(f'Validation dataset size: {len(val_dataset)}')# 2. 下载并利用CIFAR-10数据集进行分类# # # 定义数据增强和预处理# transform_train = transforms.Compose([#     transforms.Resize((224, 224)),#     transforms.RandomHorizontalFlip(),#     transforms.RandomCrop(224, padding=4),#     transforms.ToTensor(),#     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], #                          std=[0.2470, 0.2435, 0.2616])# ])# transform_val = transforms.Compose([#     transforms.Resize((224, 224)),#     transforms.ToTensor(),#     transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], #                          std=[0.2470, 0.2435, 0.2616])# ])# # 下载CIFAR-10训练集和验证集# train_dataset = CIFAR10(root='data', train=True, download=True, transform=transform_train)# val_dataset = CIFAR10(root='data', train=False, download=True, transform=transform_val)# print(f'Train dataset size: {len(train_dataset)}')# print(f'Validation dataset size: {len(val_dataset)}')# 3. 创建模型实例model = AlexNet(num_class=10)  # CIFAR-10有10个类别  # 检查是否有可用的GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)  # 将模型移动到GPU或CPU# 打印模型结构#print(model)# 4. 开始训练train(model, train_dataset, val_dataset, num_epochs=20, batch_size=32, learning_rate=0.001)print('Finished Training')# 5. 保存模型torch.save(model.state_dict(), 'output/alexnet.pth')print('Model saved as alexnet.pth')

运行main函数就可以进行训练了,后面会讲一些如何改进这个模型和一些训练技巧。

参考:
1
2
3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/912651.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/912651.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智能合同管理实战:基于区块链的电子签约技术实现

在数字经济时代,传统纸质合同签署方式已难以满足企业高效、安全、合规的业务需求。智能合同管理(Smart Contract Management)结合区块链技术,正在重塑电子签约流程,实现合同全生命周期的自动化、可追溯和防篡改。本文将深入探讨基于区块链的电子签约技术实现,涵盖核心架构…

设计模式精讲 Day 22:模板方法模式(Template Method Pattern)

【设计模式精讲 Day 22】模板方法模式(Template Method Pattern) 文章标签 设计模式, 模板方法模式, Java开发, 面向对象设计, 软件架构, 设计模式实战, Java应用开发 文章简述 模板方法模式是一种行为型设计模式,它通过定义一个算法的骨架…

如何在pytorch中使用tqdm:优雅实现训练进度监控

文章目录 为什么需要进度条?tqdm 简介基础用法示例深度学习中的实战应用1. 数据加载进度监控2. 训练循环增强版3. 验证阶段集成 高级技巧与最佳实践1. 自定义进度条样式2. 嵌套进度条(多任务)3. 分布式训练支持4. 与日志系统集成 性能优化建议…

Linux中的xxd命令详解

xxd 是一个 十六进制转储(hex dump)工具,通常用于将二进制文件转换为十六进制格式,或者反向转换(十六进制→二进制)。它是 vim 的一部分,但在大多数 Linux 系统(如 Ubuntu&#xff0…

磐维数据库panweidb3.1.0单节点多实例安装

0 说明 业务科室提单需要在某台主机上部署多个单机磐维数据库,用于业务测试。以下内容展示如何在单节点安装多个磐维数据库实例。 1 部署环境准备 1.1 IP 地址及端口 instipport实例1192.168.131.1717700实例2192.168.131.1727700 在131.17上分别安装两个实例&…

转录组分析流程(三):功能富集分析

我们的教程主要是以一个具体的例子作为线索,通过对公共数据库数据bulk-RNA-seq的挖掘,利用生物信息学分析来探索目标基因集作为某种疾病数据预后基因的潜能及其潜在分子机制,同时在单细胞水平分析(对scRNA-seq进行挖掘)预后基因的表达,了解细胞之间的通讯网络,以期为该疾病…

全面掌握 tkinter:Python GUI 编程的入门与实战指南

在自动化、工具开发、数据可视化等领域,图形用户界面(GUI)往往是提升用户体验的重要方式。作为 Python 官方内置的 GUI 库,tkinter 以其轻量、跨平台、易于学习的特性成为初学者和轻量级应用开发者首选。 本文将以深入浅出的方式…

TDH社区开发版安装教程

(注:本文章来源于星环官网安装手册) 后面放置了视频和安装手册连接 1、硬件及环境要求 Docker17及以上版本,支持Centos,Ubuntu等系统(注:这里我使用CentOS-7版本,最佳版本推荐为7.…

Linux基本命令篇 —— grep命令

grep是Linux/Unix系统中一个非常强大的文本搜索工具,它的名字来源于"Global Regular Expression Print"(全局正则表达式打印)。grep命令用于在文件中搜索包含特定模式的行,并将匹配的行打印出来。 目录 一、基本语法 二…

苍穹外卖问题系列之 苍穹外卖订单详情前端界面和网课给的不一样

问题 如图,我的前端界面和网课里面给的不一样,没有“申请退款”和一些其他的该有的东西。 原因分析 “合计”这一栏显示undefined说明我们的总金额没有输入进去。可以看看订单提交那块的代码,是否可以正确输出。还有就是订单详细界面展示这…

CppCon 2018 学习:EMULATING THE NINTENDO 3DS

我们来逐个分析一下这个 组件交互模型 和 仿真 & 序列化 的关系,特别是主线程(Main Thread)与其他系统组件之间的交互。 1. Main Thread — simple (basically memcpy) --> GPU Main Thread(主线程)负责游戏的…

[Python 基础课程]数字

数字 数字数据类型用于存储数值,比如整数、小数等。数据类型是不允许改变的,这就意味着如果改变数字数据类型的值,将重新分配内存空间。 创建数字类型的变量: var1 1 var2 10创建完变量后,如果想废弃掉这个变量&a…

Linux CentOS环境下Java连接MySQL数据库指南

文章目录 前言一、环境准备1.1 系统更新1.2 Java环境安装1.3 MySQL数据库安装1.4 下载JDBC驱动 二、编写Java程序2.1 代码如下2.2 编译和运行2.3 验证创建结果 三、代码上传至Gitee3.1 安装配置Git3.2 克隆仓库到本地3.3 添加Java项目文件3.4 提交代码到本地仓库3.5 推送到Gite…

LLM面试12

讯飞算法工程师面试题 SVM核函数能否映射到无穷维 可以的,多项式核函数将低维数据映射到高维(维度是有限的),而高斯核函数可以映射到无穷维。由 描述下xgb原理,损失函数 首先需要说一说GBDT,它是一种基于boosting增强…

类加载生命周期与内存区域详解

类加载生命周期与内存区域详解 Java 类加载的生命周期包括加载、验证、准备、解析、初始化五个阶段,每个阶段在内存中的存储区域和赋值机制各有不同。以下是详细解析: 一、类加载生命周期阶段 1. 加载(Loading) 内存区域&…

正交视图三维重建2 笔记 2d线到3d线2 先生成3d线然后判断3d线在不在

应该先连线再判断线在不在 if(fx1tx1&&tx1tx2){ const A[fx1, fy1, ty1];const Ahat[fx1, fy1, ty2];drawlines(A[0], A[1], A[2], Ahat[0], Ahat[1], Ahat[2], lineId, type,2);}if(fx2tx1&&tx1tx2){ const B[fx2, fy2, ty1];const Bhat[fx2, fy2, ty2];drawl…

Hibernate对象生命周期全解析

Hibernate对象生命周期详解 Hibernate作为Java领域主流的ORM框架,其核心机制之一就是对持久化对象生命周期的管理。理解Hibernate对象生命周期对于正确使用Hibernate进行数据持久化操作至关重要。Hibernate将对象分为三种主要状态:瞬时态(Transient)、持久态(Persistent)和游…

MCP 协议使用核心讲解

📄 MCP 协议使用核心讲解 ✅ MCP 协议的核心在于以下几个方面 一、MCP 请求结构(MCPRequest) {"messages": [{"role": "user","content": "帮我查询一下上海的天气"}],"tools"…

云计算中的几何方法:曲面变形的可视化与动画-AI云计算数值分析和代码验证

着重强调微分方程底层的几何和代数结构,以进行更深入的分析和求解方法。开发结构保持的数值方法,以在计算中保持定性特征。统一符号和数值方法,实现有效的数学建模。利用几何解释(如双曲几何)求解经典微分方程。利用计…

OpenCV篇——项目(一)OCR识别读取银行卡号码

目录 信用卡数字识别系统:前言与代码解析 前言 项目代码 ​​​​​​结果演示 代码模块解析 1. 参数解析模块 2. 轮廓排序函数 3. 图像预处理模块 4. 输入图像处理流程 5. 卡号区域定位 6. 数字识别与输出 系统优势 信用卡数字识别系统:前言…