Pytorch实现一个简单的贝叶斯卷积神经网络模型

 

贝叶斯深度模型的主要特点和实现说明:

  1. 模型结构

    • 结合了常规卷积层(用于特征提取)和贝叶斯线性层(用于分类)
    • 贝叶斯层将权重视为随机变量,而非传统神经网络中的确定值
    • 使用变分推断来近似权重的后验分布
  2. 贝叶斯特性

    • 通过重参数化技巧实现随机变量的采样,使得模型可训练
    • 损失函数包含两部分:分类损失(交叉熵)和 KL 散度(衡量近似后验与先验的差异)
    • 测试时通过多次采样获取预测分布,体现模型的不确定性
  3. 使用方法

    • 代码会自动下载 MNIST 数据集并进行预处理
    • 支持 GPU 加速(如果可用)
    • 训练完成后会绘制损失和准确率曲线,并保存模型
  4. 与传统神经网络的区别

    • 贝叶斯模型能够提供预测的不确定性估计
    • 通常具有更好的泛化能力,不易过拟合
    • 训练过程更复杂,计算成本更高
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt# 定义贝叶斯线性层 - 使用变分推断近似后验分布
class BayesianLinear(nn.Module):def __init__(self, in_features, out_features):super(BayesianLinear, self).__init__()self.in_features = in_featuresself.out_features = out_features# 先验分布参数 (高斯分布)self.prior_mu = 0.0self.prior_sigma = 1.0# 变分参数 - 权重的均值和标准差self.mu_weight = nn.Parameter(torch.Tensor(out_features, in_features).normal_(0, 0.1))self.sigma_weight = nn.Parameter(torch.Tensor(out_features, in_features).fill_(0.1))# 变分参数 - 偏置的均值和标准差self.mu_bias = nn.Parameter(torch.Tensor(out_features).normal_(0, 0.1))self.sigma_bias = nn.Parameter(torch.Tensor(out_features).fill_(0.1))# 用于重参数化技巧的噪声变量self.epsilon_weight = Noneself.epsilon_bias = Nonedef forward(self, x):# 重参数化技巧:将随机采样转换为确定性操作,便于反向传播if self.training:# 训练时从近似后验分布中采样self.epsilon_weight = torch.normal(torch.zeros_like(self.mu_weight))self.epsilon_bias = torch.normal(torch.zeros_like(self.mu_bias))weight = self.mu_weight + self.sigma_weight * self.epsilon_weightbias = self.mu_bias + self.sigma_bias * self.epsilon_biaselse:# 测试时使用均值(最大后验估计)weight = self.mu_weightbias = self.mu_bias# 计算KL散度(衡量近似后验与先验的差异)kl_loss = self._kl_divergence()return nn.functional.linear(x, weight, bias), kl_lossdef _kl_divergence(self):# 计算KL散度:KL(q(w) || p(w))kl_weight = 0.5 * torch.sum(1 + 2 * torch.log(self.sigma_weight) - torch.square(self.mu_weight) - torch.square(self.sigma_weight)) / (self.prior_sigma ** 2)kl_bias = 0.5 * torch.sum(1 + 2 * torch.log(self.sigma_bias) - torch.square(self.mu_bias) - torch.square(self.sigma_bias)) / (self.prior_sigma ** 2)return kl_weight + kl_bias# 定义贝叶斯卷积神经网络模型
class BayesianCNN(nn.Module):def __init__(self, num_classes=10):super(BayesianCNN, self).__init__()# 卷积层使用常规卷积(为简化模型)self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层使用贝叶斯层self.fc1 = BayesianLinear(64 * 7 * 7, 128)self.fc2 = BayesianLinear(128, num_classes)self.relu = nn.ReLU()def forward(self, x):# 卷积特征提取部分x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)  # 展平特征图# 贝叶斯全连接部分x, kl1 = self.fc1(x)x = self.relu(x)x, kl2 = self.fc2(x)# 总KL散度total_kl = kl1 + kl2return x, total_kl# 训练函数
def train(model, train_loader, optimizer, criterion, epoch, device):model.train()train_loss = 0correct = 0total = 0# KL散度的权重(根据数据集大小调整)kl_weight = 1.0 / len(train_loader.dataset)for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()# 前向传播output, kl_loss = model(data)# 总损失 = 分类损失 + KL散度正则化loss = criterion(output, target) + kl_weight * kl_loss# 反向传播和优化loss.backward()optimizer.step()# 统计train_loss += loss.item()_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()# 打印训练进度if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')train_loss /= len(train_loader)train_acc = 100. * correct / totalprint(f'Train set: Average loss: {train_loss:.4f}, Accuracy: {correct}/{total} ({train_acc:.2f}%)')return train_loss, train_acc# 测试函数
def test(model, test_loader, criterion, device, num_samples=10):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)# 多次采样以获取预测分布(体现贝叶斯模型的不确定性)outputs = []for _ in range(num_samples):output, _ = model(data)outputs.append(output.unsqueeze(0))# 平均多次采样的结果output = torch.mean(torch.cat(outputs, dim=0), dim=0)test_loss += criterion(output, target).item()# 统计准确率_, predicted = torch.max(output.data, 1)total += target.size(0)correct += (predicted == target).sum().item()test_loss /= len(test_loader)test_acc = 100. * correct / totalprint(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({test_acc:.2f}%)')return test_loss, test_acc# 主函数
def main():# 超参数设置batch_size = 64test_batch_size = 1000epochs = 10lr = 0.001seed = 42num_samples = 10  # 测试时的采样次数,用于获取预测分布# 设置设备(GPU或CPU)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")# 设置随机种子,保证结果可复现torch.manual_seed(seed)# 数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差])# 加载MNIST数据集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)# 初始化模型、损失函数和优化器model = BayesianCNN().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)# 记录训练过程中的损失和准确率train_losses = []train_accs = []test_losses = []test_accs = []# 开始训练和测试for epoch in range(1, epochs + 1):train_loss, train_acc = train(model, train_loader, optimizer, criterion, epoch, device)test_loss, test_acc = test(model, test_loader, criterion, device, num_samples)train_losses.append(train_loss)train_accs.append(train_acc)test_losses.append(test_loss)test_accs.append(test_acc)# 绘制训练和测试损失曲线plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, label='Train Loss')plt.plot(range(1, epochs + 1), test_losses, label='Test Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Loss vs Epoch')plt.legend()# 绘制训练和测试准确率曲线plt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs, label='Train Accuracy')plt.plot(range(1, epochs + 1), test_accs, label='Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Accuracy vs Epoch')plt.legend()plt.tight_layout()plt.show()# 保存模型torch.save(model.state_dict(), 'bayesian_cnn_mnist.pth')print("Model saved as 'bayesian_cnn_mnist.pth'")if __name__ == '__main__':main()

