PyTorch 张量(Tensor)详解:从基础到实战

1. 引言

在深度学习和科学计算领域,张量(Tensor) 是最基础的数据结构。PyTorch 作为当前最流行的深度学习框架之一,其核心计算单元就是张量。与 NumPy 的 ndarray 类似,PyTorch 张量支持高效的数值计算,但额外提供了 GPU 加速 和 自动微分(Autograd) 功能,使其成为构建和训练神经网络的理想选择。

本文将全面介绍 PyTorch 张量的核心概念、基本操作、高级特性及实际应用,帮助读者掌握张量的使用方法,并理解其在深度学习中的作用。

2. 什么是张量?

张量是多维数组的泛化,可以表示不同维度的数据:

  • 0D 张量(标量):单个数值,如 torch.tensor(5)

  • 1D 张量(向量):一维数组,如 torch.tensor([1, 2, 3])

  • 2D 张量(矩阵):二维数组,如 torch.tensor([[1, 2], [3, 4]])

  • 3D+ 张量(高阶张量):如 RGB 图像(3D)、视频数据(4D)等

PyTorch 张量的主要特点:

  1. 支持 GPU 加速:可无缝切换 CPU/GPU 计算。

  2. 自动微分:用于神经网络的反向传播。

  3. 动态计算图:更灵活的模型构建方式(与 TensorFlow 1.x 的静态计算图不同)。

3. 张量的创建与初始化

3.1 从 Python 列表或 NumPy 数组创建

import torch
import numpy as np# 从列表创建
t1 = torch.tensor([1, 2, 3])  # 1D 张量
t2 = torch.tensor([[1, 2], [3, 4]])  # 2D 张量# 从 NumPy 数组创建
arr = np.array([1, 2, 3])
t3 = torch.from_numpy(arr)  # 共享内存(修改一个会影响另一个)

3.2 特殊初始化方法

# 全零张量
zeros = torch.zeros(2, 3)  # 2x3 的零矩阵# 全一张量
ones = torch.ones(2)  # [1., 1.]# 随机张量
rand_uniform = torch.rand(2, 2)  # 0~1 均匀分布
rand_normal = torch.randn(2, 2)  # 标准正态分布# 类似现有张量的形状
x = torch.tensor([[1, 2], [3, 4]])
x_like = torch.rand_like(x)  # 形状与 x 相同,值随机

4. 张量的基本属性

每个 PyTorch 张量都有以下关键属性:

x = torch.rand(2, 3, dtype=torch.float32, device="cuda")print(x.shape)      # 形状: torch.Size([2, 3])
print(x.dtype)      # 数据类型: torch.float32
print(x.device)     # 存储设备: cpu / cuda
print(x.requires_grad)  # 是否启用梯度计算(用于 Autograd)

4.1 数据类型(dtype)

PyTorch 支持多种数据类型:

  • torch.float32(默认)

  • torch.int64

  • torch.bool(布尔张量)

可以通过 .to() 方法转换:

x = torch.tensor([1, 2], dtype=torch.float32)
y = x.to(torch.int64)  # 转换为整型

4.2 设备(CPU/GPU)

PyTorch 允许张量在 CPU 或 GPU 上运行:

if torch.cuda.is_available():device = torch.device("cuda")x = x.to(device)  # 移动到 GPUy = y.to("cuda")  # 简写方式

5. 张量的基本运算

5.1 算术运算

a = torch.tensor([1, 2])
b = torch.tensor([3, 4])# 加法
c = a + b  # 等价于 torch.add(a, b)# 乘法(逐元素)
d = a * b  # [3, 8]# 矩阵乘法
mat_a = torch.rand(2, 3)
mat_b = torch.rand(3, 2)
mat_c = torch.matmul(mat_a, mat_b)  # 或 mat_a @ mat_b

5.2 形状操作

x = torch.rand(4, 4)# 改变形状(类似 NumPy 的 reshape)
y = x.view(16)  # 展平为一维张量
z = x.view(2, 8)  # 调整为 2x8# 转置
x_t = x.permute(1, 0)  # 行列交换# 扩维 / 压缩
x_expanded = x.unsqueeze(0)  # 增加一个维度(1x4x4)
x_squeezed = x_expanded.squeeze()  # 去除大小为1的维度

5.3 索引与切片

x = torch.rand(3, 4)# 取第一行
row = x[0, :]# 取前两列
cols = x[:, :2]# 布尔索引
mask = x > 0.5
filtered = x[mask]  # 返回满足条件的元素

6. 自动微分(Autograd)

PyTorch 的 autograd 模块支持自动计算梯度,适用于反向传播:

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x  # 计算图构建
y.backward()  # 反向传播
print(x.grad)  # dy/dx = 2x + 3 → 7.0

