深度学习入门到实战:用PyTorch打通数学、张量与模型训练全链路​

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。

一. 人工智能、机器学习与深度学习的关系

1.1 概念层次解析

  • 人工智能(AI):使机器模拟人类智能的广义领域

  • 机器学习(ML):通过数据驱动的方法让系统自动改进性能

  • 深度学习(DL):基于多层神经网络的机器学习子领域

关系示意图

人工智能 ⊃ 机器学习 ⊃ 深度学习

image.png

二. PyTorch环境配置

2.1 Conda环境管理

# 创建虚拟环境  
conda create -n pytorch_env python=3.9  
# 激活环境  
conda activate pytorch_env  
# 安装PyTorch(根据CUDA版本选择)  
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia  # GPU版本  
conda install pytorch torchvision torchaudio cpuonly -c pytorch  # CPU版本

2.2 验证安装

import torch  
print(torch.__version__)            # 输出:2.0.1  
print(torch.cuda.is_available())    # 输出:True(GPU可用时)

三. 数学基础与张量操作

3.1 标量、向量、张量

image.png

代码示例:张量创建与操作

# 创建张量  
x = torch.empty(2, 3)  # 未初始化  
y = torch.zeros(2, 3, dtype=torch.int32)  
z = torch.randn(2, 3)  # 标准正态分布  
# 数学运算  
a = torch.tensor([[1,2],[3,4]], dtype=torch.float32)  
b = torch.tensor([[5,6],[7,8]], dtype=torch.float32)  
print(a + b)    # 逐元素加法  
print(a @ b.T)  # 矩阵乘法

四. 数据预处理与线性代数

4.1 数据标准化

from torchvision import transforms  
transform = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize(mean=[0.485, 0.456, 0.406],   std=[0.229, 0.224, 0.225])  
])  
# 应用到数据集  
dataset = datasets.CIFAR10(..., transform=transform)

4.2 线性代数核心操作

# 矩阵分解  
A = torch.randn(3, 3)  
U, S, V = torch.svd(A)  # 奇异值分解  
# 特征值计算  
eigenvalues = torch.linalg.eigvalsh(A)  
# 张量缩并  
tensor = torch.einsum('ijk,jl->ikl', a, b)

五. 神经网络基础

5.1 神经元数学模型

image.png

  • wi:权重

  • bb:偏置(提供平移能力)

  • ff:激活函数

代码示例:单神经元实现

class Neuron(nn.Module):  def __init__(self, input_dim):  super().__init__()  self.linear = nn.Linear(input_dim, 1)  self.activation = nn.Sigmoid()  def forward(self, x):  return self.activation(self.linear(x))  
neuron = Neuron(3)  
output = neuron(torch.tensor([0.5, -1.2, 0.8]))

六. 模型训练全流程

6.1 训练要素定义

image.png

6.2 典型训练循环

model = nn.Sequential(  nn.Linear(784, 128),  nn.ReLU(),  nn.Linear(128, 10)  
)  
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  
for epoch in range(10):  for inputs, labels in train_loader:  optimizer.zero_grad()  outputs = model(inputs.view(-1, 784))  loss = criterion(outputs, labels)  loss.backward()  optimizer.step()  print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

七. 激活函数与反向传播

7.1 常见激活函数对比

image.png

7.2 反向传播数学原理

链式法则示例:

image.png

其中 z=∑wixi+bz=∑wixi+b,y=f(z)y=f(z)

代码示例:手动实现反向传播

x = torch.tensor(2.0, requires_grad=True)  
y = torch.tensor(3.0)  
w = torch.tensor(1.0, requires_grad=True)  
b = torch.tensor(0.5, requires_grad=True)  
# 前向计算  
z = w * x + b  
loss = (z - y)**2  
# 反向传播  
loss.backward()  
print(w.grad)  # 输出:4.0 (∂loss/∂w = 2*(wx + b - y)*x = 2*(2+0.5-3)*2 = 4)

image.png

:本文代码基于PyTorch 2.0+版本实现,运行前需安装:

pip install torch torchvision matplotlib

更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/81974.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows服务器部署jenkins工具(一)

jenkins作为一款常用的构建发布工具,极大的简化了项目部署发布流程。jenkins通常是部署在linux服务上,今天给大家分享的是windows服务器上如何搭建jenkins发布工具。 1.首先第一步还是看windows安装docker 这篇文章哈,当然也可以不采用docke…

前端开发规范性利器系列之:ESLint

前言 我是一名从事低代码平台研发的前端CV程序猿,有几十名像我一样的小伙伴协同研发。在长期的多人协作和滚动迭代中,不出意外,代码中会充斥各种“坏味道”,如代码风格不统一、扩展性和灵活性降低等问题。我们是如何解决这些问题的…

数据结构知识点汇总

1、在数据结构中,随机访问是指能够直接访问任一元素,而不需要从特定的起始位置开始,也不需要按顺序访问其他元素。这种访问方式通常不涉及遍历。例如,数组(array)支持随机访问,你可以直接通过索…

ubuntu中上传项目至GitHub仓库教程

一、到github官网注册用户 1.注册用户 地址:https://github.com/ 2.安装Git 打开终端,输入指令git,检查是否已安装Git 如果没有安装就输入指令 sudo apt-get install git 二、上传项目到github 1.创建项目仓库 进入github主页,点击号…

C#在 .NET 9.0 中启用二进制序列化:配置、风险与替代方案

