Day 35

模型可视化与推理
知识点回顾:

三种不同的模型可视化方法:推荐torchinfo打印summary+权重分布可视化
进度条功能:手动和自动写法,让打印结果更加美观
推理的写法:评估模式
模型结构可视化
理解一个深度学习网络最重要的2点:

1. 了解损失如何定义的,知道损失从何而来----把抽象的任务通过损失函数量化出来

2. 了解参数总量,即知道每一层的设计---层设计决定参数总量

为了了解参数总量,我们需要知道层设计,以及每一层参数的数量。下面介绍1几个层可视化工具:

1. nn.model自带的方法

#  nn.Module 的内置功能,直接输出模型结构
print(model)

这是最基础、最简单的方法,会直接打印模型对象,它会输出模型的结构,显示模型中各个层的名称和参数信息

# nn.Module 的内置功能,返回模型的可训练参数迭代器
for name, param in model.named_parameters():print(f"Parameter name: {name}, Shape: {param.shape}")

可以将模型中带有weight的参数(即权重)提取出来,并转为 numpy 数组形式,对其计算统计分布,并且绘制可视化图表

# 提取权重数据
import numpy as np
weight_data = {}
for name, param in model.named_parameters():if 'weight' in name:weight_data[name] = param.detach().cpu().numpy()# 可视化权重分布
fig, axes = plt.subplots(1, len(weight_data), figsize=(15, 5))
fig.suptitle('Weight Distribution of Layers')for i, (name, weights) in enumerate(weight_data.items()):# 展平权重张量为一维数组weights_flat = weights.flatten()# 绘制直方图axes[i].hist(weights_flat, bins=50, alpha=0.7)axes[i].set_title(name)axes[i].set_xlabel('Weight Value')axes[i].set_ylabel('Frequency')axes[i].grid(True, linestyle='--', alpha=0.7)plt.tight_layout()
plt.subplots_adjust(top=0.85)
plt.show()# 计算并打印每层权重的统计信息
print("\n=== 权重统计信息 ===")
for name, weights in weight_data.items():mean = np.mean(weights)std = np.std(weights)min_val = np.min(weights)max_val = np.max(weights)print(f"{name}:")print(f"  均值: {mean:.6f}")print(f"  标准差: {std:.6f}")print(f"  最小值: {min_val:.6f}")print(f"  最大值: {max_val:.6f}")print("-" * 30)

对比 fc1.weight 和 fc2.weight 的统计信息 ,可以发现它们的均值、标准差、最值等存在差异。这反映了不同层在模型中的作用不同。权重统计信息可以为超参数调整提供参考。

2.torchsummary库的summary方法
# pip install torchsummary -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchsummary import summary
# 打印模型摘要,可以放置在模型定义后面
summary(model, input_size=(4,))

 该方法不显示输入层的尺寸,因为输入的神经网是自己设置的,所以不需要显示输入层的尺寸。但是在使用该方法时,input_size=(4,) 参数是必需的,因为 PyTorch 需要知道输入数据的形状才能推断模型各层的输出形状和参数数量。

        这是因为PyTorch 的模型在定义时是动态的,它不会预先知道输入数据的具体形状。nn.Linear(4, 10) 只定义了 “输入维度是 4,输出维度是 10”,但不知道输入的批量大小和其他维度,比如卷积层需要知道输入的通道数、高度、宽度等信息。----并非所有输入数据都是结构化数据

        因此,要生成模型摘要(如每层的输出形状、参数数量),必须提供一个示例输入形状,让 PyTorch “运行” 一次模型,从而推断出各层的信息。

summary 函数的核心逻辑是:

1. 创建一个与 input_size 形状匹配的虚拟输入张量(通常填充零)

2. 将虚拟输入传递给模型,执行一次前向传播(但不计算梯度)

3. 记录每一层的输入和输出形状,以及参数数量

4. 生成可读的摘要报告

构建神经网络的时候

1. 输入层不需要写:x多少个特征 输入层就有多少神经元

2. 隐藏层需要写,从第一个隐藏层可以看出特征的个数

3. 输出层的神经元和任务有关,比如分类任务,输出层有3个神经元,一个对应每个类别

可学习参数计算

1. Linear-1对应self.fc1 = nn.Linear(4, 10),表明前一层有4个神经元,这一层有10个神经元,每2个神经元之间靠着线相连,所有有4*10个权重参数+10个偏置参数=50个参数

2. relu层不涉及可学习参数,可以把它和前一个线性层看成一层,图上也是这个含义

3. Linear-3层对应代码 self.fc2 = nn.Linear(10,3),10*3个权重参数+3个偏置=33个参数

总参数83个,占用内存几乎为0

1.3 torchinfo库的summary方法
 torchinfo 是提供比 torchsummary 更详细的模型摘要信息,包括每层的输入输出形状、参数数量、计算量等。