在模型规模相似(例如参数总量、网络深度和宽度相近)的情况下,普通卷积神经网络(CNN)的训练效率通常更高,训练速度更快。这主要源于贝叶斯卷积神经网络(Bayesian CNN)的特殊结构和训练机制带来的额外计算开销,具体原因如下:

1. 参数数量与计算复杂度差异

普通 CNN 中,每个权重是确定值,每个层仅需存储和优化一组权重参数(例如卷积核权重、偏置)。
而贝叶斯 CNN 中,权重被视为随机变量(通常假设服从高斯分布),需要用变分推断近似其 posterior 分布。这意味着每个权重需要学习两个参数:均值(μ) 和标准差(σ)(或精度),参数数量几乎是普通 CNN 的 2 倍(对于贝叶斯层而言)。

更多的参数直接导致:

  • 前向传播时需要计算更多变量的组合(例如通过重参数化技巧采样权重:weight = μ + σ·ε);
  • 反向传播时需要计算更多参数的梯度(不仅是均值,还有标准差),增加了梯度计算的复杂度。

2. 额外的损失项计算

普通 CNN 的损失函数通常仅包含任务相关损失(例如分类问题的交叉熵损失)。
而贝叶斯 CNN 的损失函数必须包含两部分:

  • 任务相关损失(与普通 CNN 相同);
  • KL 散度(KL divergence):用于衡量近似后验分布与先验分布的差异,作为正则化项。

KL 散度的计算需要对每个贝叶斯层的权重分布进行积分近似(即使是简化的解析解,也需要对所有权重的均值和标准差进行逐元素运算),这会额外增加计算开销,尤其当贝叶斯层较多时,累积开销显著。