在 .NET 9.0 中启用二进制序列化:配置、风险与替代方案 引言一、启用二进制序列化的步骤二、实现序列化与反序列化三、安全风险与缓解措施四、推荐替代方案五、总结 引言 在 .NET 生态中,二进制序列化(Binary Serialization)曾是…

如何解决鸿蒙应用闪退问题

如何解决鸿蒙应用闪退问题 本文是一份面向 ArkTS/JavaScript/C 多语言开发者的综合性排查与优化手册,覆盖 HarmonyOS/OpenHarmony 5.x 时代 常见闪退根因、诊断流程、调试技巧、CI 监控及线上防护方案,力争帮你把 Crash 数量降到 …

【Java高阶面经:微服务篇】4.大促生存法则:微服务降级实战与高可用架构设计

一、降级决策的核心逻辑:资源博弈下的生存选择 1.1 大促场景的资源极限挑战 在电商大促等极端流量场景下,系统面临的资源瓶颈呈现指数级增长: 流量特征: 峰值QPS可达日常的50倍以上(如某电商大促下单QPS从1万突增至50万)流量毛刺持续时间短(通常2-4小时),但对系统稳…

关于我对传统系统机构向大模型架构演进的认知

最近这段时间在研究大模型,不可避免会接触到架构。从我职业经历一路走来,自然会拿着现有模型的架构和我之前接触到的系统架构进行对比。今天就大模型的架构和传统系统架构进行一下梳理,说一说我的见解。 在我眼里,传统系统架构如…

图片识别(TransFormerCNNMLP)

目录 一、Transformer (一)ViT:Transformer 引入计算机视觉的里程碑 (二)Swin-Transformer:借鉴卷积改进 ViT (三)VAN:使用卷积模仿 ViT (四)…

性能测试、压力测试、负载测试如何区分

一、前言:为何区分三者如此重要? “你们做过压力测试吗?”“系统性能测试做得怎么样?”“负载测试的数据能分享一下吗?” 在很多软件开发与测试团队的日常沟通中,“性能测试”“压力测试”“负载测试”这…

工业路由器WiFi6+5G的作用与使用指南,和普通路由器对比

工业路由器的技术优势 在现代工业环境中,网络连接的可靠性与效率直接影响生产效率和数据处理能力。WiFi 6(即802.11ax)和5G技术的结合,为工业路由器注入了强大的性能,使其成为智能制造、物联网和边缘计算的理想选择。…

紫光同创FPGA实现AD9238数据采集转UDP网络传输,分享PDS工程源码和技术支持和QT上位机

目录 1、前言工程概述免责声明 2、相关方案推荐我已有的所有工程源码总目录----方便你快速找到自己喜欢的项目紫光同创FPGA相关方案推荐我这里已有的以太网方案本方案在Xilinx系列FPGA的应用方案 3、设计思路框架工程设计原理框图AD输入源AD9238数据采集AD9238数据缓存控制模块…

如何修改服务器管理员账号名和密码(1)

命令解析sudo useradd -m -s /bin/bash 新用户名 1. sudo 作用:以超级用户(root)权限执行命令 为什么需要:创建用户需要修改系统文件(/etc/passwd, /etc/shadow等),普通用户没有这个权限 替代方案:如果已经是root用户&#xff0…

Linux shell 正则表达式高效使用

Linux正则表达式高效使用教程 正则表达式是Linux命令行中强大的文本处理工具,能够极大提高搜索和匹配效率。下面为新手提供一个简单教程,介绍如何在grep和find命令中使用正则表达式。 使用建议:使用grep时要加-E选项使其支持扩展正则表达式&…

你通俗易懂的理解——线程、多线程与线程池

一:异常处理 1.1 异常概述 (1)场景 (2)定义 (3)异常抛出机制 Java把不同的异常用不同的类表示 (4)如何对待异常 1.2 常见异常类 (1)Throwable &am…

w~自动驾驶~合集13

我自己的原文哦~ https://blog.51cto.com/whaosoft/13933252 # 小米智能驾驶技术的一些猜测 来蹭一下小米汽车智能驾驶的热度,昨晚听了雷总小米汽车的发布,心潮澎湃寻思下单一辆奈何现实不允许hhh。 言归正传吧, 本来是想主要听一下小米…

AI 面试帮 开发日志

项目源码 https://cnb.cool/szu/TravelBest/Platform/-/tree/main 文章目录 架构微服务网络通信延迟 中间件redisMongoDB 架构 微服务 优点: 模块间解耦、职责清晰,独立部署与扩展,单个服务故障不会影响整个系统,便于持续交付与…

论文阅读(四):Agglomerative Transformer for Human-Object Interaction Detection

论文来源:ICCV(2023) 项目地址:https://github.com/six6607/AGER.git 1.研究背景 人机交互(HOI)检测需要同时定位人与物体对并识别其交互关系,核心挑战在于区分相似交互的细微视觉差异&#…

部署java项目

1.编写shell脚本部署服务 restart.sh #!/bin/bash # # start the user program # echo "-------------------- start jk service --------------------" LOG_DIR"/home/joy/usr/app/ers-log" LOG_FILE"$LOG_DIR/log_$(date "%Y%m%d").txt&…

第18天-NumPy + Pandas + Matplotlib多维度直方图

示例1:带样式的柱状图 python 复制 下载 import numpy as np import pandas as pd import matplotlib.pyplot as plt# 生成数据 df = pd.DataFrame(np.random.randint(10, 100, size=(8, 4)),columns=[Spring, Summer, Autumn, Winter],index=[2015, 2016, 2017, 2018, 20…