深入理解 PyTorch:从基础到高级应用

在深度学习的浪潮中,PyTorch 凭借其简洁易用、动态计算图等特性,迅速成为众多开发者和研究人员的首选框架。本文将深入探讨 PyTorch 的核心概念、基础操作以及高级应用,带你全面了解这一强大的深度学习工具。​

一、PyTorch 简介​

PyTorch 是一个基于 Python 的科学计算包,主要用于深度学习领域。它由 Facebook 的 AI 研究小组(FAIR)开发,旨在为深度学习提供一个灵活、高效且易于使用的平台。PyTorch 具有以下几个显著特点:​

  1. 动态计算图:与 TensorFlow 等框架使用的静态计算图不同,PyTorch 采用动态计算图。这意味着在运行时可以根据条件和循环动态构建计算图,使得调试更加方便,代码编写也更加灵活。例如,在训练过程中,我们可以根据当前的训练状态动态调整网络结构或计算逻辑。​
  1. Pythonic 风格:PyTorch 的设计理念遵循 Python 的简洁和直观风格,易于学习和使用。对于熟悉 Python 的开发者来说,能够快速上手 PyTorch。其 API 设计也非常符合 Python 的编程习惯,代码可读性强。​
  1. 强大的 GPU 支持:PyTorch 能够充分利用 GPU 的并行计算能力,大幅提升深度学习模型的训练速度。通过简单的操作,就可以将数据和模型移动到 GPU 上进行计算。​
  1. 丰富的生态系统:PyTorch 拥有庞大的社区和丰富的工具库,如 TorchVision(用于计算机视觉任务)、TorchText(用于自然语言处理任务)等,方便开发者快速实现各种深度学习应用。​

二、PyTorch 基础操作​

1. 张量(Tensor)​

张量是 PyTorch 中最基本的数据结构,类似于 NumPy 中的数组。它可以是一个标量(0 维张量)、向量(1 维张量)、矩阵(2 维张量)或更高维的数组。​

创建张量的方式有多种:​

  • 直接创建:​

TypeScript

取消自动换行复制

import torch​

# 创建一个5x3的未初始化张量​

x = torch.empty(5, 3)​

print(x)​

# 创建一个5x3的随机初始化张量​

y = torch.rand(5, 3)​

print(y)​

# 创建一个5x3的全0张量,数据类型为long​

z = torch.zeros(5, 3, dtype=torch.long)​

print(z)​

  • 从数据创建:​

TypeScript

取消自动换行复制

# 从Python列表创建张量​

data = [[1, 2], [3, 4]]​

a = torch.tensor(data)​

print(a)​

  • 基于现有张量创建:​

TypeScript

取消自动换行复制

# 使用现有张量的属性创建新张量​

b = a.new_ones(5, 3, dtype=torch.double)​

print(b)​

# 创建与a相同大小和数据类型的随机张量​

c = torch.randn_like(a, dtype=torch.float)​

print(c)​

张量支持各种数学运算,如加法、减法、乘法等,运算方式与 NumPy 类似:​

TypeScript

取消自动换行复制

# 加法运算​

result = y + z​

print(result)​

# 另一种加法运算方式​

result = torch.add(y, z)​

print(result)​

# 原地加法运算(直接修改z)​

z.add_(y)​

print(z)​

2. 自动求导(Autograd)​

Autograd 是 PyTorch 中用于自动计算梯度的模块。在深度学习中,我们需要通过反向传播计算梯度来更新模型参数,Autograd 可以自动完成这一过程。​

要使用 Autograd,只需将张量的requires_grad属性设置为True,表示需要计算该张量的梯度。例如:​

TypeScript

取消自动换行复制

x = torch.ones(2, 2, requires_grad=True)​

print(x)​

y = x + 2​

print(y)​

z = y * y * 3​

out = z.mean()​

print(out)​

在上述代码中,x、y、z和out的requires_grad属性都为True。通过调用out.backward(),可以自动计算out关于x的梯度:​

TypeScript

取消自动换行复制

out.backward()​

print(x.grad)​

3. 设备(Device)​

PyTorch 支持在 CPU 和 GPU 上进行计算。通过to()方法,可以将张量和模型移动到指定的设备上。首先需要判断是否有可用的 GPU:​

TypeScript

取消自动换行复制

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")​

print(device)​

然后将张量移动到设备上:​

TypeScript

取消自动换行复制

x = torch.tensor([1, 2, 3])​

x = x.to(device)​

print(x)​

对于模型,也可以使用相同的方法将其移动到设备上:​

TypeScript

取消自动换行复制

import torch.nn as nn​

model = nn.Linear(10, 2)​

model = model.to(device)​

三、PyTorch 神经网络​

1. 定义神经网络​