3. 采样操作的开销

贝叶斯 CNN 在训练时,为了通过重参数化技巧实现梯度回传,需要对每个贝叶斯层的权重进行随机采样(例如从N(μ, σ²)中采样噪声ε,再计算weight = μ + σ·ε)。虽然采样操作本身不算复杂,但在大规模网络中,多次采样(即使每个 batch 一次)会累积计算时间。

普通 CNN 则无需采样,权重是确定性的,前向传播更直接高效。

总结

在模型规模相似的情况下,普通 CNN 由于参数更少、计算流程更简单(无额外的 KL 散度计算和采样操作),训练速度显著快于贝叶斯 CNN。

贝叶斯 CNN 的优势不在于训练效率,而在于其能量化预测的不确定性(例如通过多次采样得到预测分布),并在小样本、数据噪声大的场景下可能具有更好的泛化能力,但这是以更高的计算成本为代价的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/94291.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/94291.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dubbo 3.x源码(32)—Dubbo Provider处理服务调用请求源码

基于Dubbo 3.1,详细介绍了Dubbo Provider处理服务调用请求源码 上文我们学习了,Dubbo消息的编码解的源码。现在我们来学习一下Dubbo Provider处理服务调用请求源码。 当前consumer发起了rpc请求,经过请求编码之后到达provider端,…

每日一leetcode:移动零

目录 解题过程: 描述: 分析条件: 解题思路: 通过这道题可以学到什么: 解题过程: 描述: 给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序。 请注意 ,必须在不复制数组的情况下原地对数组进行操…

6-Django项目实战-[dtoken]-用户登录模块

1.创建应用 python manage.py startapp dtoken 2.注册应用 settings.py中注册 3.匹配路由4.编写登录功能视图函数 import hashlib import json import timeimport jwt from django.conf import settings from django.http import JsonResponse from user.models import UserPro…

Axure日期日历高保真动态交互原型

在数字化产品设计中,日期日历组件作为高频交互元素,其功能完整性与用户体验直接影响着用户对产品的信任度。本次带来的日期日历高保真动态交互原型,依照Element UI、View UI等主流前端框架为参考,通过动态面板、中继器、函数、交互…

【YOLOv4】

YOLOv4 论文地址::【https://arxiv.org/pdf/2004.10934】 YOLOv4 论文中文翻译地址:【深度学习论文阅读目标检测篇(七)中文版:YOLOv4《Optimal Speed and Accuracy of Object Detection》-CSDN博客】 yol…

【秋招笔试】2025.08.03虾皮秋招笔试-第一题

📌 点击直达笔试专栏 👉《大厂笔试突围》 💻 春秋招笔试突围在线OJ 👉 笔试突围在线刷题 bishipass.com 01. 蛋糕切分的最大收益 问题描述 K小姐经营着一家甜品店,今天她有一块长度为 n n n 厘米的长条蛋糕需要切分。根据店里的规定,她必须将蛋糕切成至少 2 2

2.0 vue工程项目的创建

前提准备.需要电脑上已经安装了nodejs 参考 7.nodejs和npm简单使用_npmjs官网-CSDN博客 创建vue2工程 全局安装 Vue CLI 在终端中运行以下命令来全局安装 Vue CLI: npm install -g vue/cli npm install -g 表示全局安装。vue/cli 是 Vue CLI 的包名。 安装完成后…

视觉图像处理中级篇 [2]—— 外观检查 / 伤痕模式的原理与优化设置方法

外观缺陷检测是工业生产中的关键环节,而伤痕模式作为图像处理的核心算法,能精准识别工件表面的划痕、污迹等缺陷。掌握其原理和优化方法,对提升检测效率至关重要。一、利用伤痕模式进行外观检查虽然总称为外观检查,但根据检查对象…

ethtool,lspci,iperf工具常用命令总结

ethtool、lspci 和 iperf 是 Linux 系统中进行网络硬件查看、配置和性能测试的核心命令行工具。下面是它们的常用命令分析和总结: 核心作用总结: lspci: 侦察兵 - 列出系统所有 PCI/PCIe 总线上的硬件设备信息,主要用于识别网卡型号、制造商、…

DAY10DAY11-新世纪DL(DeepLearning/深度学习)战士:序

