Python 训练营打卡 Day 34

GPU训练及类的call方法

一、GPU训练

与day33采用的CPU训练不同,今天试着让模型在GPU上训练,引入import time比较两者在运行时间上的差异

import torch
# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 仍然用4特征,3分类的鸢尾花数据集作为我们今天的数据集
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np
import time# 加载数据集
iris = load_iris()
x = iris.data
y = iris.target
# 划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
# 归一化数据,神经网络对于输入数据的尺寸敏感,归一化是最常见的处理方式
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test) #确保训练集和测试集使用相同的缩放因子
# 将数据转换为PyTorch张量并移至GPU
# 分类问题交叉熵损失要求标签为long类型
# 张量具有to(device)方法,可以将张量移动到指定的设备上
x_train = torch.FloatTensor(x_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
x_test = torch.FloatTensor(x_test).to(device)
y_test = torch.LongTensor(y_test).to(device)import torch.nn as nn # 导入PyTorch的神经网络模块
import torch.optim as optim # 导入PyTorch的优化器模块
class MLP(nn.Module): # 定义一个多层感知机(MLP)模型,继承父类nn.Moduledef __init__(self): # 初始化函数super(MLP, self).__init__() # 调用父类的初始化函数# 定义的前三行是八股文,后面的是自定义的self.fc1 = nn.Linear(4, 10) # 定义第一个全连接层,输入维度为4,输出维度为10self.relu = nn.ReLU() # 定义激活函数ReLUself.fc2 = nn.Linear(10, 3) # 定义第二个全连接层,输入维度为10,输出维度为3
# 输出层不需要激活函数,因为后面会用到交叉熵函数cross_entropy,交叉熵函数内部有softmax函数,会把输出转化为概率def forward(self, x):out = self.fc1(x) # 输入x经过第一个全连接层out = self.relu(out) # 激活函数ReLUout = self.fc2(out) # 输入x经过第二个全连接层return out 
# 实例化模型并移至GPU
# MLP继承nn.Module类,所以也具有to(device)方法
model = MLP().to(device)
# 定义损失函数和优化器
# 分类问题使用交叉熵损失函数,适用于多分类问题,应用softmax函数将输出映射到概率分布,然后计算交叉熵损失
criterion = nn.CrossEntropyLoss()
# 使用随机梯度下降优化器(SGD),学习率为0.01
optimizer = optim.SGD(model.parameters(), lr=0.01)# 开始循环训练
# 训练模型
num_epochs = 20000 # 训练的轮数# 用于存储每个 epoch 的损失值
losses = []
start_time = time.time() # 记录开始时间
for epoch in range(num_epochs): # 开始迭代训练过程,range是从0开始,所以epoch是从0开始# 前向传播:将数据输入模型,计算模型预测输出outputs = model.forward(x_train)   # 显式调用forward函数# outputs = model(X_train)  # 常见写法隐式调用forward函数,其实是用了model类的__call__方法loss = criterion(outputs, y_train) # output是模型预测值,y_train是真实标签,计算两者之间损失值# 反向传播和优化optimizer.zero_grad() #梯度清零,因为PyTorch会累积梯度,所以每次迭代需要清零,梯度累计是那种小的bitchsize模拟大的bitchsizeloss.backward() # 反向传播计算梯度,自动完成以下计算:# 1. 计算损失函数对输出的梯度# 2. 从输出层→隐藏层→输入层反向传播# 3. 计算各层权重/偏置的梯度optimizer.step() # 更新参数# 记录损失值losses.append(loss.item())# 打印训练信息if (epoch + 1) % 100 == 0: # range是从0开始,所以epoch+1是从当前epoch开始,每100个epoch打印一次print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
time_all = time.time() - start_time
print(f'Training time: {time_all:.2f} seconds')
import matplotlib.pyplot as plt
# 可视化损失曲线
plt.plot(range(num_epochs), losses) 
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

 结果显示GPU训练情况下训练时间为44.96s,而昨天的CPU训练时间为14.13s。

本质是因为GPU在计算的时候,相较于CPU多了3个时间上的开销

1.数据传输开销 (CPU 内存 <-> GPU 显存)

  • 在 GPU 进行任何计算之前,数据(输入张量 X_train、y_train,模型参数)需要从计算机的主内存 (RAM) 复制到 GPU 专用的显存 (VRAM) 中。
  • 当结果传回 CPU 时(例如,使用 loss.item() 获取损失值用于打印或记录,或者获取最终预测结果),数据也需要从 GPU 显存复制回 CPU 内存。
  • 对于少量数据和非常快速的计算任务,这个传输时间可能比 GPU 通过并行计算节省下来的时间还要长。

在上述代码中,循环里的 loss.item() 操作会在每个 epoch 都进行一次从 GPU 到 CPU 的数据同步和传输,以便获取标量损失值。对于20000个epoch来说,这会累积不少的传输开销。

2.核心启动开销 (GPU 核心启动时间)

  • GPU 执行的每个操作(例如,一个线性层的前向传播、一个激活函数)都涉及到在 GPU 上启动一个“核心”(kernel)——一个在 GPU 众多计算单元上运行的小程序。
  • 启动每个核心都有一个小的、固定的开销。
  • 如果核心内的实际计算量非常小(本项目的小型网络和鸢尾花数据),这个启动开销在总时间中的占比就会比较大。相比之下,CPU 执行这些小操作的“调度”开销通常更低。

3.性能浪费:计算量和数据批次

  • 这个数据量太少,gpu的很多计算单元都没有被用到,即使用了全批次也没有用到的全部计算单元。

综上,数据传输和各种固定开销的总和,超过了 GPU 在这点计算量上通过并行处理所能节省的时间,导致了 GPU 比 CPU 慢的现象。这些特性导致GPU在处理鸢尾花分类这种“玩具级别”的问题时,它的优势无法体现,反而会因为上述开销显得“笨重”。

那么什么时候 GPU 会发挥巨大优势?

  • 大型数据集: 例如,图像数据集成千上万张图片,每张图片维度很高。
  • 大型模型: 例如,深度卷积网络 (CNNs like ResNet, VGG) 或 Transformer 模型,它们有数百万甚至数十亿的参数,计算量巨大。
  • 合适的批处理大小: 能够充分利用 GPU 并行性的 batch size,不至于还有剩余的计算量没有被 GPU 处理。
  • 复杂的、可并行的运算: 大量的矩阵乘法、卷积等。

针对上面反应的3个时间开销问题,能够优化的只有数据传输时间,针对性解决即可,很容易想到2个思路:

  1. 直接不打印训练过程的loss了,但是这样会没办法记录最后的可视化图片,只能肉眼观察loss数值变化
  2. 每隔200个epoch保存一下loss,不需要20000个epoch每次都打印

经试验,思路一去除图片打印部分的情况下训练时间为:36.77s,思路二改变保存间隔情况下的训练时长为:44.6s

二、__call__方法

在 Python 中,__call__方法是一个特殊的魔术方法,它允许类的实例像函数一样被调用。这种特性使得对象可以表现得像函数,同时保留对象的内部状态。

# 不带参数的call方法
class Counter:def __init__(self):self.count = 0def __call__(self):self.count += 1return self.count# 使用示例
counter = Counter() # 实例化对象,通过__init__方法初始化count为0
print(counter())  # 调用__call__,输出: 1
print(counter())  # counter=1再代入call方法,输出: 2
print(counter.count)  # 输出: 2# 带参数的call方法
class Adder:def __call__(self, a, b):print("唱跳篮球rap")return a + badder = Adder()
print(adder(3, 5))  
# 输出:
# 唱跳篮球rap 
# 8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/82302.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu22.04 系统安装Docker教程

1.更新系统软件包 #确保您的系统软件包是最新的。这有助于避免安装过程中可能遇到的问题 sudo apt update sudo apt upgrade -y 2.安装必要的依赖 sudo apt install apt-transport-https ca-certificates curl software-properties-common -y 3.替换软件源 原来/etc/apt/s…

深入解析前端 JSBridge:现代混合开发的通信基石与架构艺术

引言&#xff1a;被低估的通信革命 在移动互联网爆发式增长的十年间&#xff0c;Hybrid App&#xff08;混合应用&#xff09;始终占据着不可替代的地位。作为连接 Web 与 Native 的神经中枢&#xff0c;JSBridge 的设计质量直接决定了应用的性能上限与开发效率。本文将突破传…

ES 面试题系列「三」

1、在设计 Elasticsearch 索引时&#xff0c;如何考虑数据的建模和映射&#xff1f; 需要根据业务需求和数据特点来确定索引的结构。首先要分析数据的类型&#xff0c;对于结构化数据&#xff0c;如数字、日期等&#xff0c;要明确其数据格式和范围&#xff0c;选择合适的字段…

HTML5快速入门-常用标签及其属性(三)

HTML5快速入门-常用标签及其属性(三) 文章目录 HTML5快速入门-常用标签及其属性(三)音视频标签&#x1f3a7; <audio> 标签 — 插入音频使用 <source> 提供多格式备选&#xff08;提高兼容性&#xff09;&#x1f3a5; <video> 标签 — 插入视频&#x1f3b5…

Qt文件:XML文件

XML文件 1. XML文件结构1.1 基本结构1.2 XML 格式规则1.3 XML vs HTML 2. XML文件操作2.1 DOM 方式&#xff08;QDomDocument&#xff09;读取 XML写入XML 2.2 SAX 方式&#xff08;QXmlStreamReader/QXmlStreamWriter&#xff09;读取XML写入XML 2.3 对比分析 3. 使用场景3.1 …

day24Node-node的Web框架Express

1. Express 基础 1.1 什么是Express node的web框架有Express 和 Koa。常用Express 。 Express 是一个基于 Node.js 的快速、极简的 Web 应用框架,用于构建 服务器端应用(如网站后端、RESTful API 等)。它是 Node.js 生态中最流行的框架之一,以轻量、灵活和易用著称。 …

uniapp实现的简约美观的票据、车票、飞机票模板

采用 uniapp 实现的一款简约美观的票据模板&#xff0c;纯CSS、HTML实现&#xff0c;用户完全可根据自身需求进行更改、扩展&#xff1b;支持web、H5、微信小程序&#xff08;其他小程序请自行测试&#xff09;&#xff0c; 可到插件市场下载尝试&#xff1a; https://ext.dclo…

esp32+IDF V5.1.1版本编译freertos报错

error: portTICK_RATE_MS undeclared (first use in this function); did you mean portTICK_PERIOD_MS 解决方法: 使用命令 idf.py menuconfig 打开配置界面配置freeRtos 使能configENABLE_BACKWARD_COMPATIBLITY

vue 水印组件

Watermark.vue <script setup lang"ts"> import { ref, onMounted, onUnmounted, watch } from vue;interface Props {text?: string;fontSize?: number;color?: string;rotate?: number;zIndex?: number;gap?: number; }const props withDefaults(def…

hbuilder中h5转为小程序提交发布审核

【注意】 [HBuilder] 11:59:15.179 此应用 DCloud appid 为 __UNI__9F9CC77 &#xff0c;您不是这个应用的项目成员。1、联系这个应用的所有者&#xff0c;请求加入项目成员&#xff08;https://dev.dcloud.net.cn "成员管理"-"添加项目成员"&#xff09;…

QT之INI、JSON、XML处理

文章目录 INI文件处理写配置文件读配置文件 JSON 文件处理写入JSON读取JSON XML文件处理写XML文件读XML文件 INI文件处理 首先得引入QSettings QSettings 是用来存储和读取应用程序设置的一个类 #include "wrinifile.h"#include <QSettings> #include <QtD…

道德经总结

道德经 《道德经》是中国古代伟大哲学家老子所著&#xff0c;全书约五千字&#xff0c;共81章&#xff0c;分为“道经”&#xff08;1–37章&#xff09;和“德经”&#xff08;38–81章&#xff09;两部分。 《道德经》是一部融合哲学、政治、人生智慧于一体的经典著作。它提…

行为型:迭代器模式

目录 1、核心思想 2、实现方式 2.1 模式结构 2.2 实现案例 3、优缺点分析 4、适用场景 1、核心思想 目的&#xff1a;将遍历逻辑与数据存储结构解耦 概念&#xff1a;提供一种机制来按顺序访问集合中的各元素&#xff0c;而不需要知道集合内部的构造 举例&#xff1a;…

人脸识别技术合规备案最新政策详解

《人脸识别技术应用安全管理办法》将于2025年6月1日正式实施&#xff0c;该办法从技术应用、个人信息保护、技术替代、监管体系四方面构建了人脸识别技术的治理框架&#xff0c;旨在平衡技术发展与安全风险。 一、明确技术应用的边界 公共场所使用限制&#xff1a;仅在“维护公…

如何把vue项目部署在nginx上

1&#xff1a;在vscode中把vue项目打包会出现dist文件夹 按照图示内容即可把vue项目部署在nginx上

奇好 PDF安全加密 + 自由拆分合并批量处理 OCR 识别

各位办公小能手们&#xff0c;你们好呀&#xff01;今天我要给大家介绍一款超厉害的软件——奇好PDF。它就像是一个PDF文档处理的超级大管家&#xff0c;啥功能都有&#xff0c;格式转换、编辑、提取、安全保护这些统统不在话下&#xff0c;不管是办公、学习&#xff0c;还是设…

Docker-Harbor 私有镜像仓库使用指南

1.用户管理 为项目创建专用用户&#xff0c;并配置权限&#xff0c;确保该用户能够顺利推送镜像到 Harbor 仓库&#xff0c;确保镜像推送操作的安全性和便捷性。 创建完成后可以根据需要选择是否设置为管理员 角色 权限描述 适用场景 系统管理员 拥有系统的完全控制权限 运维…

HomeAssistant开源的智能家居docker快速部署实践笔记(CentOS7)

1. SGCC_Electricity 应用介绍 SGCC_Electricity 是一个用于将国家电网&#xff08;State Grid Corporation of China&#xff0c;简称 SGCC&#xff09;的电费和用电量数据接入 Home Assistant 的自定义集成组件。通过该应用&#xff0c;用户可以实时追踪家庭用电量情况&…

maven 3.0多线程编译提高编译速度

mvn package 默认只使用 单线程 来执行构建生命周期&#xff08;即顺序地构建每一个模块&#xff09;。 如果你使用的是多模块项目&#xff0c;Maven 从 3.0 开始提供了**并行构建&#xff08;parallel build&#xff09;**的能力&#xff0c;但它不是默认开启的。 如何启用多…

python模块管理环境变量

概要 在 Python 应用中&#xff0c;为了将配置信息与代码分离、增强安全性并支持多环境&#xff08;开发、测试、生产&#xff09;运行&#xff0c;使用专门的模块来管理环境变量是最佳实践。常见工具包括&#xff1a; 标准库 os.environ&#xff1a;直接读取操作系统环境变量…