# pip install torchinfo -i https://pypi.tuna.tsinghua.edu.cn/simple
from torchinfo import summary
summary(model, input_size=(4, ))

进度条功能
tqdm这个库非常适合用在循环中观察进度。尤其在深度学习这种训练是循环的场景中。他最核心的逻辑如下

1. 创建一个进度条对象,并传入总迭代次数。一般用with语句创建对象,这样对象会在with语句结束后自动销毁,保证资源释放。with是常见的上下文管理器,这样的使用方式还有用with打开文件,结束后会自动关闭文件。

2. 更新进度条,通过pbar.update(n)指定每次前进的步数n(适用于非固定步长的循环)。

1.手动更新

from tqdm import tqdm  # 先导入tqdm库
import time  # 用于模拟耗时操作# 创建一个总步数为10的进度条
with tqdm(total=10) as pbar:  # pbar是进度条对象的变量名# pbar 是 progress bar(进度条)的缩写,约定俗成的命名习惯。for i in range(10):  # 循环10次(对应进度条的10步)time.sleep(0.5)  # 模拟每次循环耗时0.5秒pbar.update(1)  # 每次循环后,进度条前进1步
from tqdm import tqdm
import time# 创建进度条时添加描述(desc)和单位(unit)
with tqdm(total=5, desc="下载文件", unit="个") as pbar:# 进度条这个对象,可以设置描述和单位# desc是描述,在左侧显示# unit是单位,在进度条右侧显示for i in range(5):time.sleep(1)pbar.update(1)  # 每次循环进度+1

unit 参数的核心作用是明确进度条中每个进度单位的含义,使可视化信息更具可读性。在深度学习训练中,常用的单位包括:

  • epoch:训练轮次(遍历整个数据集一次)。
  • batch:批次(每次梯度更新处理的样本组)。
  • sample:样本(单个数据点)
  • @浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/82713.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[yolov11改进系列]基于yolov11引入自注意力与卷积混合模块ACmix提高FPS+检测效率python源码+训练源码

[ACmix的框架原理] 1.1 ACMix的基本原理 ACmix是一种混合模型,结合了自注意力机制和卷积运算的优势。它的核心思想是,传统卷积操作和自注意力模块的大部分计算都可以通过1x1的卷积来实现。ACmix首先使用1x1卷积对输入特征图进行投影,生成一组…

[DS]使用 Python 库中自带的数据集来实现上述 50 个数据分析和数据可视化程序的示例代码

使用 Python 库中自带的数据集来实现上述 50 个数据分析和数据可视化程序的示例代码 摘要:由于 sample_data.csv 是一个占位符文件,用于代表任意数据集,我将使用 Python 库中自带的数据集来实现上述 50 个数据分析和数据可视化程序的示例代码…

【Python 中 lambda、map、filter 和 reduce】详细功能介绍及用法总结

以下是 Python 中 lambda、map、filter 和 reduce 的详细功能介绍及用法总结,涵盖基础语法、高频场景和示例代码。 一、lambda 匿名函数 功能 用于快速定义一次性使用的匿名函数。不需要显式命名,适合简化小规模逻辑。 语法 lambda 参数1, 参数2, ..…

贪心算法——分数背包问题

一、背景介绍 给定𝑛个物品,第𝑖个物品的重量为𝑤𝑔𝑡[𝑖−1]、价值为𝑣𝑎𝑙[𝑖−1],和一个容量为𝑐𝑎&#…

《软件工程》第 5 章 - 需求分析模型的表示

目录 5.1需求分析与验证 5.1.1 顺序图 5.1.2 通信图 5.1.3 状态图 5.1.4 扩充机制 5.2 需求分析的过程模型 5.3 需求优先级分析 5.3.1 确定需求项优先级 5.3.2 排定用例分析的优先顺序 5.4 用例分析 5.4.1 精化领域概念模型 5.4.2 设置分析类 5.4.3 构思分析类之间…

基于MATLAB的大规模MIMO信道仿真