6.1 禁用梯度计算

with torch.no_grad():y = x * 2  # 不记录梯度

7. 张量与 NumPy 的互操作

PyTorch 张量可以无缝转换为 NumPy 数组:

# Tensor → NumPy
a = torch.rand(2, 2)
b = a.numpy()  # 共享内存(修改一个会影响另一个)# NumPy → Tensor
c = np.array([1, 2])
d = torch.from_numpy(c)  # 共享内存

8. 实际应用示例

8.1 线性回归(手动实现)

# 数据准备
X = torch.rand(100, 1)
y = 3 * X + 2 + 0.1 * torch.randn(100, 1)# 初始化参数
w = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 训练
lr = 0.01
for epoch in range(100):y_pred = w * X + bloss = ((y_pred - y) ** 2).mean()loss.backward()  # 计算梯度with torch.no_grad():w -= lr * w.gradb -= lr * b.gradw.grad.zero_()b.grad.zero_()print(f"w: {w.item()}, b: {b.item()}")

8.2 张量在 CNN 中的应用

import torch.nn as nn# 模拟输入(batch_size=1, channels=3, height=32, width=32)
input_tensor = torch.rand(1, 3, 32, 32)# 定义一个简单的 CNN
model = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3),nn.ReLU(),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(16 * 15 * 15, 10)  # 假设输出 10 类
)output = model(input_tensor)
print(output.shape)  # torch.Size([1, 10])

9. 总结

PyTorch 张量是深度学习的基础数据结构,支持:

  • 多维数组计算(类似 NumPy)

  • GPU 加速(大幅提升计算速度)

  • 自动微分(简化神经网络训练)

  • 动态计算图(灵活调试模型)

掌握张量的基本操作是学习 PyTorch 的关键步骤。建议读者通过官方文档和实际项目加深理解,逐步掌握张量的高级用法(如广播机制、高级索引等)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/97493.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/97493.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CPTS---Hospital

端口扫描 nmap -A -p- -n -Pn -T4 10.10.11.241 22/tcp open ssh OpenSSH 9.0p1 Ubuntu 1ubuntu8.5 (Ubuntu Linux; protocol 2.0) | ssh-hostkey: | 256 e1:4b:4b:3a:6d:18:66:69:39:f7:aa:74:b3:16:0a:aa (ECDSA) |_ 256 96:c1:dc:d8:97:20:95:e7:01:5…

【贪心算法】day5

📝前言说明: 本专栏主要记录本人的贪心算法学习以及LeetCode刷题记录,按专题划分每题主要记录:(1)本人解法 本人屎山代码;(2)优质解法 优质代码;&#xff…

软考中级【网络工程师】第6版教材 第4章 无线通信网 (上)

考点分析: 重要程度:⭐⭐⭐ 选择题考查1 ~ 3分,案例分析可能考查填空和简答 高频考点:802.11信道与频段、CSMA/CA、无线网络优化、无线认证、无线配置步骤 新教材变化:新增4G/5G、删除无线城域网 本章将详述蜂窝移动通信系统、无线局域网以及无线个人网的体系结构和实用技…

vscode+EIDE+Clangd环境导入keil C51以及MDK工程

我最近一直在使用vscodeclangd的编译环境替代了vscode自带的c/c插件。感觉clangd的环境更加优秀,能够更好找到函数、全局变量等定义调用等。如果使用keil C51以及MDK环境开发51单片机或者STM32单片机就需要使用到了EIDE这个插件这个插件现在能够自动生成compile_com…

FTP - 学习/实践

1.应用场景 主要用于学习和使用FTP服务,同时研究其架构实现, 以及日常开发中的使用。 FTP(文件传输协议)是一种用于网络文件传输的标准协议,基于客户端-服务器模型运行,通过控制通道(端口21)和…

【瑞吉外卖】手机号验证码登录(用QQ邮件发送代替)

目录 介绍 一、获取授权码 二、前端代码修改 三、后端代码修改 ①pom依赖 ②yml配置 ③控制层 ④业务层 ⑤工具类 介绍 本文介绍了QQ邮箱验证码登录功能的实现步骤: 获取QQ邮箱授权码并配置;前端修改登录页面,增加验证码发送接口调…

为什么要用 Markdown?以及如何使用它

在处理大量文档时,尤其是在构建知识库、进行文档分析或训练大语言模型(LLM)时,将各种格式的文件(如 PDF、Word、Excel、PPT、HTML 等)转换为统一的 Markdown 格式,能够显著提高处理效率和兼容性…

订餐后台管理系统-day06菜品分类模块