在 PyTorch 中,定义神经网络通常继承nn.Module类,并实现__init__和forward方法。__init__方法用于定义网络层,forward方法用于定义数据的前向传播过程。​

以下是一个简单的全连接神经网络示例:​

TypeScript

取消自动换行复制

import torch.nn as nn​

import torch.nn.functional as F​

class Net(nn.Module):​

def __init__(self):​

super(Net, self).__init__()​

# 输入图像大小为32x32,1个通道,输出6个特征图​

self.conv1 = nn.Conv2d(1, 6, 3)​

# 输入6个特征图,输出16个特征图​

self.conv2 = nn.Conv2d(6, 16, 3)​

# 全连接层,输入16 * 6 * 6个神经元,输出120个神经元​

self.fc1 = nn.Linear(16 * 6 * 6, 120)​

self.fc2 = nn.Linear(120, 84)​

self.fc3 = nn.Linear(84, 10)​

def forward(self, x):​

# 卷积层 + ReLU激活函数 + 最大池化​

x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))​

x = F.max_pool2d(F.relu(self.conv2(x)), 2)​

# 将张量展平为一维向量​

x = x.view(-1, self.num_flat_features(x))​

x = F.relu(self.fc1(x))​

x = F.relu(self.fc2(x))​

x = self.fc3(x)​

return x​

def num_flat_features(self, x):​

size = x.size()[1:] # 除批量维度外的所有维度​

num_features = 1​

for s in size:​

num_features *= s​

return num_features​

net = Net()​

print(net)​

2. 损失函数和优化器​

训练神经网络需要定义损失函数和优化器。常见的损失函数有均方误差损失函数(nn.MSELoss)、交叉熵损失函数(nn.CrossEntropyLoss)等。优化器有随机梯度下降(torch.optim.SGD)、Adam 优化器(torch.optim.Adam)等。​

TypeScript

取消自动换行复制

import torch.optim as optim​

# 定义损失函数​

criterion = nn.CrossEntropyLoss()​

# 定义优化器​

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)​

3. 训练神经网络​

训练神经网络的一般步骤如下:​

  1. 前向传播,计算预测值。​
  1. 计算损失。​
  1. 反向传播,计算梯度。​
  1. 使用优化器更新模型参数。​

TypeScript

取消自动换行复制

for epoch in range(2):​

running_loss = 0.0​

for i, data in enumerate(trainloader, 0):​

# 获取输入数据和标签​

inputs, labels = data[0].to(device), data[1].to(device)​

# 梯度清零​

optimizer.zero_grad()​

# 前向传播 + 反向传播 + 优化​

outputs = net(inputs)​

loss = criterion(outputs, labels)​

loss.backward()​

optimizer.step()​

# 打印统计信息​

running_loss += loss.item()​

if i % 2000 == 1999:​

print('[%d, %5d] loss: %.3f' %​

(epoch + 1, i + 1, running_loss / 2000))​

running_loss = 0.0​

print('Finished Training')​

四、PyTorch 高级应用​

1. 预训练模型​

PyTorch 提供了许多预训练模型,如 ResNet、VGG、BERT 等。我们可以直接加载这些预训练模型,并在其基础上进行微调,以适应特定的任务。​

以加载 ResNet18 预训练模型为例:​

TypeScript

取消自动换行复制

import torchvision.models as models​

# 加载预训练的ResNet18模型​

model = models.resnet18(pretrained=True)​

# 冻结所有参数,不进行训练​

for param in model.parameters():​

param.requires_grad = False​

# 修改最后一层全连接层,以适应新的分类任务​

num_ftrs = model.fc.in_features​

model.fc = nn.Linear(num_ftrs, 2)​

2. 自定义数据集和数据加载器​

在实际应用中,我们通常需要处理自定义的数据集。通过继承torch.utils.data.Dataset类,可以创建自定义数据集,并使用torch.utils.data.DataLoader进行数据加载和批量处理。​

TypeScript

取消自动换行复制

import torch.utils.data as data​

class CustomDataset(data.Dataset):​

def __init__(self, data_list, label_list, transform=None):​

self.data_list = data_list​

self.label_list = label_list​

self.transform = transform​

def __len__(self):​

return len(self.data_list)​

def __getitem__(self, index):​

data = self.data_list[index]​

label = self.label_list[index]​

if self.transform is not None:​

data = self.transform(data)​

return data, label​

# 使用示例​

custom_dataset = CustomDataset(data_list, label_list)​

dataloader = data.DataLoader(custom_dataset, batch_size=4, shuffle=True)​

3. 分布式训练​

对于大规模的深度学习任务,分布式训练可以显著提高训练效率。PyTorch 提供了分布式训练的支持,通过torch.distributed模块可以实现多机多卡的分布式训练。​