1. 系统模型与参数设置 以下是一个单小区大规模MIMO系统的参数配置示例,适用于多发多收和单发单收场景。 % 参数配置 params.N_cell 1; % 小区数量(单小区仿真) params.cell_radius 500; % 小区半径(米&#xff09…

想查看或修改 MinIO 桶的匿名访问权限(public/private/custom)

在 Ubuntu 下,如果你想查看或修改 MinIO 桶的匿名访问权限(public/private/custom),需要使用 mc anonymous 命令而不是 mc policy。以下是详细操作指南: 1. 查看当前匿名访问权限 mc anonymous get minio/test输出示例…

HarmonyOS:相机选择器

一、概述 相机选择器提供相机拍照与录制的能力。应用可选择媒体类型实现拍照和录制的功能。调用此类接口时,应用必须在界面UIAbility中调用,否则无法启动cameraPicker应用。 说明 本模块首批接口从API version 11开始支持。后续版本的新增接口&#xff0…

牛客AI简历筛选:提升招聘效率的智能解决方案

在竞争激烈的人才市场中,企业HR每天需处理海量简历,面临筛选耗时长、标准不统一、误判率高等痛点。牛客网推出的AI简历筛选工具,以“20分钟处理1000份简历、准确率媲美真人HR”的高效表现,成为企业招聘的智能化利器。本文将深度解…

白杨SEO:做AI搜索优化的DeepSeek、豆包、Kimi、百度文心一言、腾讯元宝、通义、智谱、天工等AI生成内容信息采集主要来自哪?占比是多少?

大家好,我是白杨SEO,专注SEO十年以上,全网SEO流量实战派,AI搜索优化研究者。 在开始写之前,先说个抱歉。 上周在上海客户以及线下聚会AI搜索优化分享说各大AI模型的联网搜索是关闭的,最开始上来确实是的。…

QML与C++交互2

在QML与C的交互中,主要有两种方式:在C中调用QML的方法和在QML中调用C的方法。以下是具体的实现方法。 在C中调用QML的方法 首先,我们需要在QML文件中定义一个函数,然后在C代码中调用它。 示例 //QML main.qml文件 import QtQu…

OpenGL Chan视频学习-8 How I Deal with Shaders in OpenGL

bilibili视频链接: 【最好的OpenGL教程之一】https://www.bilibili.com/video/BV1MJ411u7Bc?p5&vd_source44b77bde056381262ee55e448b9b1973 函数网站: docs.gl 说明: 1.之后就不再整理具体函数了,网站直接翻译会更直观也…

动态防御新纪元:AI如何重构DDoS攻防成本格局

1. 传统高防IP的静态瓶颈与成本困境 传统高防IP依赖预定义规则库,面对SYN Flood、CC攻击等威胁时,常因规则更新滞后导致误封合法流量。例如,某电商平台曾因静态阈值过滤误封20%的订单接口流量,直接影响营收。以下代码模拟传统方案…

如何实现高性能超低延迟的RTSP或RTMP播放器

随着直播行业的快速发展,RTSP和RTMP协议成为了广泛使用的流媒体传输协议,尤其是在实时视频直播领域,如何构建一个高性能超低延迟的直播播放器,已经成为了决定直播平台成功与否的关键因素之一。作为音视频直播SDK技术老兵&#xff…

UE5 编辑器工具蓝图

文章目录 简述使用方法样例自动生成Actor,并根据模型的包围盒设置Actor的大小批量修改场景中Actor的属性,设置Actor的名字,设置Actor到指定的文件夹 简述 使用编辑器工具好处是可以在非运行时可以对资源或场景做一些操作,例如自动…

解锁5月游戏新体验 高速电脑配置推荐

很多玩家用户会发现一个规律,618大促前很多商家会提前解锁各种福利,5月选购各种电脑配件有时候会更划算!并且,STEAM在5月还有几个年度主题促销,“生物收集游戏节”、“僵尸大战吸血鬼游戏节”等等,配件大促…

干货|VR全景是什么?

VR全景技术解析:概念、特点与用途 VR全景,全称为虚拟现实全景技术(Virtual Reality Panorama Technology),是基于虚拟现实(Virtual Reality,VR)技术的创新展示方式。VR全景技术利用专业的拍摄设…

Nacos适配GaussDB超详细部署流程,通过二进制包、以及 Docker 打通用镜像包部署保姆级教程

1部署openGauss 官方文档下载 https://support.huaweicloud.com/download_gaussdb/index.html 社区地址 安装包下载 本文主要是以部署轻量级为主要教程,系统为openEuler,ip: 192.168.1.15 1.1系统环境准备 操作系统选择 系统AARCH64X86-64openEuler√√CentOS7√Docker…

MySQL 表内容的增删查改 -- CRUD操作,聚合函数,group by 子句

目录 1. Create 1.1 语法 1.2 单行数据 全列插入 1.3 多行数据 指定列插入 1.4 插入数据否则更新数据 1.5 替换 2. Retrieve 2.1 SELECT 列 2.1.1 全列查询 2.1.2 指定列查询 2.1.3 查询字段为表达式 2.1.4 为查询结果指定别名 2.1.5 结构去重 2.2 WHERE 条件 …

LabVIEW累加器标签通道

主要展示了 Accumulator Tag 通道的使用,通过三个并行运行的循环模拟不同数值的多个随机序列,分别以不同频率向累加器写入数值,右侧循环每秒读取累加器值,同时可切换查看每秒内每次事件的平均值,用于演示多线程数据交互…