本文参考视频[双语字幕]吴恩达深度学习deeplearning.ai_哔哩哔哩_bilibili 参考文章0.0 目录-深度学习第一课《神经网络与深度学习》-Stanford吴恩达教授-CSDN博客 1深度学习概论 1.举例介绍 lg房价预测:房价与面积之间的坐标关系如图所示,由线性回归…

flutter release调试插件

chucker_flutter (只有网络请求的信息,亲测可以用) flutter:3.24.3 使用版本 chucker_flutter: 1.8.2 chucker_flutter | Flutter package void main() async {// 可以控制显示ChuckerFlutter.showNotification false;ChuckerF…

基于开源链动2+1模式AI智能名片S2B2C商城小程序的私域流量拉新策略研究

摘要:私域流量运营已成为企业数字化转型的核心战略,其本质是通过精细化用户运营实现流量价值最大化。本文以“定位、拉新、养熟、成交、裂变、留存”全链路为框架,聚焦开源链动21模式、AI智能名片与S2B2C商城小程序的协同创新,揭示…

华为云云服务高级顾问叶正晖:华为对多模态大模型的思考与实践

嘉宾介绍:叶正晖,华为云云服务高级顾问,全球化企业信息化专家,从业年限超过23年,在华为任职超过21年,涉及运营商、企业、消费者、云服务、安全与隐私等领域,精通云服务、安全合规、隐私保护等领…

【机器学习(二)】KNN算法与模型评估调优

目录 一、写在前面的话 二、KNN(K-Nearest Neighbor) 2.1 KNN算法介绍 2.1.1 概念介绍 2.1.2 算法特点 2.1.3 API 讲解 2.2 样本距离计算 2.2.1 距离的类型 (1)欧几里得距离(Euclidean Distance) …

《Uniapp-Vue 3-TS 实战开发》实现自定义头部导航栏

本文介绍了如何将Vue2组件迁移至Vue3的组合式API。主要内容包括:1) 使用<script setup lang="ts">语法;2) 通过接口定义props类型约束;3) 用defineProps替代props选项;4) 将data变量转为ref响应式变量;5) 使用computed替代计算属性;6) 将created生命周期…

GitCode疑难问题诊疗

问题诊断与解决框架通用问题排查流程&#xff08;适用于大多数场景&#xff09; 版本兼容性验证方法 网络连接与权限检查清单常见错误分类与解决方案仓库克隆失败场景分析 HTTP/SSH协议错误代码解读 403/404错误深层原因排查高级疑难问题处理分支合并冲突的深度解决 .gitignore…

告别物业思维:科技正重构产业园区的价值坐标系

文 | 方寸控股引言&#xff1a;当产业园区的竞争升维为“科技军备竞赛”&#xff0c;土地红利消退&#xff0c;政策优势趋同&#xff0c;传统园区运营陷入增长困局。当招商团队还在用Excel统计企业需求&#xff0c;当能耗管理依赖保安夜间巡检&#xff0c;当企业服务停留在“修…

GitHub 热门项目 PandaWiki:零门槛搭建智能漏洞库,支持 10 + 大模型接入

转自&#xff1a;Khan安全团队你还没有自己的漏洞库吗&#xff1f;一条命令教你搭建。PandaWiki 是一款 AI 大模型驱动的开源知识库搭建系统&#xff0c;帮助你快速构建智能化的 产品文档、技术文档、FAQ、博客系统&#xff0c;借助大模型的力量为你提供 AI 创作、AI 问答、AI …

Python 程序设计讲义(55):Python 的函数——函数的参数

Python 程序设计讲义&#xff08;55&#xff09;&#xff1a;Python 的函数——函数的参数 目录Python 程序设计讲义&#xff08;55&#xff09;&#xff1a;Python 的函数——函数的参数一、声明形参二、传递实参&#xff08;位置参数&#xff09;1、在调用函数进行传递参数时…

机器学习sklearn:支持向量机svm

概述&#xff1a;现在就只知道这个svm可以画出决策边界&#xff0c;对数据的划分。简单举例就是&#xff1a;好的和坏的数据分开&#xff0c;中间的再验证from sklearn.datasets import make_blobs from sklearn.svm import SVC import matplotlib.pyplot as plt import numpy …