以下是一个简单的分布式训练示例(假设在单机多卡环境下):​

TypeScript

取消自动换行复制

import torch.distributed as dist​

import torch.multiprocessing as mp​

def train(rank, world_size):​

# 初始化分布式环境​

dist.init_process_group("nccl", rank=rank, world_size=world_size)​

# 每个进程创建一个模型和优化器​

model = nn.Linear(10, 2).to(rank)​

optimizer = optim.SGD(model.parameters(), lr=0.001)​

# 数据并行包装模型​

model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])​

# 训练过程​

for epoch in range(2):​

running_loss = 0.0​

for i, data in enumerate(trainloader, 0):​

inputs, labels = data[0].to(rank), data[1].to(rank)​

optimizer.zero_grad()​

outputs = model(inputs)​

loss = criterion(outputs, labels)​

loss.backward()​

optimizer.step()​

running_loss += loss.item()​

print('Rank {} loss: {:.3f}'.format(rank, running_loss))​

# 销毁分布式环境​

dist.destroy_process_group()​

if __name__ == '__main__':​

world_size = torch.cuda.device_count()​

mp.spawn(train, args=(world_size,), nprocs=world_size)​

五、总结​

本文全面介绍了 PyTorch 的核心概念、基础操作、神经网络构建以及高级应用。从张量的创建和运算,到自动求导、神经网络训练,再到预训练模型、自定义数据集和分布式训练,涵盖了 PyTorch 在深度学习开发中的主要方面。希望通过本文的学习,你能够对 PyTorch 有更深入的理解,并在实际项目中熟练运用这一强大的深度学习框架。随着深度学习技术的不断发展,PyTorch 也在持续更新和完善,未来还会有更多强大的功能和应用场景等待我们去探索和实践。​

以上博客详细梳理了 Pytorch 从基础到进阶的知识。如果你对某个部分还想进一步了解,或者有特定的应用场景想探讨,欢迎随时告诉我。​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/84819.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 中的 synchronized 与 Lock:深度对比、使用场景及高级用法

💡 前言 在多线程并发编程中,线程安全问题始终是开发者需要重点关注的核心内容之一。Java 提供了多种机制来实现同步控制,其中最常用的两种方式是: 使用 synchronized 关键字使用 java.util.concurrent.locks.Lock 接口&#xf…

Notepad++如何列选

在 Notepad 中,你可以通过 列模式(Column Mode) 进行垂直选择文本(列选),以下是具体操作方法: 方法 1:键盘 鼠标列选 按住 Alt 键(或 Alt Shift)。 按住鼠…

华为OD机考-水仙花数Ⅰ-逻辑分析(JAVA 2025B卷)