菜品分类显示我们需要先实现分类操作,因为没有菜品分类,我们无法准确知道当前菜品属于哪个分类,在前端显示时,需要根据分类显示数据先显示分类列表页面准备路由manage_bp.route(/food/cat/list) def food_cat_list():# 默认页面从…

More Effective C++ 条款20:协助完成返回值优化(Facilitate the Return Value Optimization)

More Effective C 条款20:协助完成返回值优化(Facilitate the Return Value Optimization)核心思想:返回值优化(RVO)是编译器消除函数返回时临时对象的一种重要优化技术。通过编写适合RVO的代码&#xff0c…

《HelloGitHub》第 113 期

兴趣是最好的老师,HelloGitHub 让你对开源感兴趣!简介HelloGitHub 分享 GitHub 上有趣、入门级的开源项目。github.com/521xueweihan/HelloGitHub这里有实战项目、入门教程、黑科技、开源书籍、大厂开源项目等,涵盖多种编程语言 Python、Java…

萌宝喂养日志-我用AI做喂养记录小程序1-原型设计

准备工作 首先,注册硅基流动账号,并配置Trae开发工具。 ↓现在注册有2000 万 Tokens 的免费额度↓。 硅基流动统一登录 具体可以看我这篇文章:Trae接入自有Deepseek模型,不再排队等待-CSDN博客 实践 设计原型图 我想开发一…

工业产品营销:概念、原理、流程与实践指南

摘要 工业产品营销是针对B2B市场的专业化推广活动,旨在满足企业客户的生产和运营需求。本文详细阐述了工业产品营销的概念与特点,分析其核心原理,包括客户需求驱动、价值传递和关系管理。营销过程涵盖市场调研、细分定位、策略制定、执行、转化及售后服务六个步骤,并提供品…

【读书笔记】《人体微生物的奥秘》

Follow Your Gut:人体微生物的奥秘 引言:从蚊子到微生物 夏天来临,许多人又开始纠结为什么有些人特别招蚊子。有人说是血型问题,有人说是皮肤嫩度,还有人归结于基因。但今天要分享的一本书,虽然标题看似讨论…

【Matplotlib学习】驾驭画布:Matplotlib 布局方式从入门到精通完全指南

目录驾驭画布:Matplotlib 布局方式从入门到精通完全指南一、 核心理念:理解 Figure 和 Axes二、 布局方式大全:从简单到复杂类别一:自动创建与基础单图布局类别二:规律网格布局 - 主力军类别三:复杂网格布局…

【C#】在一个任意旋转的矩形(由四个顶点定义)内绘制一个内切椭圆

核心点:在一个任意旋转的矩形(由四个顶点定义)内绘制一个内切椭圆 实现步骤 计算矩形中心:作为旋转中心点 创建椭圆路径:在未旋转状态下定义椭圆 应用旋转变换:使用矩阵绕中心点旋转路径 绘制变换后的路…

洛谷 P2052 [NOI2011] 道路修建-普及/提高-

P2052 [NOI2011] 道路修建 题目描述 在 W 星球上有 nnn 个国家。为了各自国家的经济发展,他们决定在各个国家之间建设双向道路使得国家之间连通。但是每个国家的国王都很吝啬,他们只愿意修建恰好 n−1n - 1n−1 条双向道路。 每条道路的修建都要付出一定…

springboot连接不上redis,但是redis客户端是能连接上的

除了常规排查,还有一个就是检查配置文件格式。这个旧版本格式会导致读取不到配置,spring:# 对应 RedisProperties 类redis:host: 127.0.0.1port: 6379 # password: 123456 # Redis 服务器密码,默认为空。生产中,一定要设置 Red…

GitBook 完整使用指南:从安装到部署

文章目录 环境准备 Node.js 安装 GitBook CLI 安装 项目初始化 创建项目结构 (可选) npm 初始化 目录结构配置 开发与调试 本地服务启动 构建静态文件 配置文件详解 插件系统 常用插件推荐 插件安装与配置 自定义样式 部署指南 GitHub Pages 部署 Netlify 部署 高级功能 多语言…

VS安装 .NETFramework,Version=v4.6.x

一、前言 在使用VS2019打开项目时提示MSB3644 找不到 .NETFramework,Versionv4.6.2 的引用程序集的错误 二、解决方案 1.百度......找到了解决方法了 2.打开Visual Studio Install 3.点击修改 4.点击单个组件,安装相对应的版本即可

Visual Studio Code中launch.json的解析笔记

<摘要> launch.json 是 Visual Studio Code 中用于配置调试任务的核心文件。本文解析了其最常用的配置字段&#xff0c;涵盖了基本调试设置、程序控制、环境配置和高级调试功能。理解这些字段能帮助开发者高效配置调试环境&#xff0c;提升开发效率。<解析> 1. 背景…