import java.util.*; public static Integer get(int count,int c){if(count<3||count>7){return -1;}//存储每位数的最高位……最低位int[] arr new int[count];List<Integer> res new ArrayList<>();for(int i(int) Math.pow(10,count-1);i<(int) Math…

基于 STL+VMD 二次分解的 Informer-LSTM 并行预测模型详解与案例

一、背景与动机 在时间序列预测中,如电力负荷、风速、交通流量等复杂数据常表现为: 非线性:趋势+季节+突变+噪声 多尺度:高频扰动与低频变化共存 长时依赖:远期信息也影响当前预测 传统模型(如 ARIMA、LSTM)往往无法兼顾全局趋势建模与局部扰动感知,因此我们提出一种 …

【Linux Learning】SSH连线出现警告:WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!

问题&#xff1a;WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY! Someone could be eavesdropping on you right now (man-in-the-middle attack)! It is al…

轻量级密码算法PRESENT的C语言实现(无第三方库)

一、PRESENT算法介绍 PRESENT是一种超轻量级分组密码算法&#xff0c;由Bogdanov等人在2007年提出&#xff0c;专门为资源受限环境如RFID标签和传感器网络设计。该算法在硬件实现上仅需1570个门等效电路(GE)&#xff0c;在保持较高安全性的同时实现了极小的硬件占用空间。PRES…

if的简化书写,提高执行效率

很多时候可能有下面判断 if(a0) {b1;} else if(a1) {b0;} 就是ba的反向值&#xff1a; a0;b1&#xff1b; a1;b0; 这时&#xff0c;可以简化如下&#xff1a; ba^1 使用异或&#xff0c;程序更简洁&#xff0c;执行效率也更高 其他的也可以类似使用按位异或优化代码

Vim 调用外部命令学习笔记

Vim 外部命令集成完全指南 文章目录 Vim 外部命令集成完全指南核心概念理解命令语法解析语法对比 常用外部命令详解文本排序与去重文本筛选与搜索高级 grep 搜索技巧文本替换与编辑字符处理高级文本处理编程语言处理其他实用命令 范围操作示例指定行范围处理复合命令示例 实用技…

bash挖矿木马事件全景复盘与企业级防御实战20250612

&#x1f427; CentOS “-bash 挖矿木马” 事件全景复盘与企业级防御实战 ✍️ 作者&#xff1a;Narutolxy | &#x1f4c5; 日期&#xff1a;2025-06-12 | &#x1f3f7;️ 标签&#xff1a;Linux 安全、应急响应、运维加固、实战复盘 &#x1f4d8; 内容简介 本文是一场真实…

「Linux中Shell命令」Shell命令基础

知识点详细解析 Shell简介 Shell是Linux操作系统系统中用户与操作系统内核交互的接口。它既是命令解释器,负责接收用户输入的命令并将其转换为内核能够理解的指令,也是一种脚本编程语言。作为Linux操作系统的重要组成部分,Shell扮演着用户与系统内核之间的"中间人"…

202557读书笔记|《梦里花落知多少(轻经典)》——有你在的地方才最美

《梦里花落知多少&#xff08;轻经典&#xff09;》作者三毛&#xff0c;物极必反&#xff0c;阴晴圆缺&#xff0c;小满即万全么&#xff1f;因为幸福过于满溢。所以幸福被收走了。 没有看过太多三毛的作品&#xff0c;给我的感觉她是很敏感&#xff0c;多愁善感及没有安全感…

对象映射 C# 中 Mapster 和 AutoMapper 的比较

Mapster和AutoMapper是C#领域两大主流对象映射库&#xff0c;各具特色。Mapster以高性能著称&#xff0c;使用表达式树实现零反射映射&#xff0c;首次编译后执行效率极高&#xff0c;适合对性能敏感的场景&#xff1b;AutoMapper则提供更丰富的功能集&#xff0c;如条件映射和…

QEMU源码全解析 —— 块设备虚拟化(26)

接前一篇文章:QEMU源码全解析 —— 块设备虚拟化(25) 本文内容参考: 《趣谈Linux操作系统》 —— 刘超,极客时间 《QEMU/KVM源码解析与应用》 —— 李强,机械工业出版社 Virt

微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】

微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来&#xff0c;Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…

RED DA认证-EN18031网络安全常见问题以及解答

Q&#xff1a;RED DA是否对所有无线模块和设备强制要求&#xff1f; A&#xff1a;是的&#xff0c;RED DA适用于欧盟境内销售的所有无线设备&#xff0c;包括WWAN、蓝牙或Wi-Fi模块。唯一例外是GNSS模块&#xff08;仅支持接收功能&#xff0c;无需认证&#xff09;。 Q&…

腾讯开源 ovCompose 跨平台框架:实现一次跨三端(Android/iOS/鸿蒙)

在移动应用开发领域&#xff0c;跨平台技术一直是开发者们追求的目标&#xff0c;它能够帮助企业降低开发成本、提高开发效率&#xff0c;同时保证应用在不同平台上的一致性体验。2025 年 6 月 3 日&#xff0c;腾讯视频团队迎来了一个重要的里程碑 —— 正式发布 ovCompose 跨…

对3D对象进行形变分析

1&#xff0c;目的 分析3D实例对象相对标准参照物的形变。 一般用于质地较软的材质&#xff08;例如橡胶&#xff0c;布料&#xff09;查找&#xff0c;检查等。 标准参考模型 需匹配的实例&#xff1a; 形变后的模型&#xff1a;* 形变后的模型&#xff1a; 实例形变后的…

宝塔面板WordPress中使用Contact Form 7插件收不到邮件的解决方法

如果是宝塔面板的环境下&#xff0c;在WordPress中使用Contact Form 7插件提交表单时显示成功&#xff0c;但邮箱未收到邮件&#xff0c;可能是由于服务器邮件功能配置问题。以下是几种常见解决方法&#xff1a; 1. 检查邮件发送方式 默认情况下&#xff0c;Contact Form 7 使…

Android中的DX、D8、R8

Kotlin 版本所需的 AGP、D8 和 R8 版本 :https://developer.android.google.cn/build/kotlin-support?hlzh_cn R8&#xff1a;https://developer.android.google.cn/tools/retrace?hlzh_cn D8&#xff1a;https://developer.android.google.cn/tools/d8?hlzh_cn 如上图&…

通义灵码 AI IDE 上线!智能体+MCP 从手动调用工具过渡到“AI 主动调度资源”

告诉大家一个好消息&#xff0c;通义灵码发布了 AI 编程 IDE &#xff1a;Lingma IDE &#xff0c;你没看错&#xff0c;通义灵码也推出了自己的 AI IDE 客户端&#xff0c;不是 AI 编程插件&#xff0c;是 IDE 。 Lingma IDE 是基于 VS Code 开源版本构建的智能代码编辑器&am…