基础神经网络模型搭建

nn 包提供通用深度学习网络的模块集合,接收输入张量,计算输出张量,并保存权重。通常使用两种途径搭建 PyTorch 中的模型:nn.Sequential和 nn.Module。

nn.Sequential通过线性层有序组合搭建模型;nn.Module通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

目录

搭建线性层

通过nn.Sequential搭建

通过nn.Module搭建

获取模型摘要


搭建线性层

使用 nn 包搭建线性层。线性层接收 64*1000 维的输入,保存 1000*100 维的权重,并计算 64*100 维的输出。

import torch
from torch import nn
input_tensor = torch.randn(64, 1000)
linear_layer = nn.Linear(1000, 100)
output = linear_layer(input_tensor)
print(input_tensor.size())
print(output.size())

通过nn.Sequential搭建

考虑一个两层的神经网络,四个节点作为输入,五个节点在隐藏层,一个节点作为输出

from torch import nn
model = nn.Sequential(nn.Linear(4, 5),nn.ReLU(),nn.Linear(5, 1),
)
print(model)

通过nn.Module搭建

在 PyTorch 中搭建模型的另一种方法是对 nn.Module 类进行子类化,通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

考虑两个卷积层和两个完全连接层搭建的模型:

import torch.nn.functional as F
class Net(nn.Module):def __init__(self):super(Net, self).__init__()def forward(self, x):pass

定义__init__ 函数和forward 函数

def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5, 1)self.conv2 = nn.Conv2d(20, 50, 5, 1)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)
def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

重写两个类函数并打印模型

重写:子类中实现一个与父类的成员函数原型完全相同的函数

Net.__init__ = __init__
Net.forward = forward
model = Net()
print(model)

 查看模型位置

print(next(model.parameters()).device)

 

将模型移动至CUDA设备 

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

获取模型摘要

借助torchsummary包查获取模型摘要

pip install torchsummary
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89716.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89716.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于单片机出租车计价器设计

传送门 👉👉👉👉其他作品题目速选一览表 👉👉👉👉其他作品题目功能速览 概述 本设计实现了一种基于单片机的智能化出租车计价系统。系统以单片机为核心处理器,集成…

134. Java 泛型 - 上限通配符

文章目录134. Java 泛型 - 上限通配符 (? extends T)**1. 什么是上限通配符 (? extends T)?****2. 为什么使用 ? extends T?****3. 示例:使用 ? extends T 进行数据读取****✅ 示例 1:计算数值列表的总和****4. 注意事项&…

【1】YOLOv13 AI大模型-可视化图形用户(GUI)界面系统开发

【文章内容适用于任意目标检测任务】【GUI界面系统不局限于YOLOV13,主流YOLO系列模型同样适用】本文以车辆行人检测为背景,介绍基于【YOLOV13模型】和【AI大模型】的图形用户(GUI)界面系统的开发。助力大论文实现目标检测模型的应…

小程序常用api

1. wx.request - 发起网络请求 用于向服务器发送 HTTP 请求,获取数据或提交表单。 // 示例:GET 请求获取数据 wx.request({url: https://api.example.com/data, // 替换为实际 API 地址method: GET,success: (res) > {console.log(请求成功, res.da…

PaliGemma 2-轻量级开放式视觉语言模型

PaliGemma 2是轻量级开放式视觉语言模型 (VLM),灵感源自 PaLI-3,基于 SigLIP 视觉模型和 Gemma 语言模型等开放式组件。PaliGemma 同时接受图片和文本作为输入,并且可以回答有关图片的详细问题和背景信息。PaliGemma 2 提供 30 亿、100 亿和 …

腾讯云云服务器深度介绍

以下是围绕腾讯云云服务器(CVM)的详细介绍与推荐文章,结合其核心优势、应用场景及技术特性,为不同用户群体提供参考: 🚀 一、产品定位与核心价值 腾讯云云服务器(Cloud Virtual Machine, CVM&a…

Ceph OSD.419 故障分析

Ceph OSD.419 故障分析 1. 问题描述 在 Ceph 存储集群中,OSD.419 无法正常启动,系统日志显示服务反复重启失败。 2. 初始状态分析 观察到 OSD.419 服务启动失败的系统状态: systemctl status ceph-osd419 ● ceph-osd419.service - Ceph obje…

MySQL持久化原理及其常见问题

目录 MySQL刷盘原理 脏页和干净页 MySQL出现短暂的堵塞SQL现象 情况分析 应对措施 数据库表中数据删除原理 删除表中数据数据库空间大小不会改变 情况分析 应对措施 MySQL刷盘原理 一般主要分为两个步骤 内存更新和 redo log 记录是同一事务修改的两个必要操作&#…

VSCode中Cline无法正确读取终端的问题解决

出现的问题是:Cline 无法正确读取终端输出。 Shell Integration Unavailable Cline won’t be able to view the command’s output. Please update VSCode (CMD/CTRL Shift P → “Update”) and make sure you’re using a supported shell: zsh, bash, fish, o…

scalelsd 笔记 线段识别 本地部署 模型架构

ant-research/scalelsd | DeepWiki https://arxiv.org/html/2506.09369?_immersive_translate_auto_translate1 https://gitee.com/njsgcs/scalelsd https://github.com/ant-research/scalelsd https://huggingface.co/cherubicxn/scalelsd 模型链接: https…

Python, C ++开发个体户/个人品牌打造APP

个体户/个人品牌打造APP开发方案(Python C)一、技术选型与分工1. Python- 核心场景:后端API开发、数据处理、内容管理、第三方服务集成(如社交媒体分享、支付接口)。- 优势:开发效率高,丰富的库…

SQLAlchemy 常见问题笔记

文章目录SQLAlchemy Session对象如何操作数据库SQLAlchemy非序列化对象如何返回1.问题分析2.解决方案方法1:使用 Pydantic 响应模型(推荐)方法2:手动转换为字典(简单快速)方法3:使用 SQLAlchemy…

Shell脚本-uniq工具

一、前言在 Linux/Unix 系统中,uniq 是一个非常实用的文本处理命令,用于对重复的行进行统计、去重和筛选。它通常与 sort 搭配使用,以实现高效的文本数据清洗与统计分析。无论是做日志分析、访问频率统计,还是编写自动化脚本&…

氛围编码(Vice Coding)的工具选择方式

一、前言 在写作过程中,我受益于若干优秀的博客分享,它们给予我宝贵的启发: 《5分钟选对AI编辑器,每天节省2小时开发时间让你早下班!》:https://mp.weixin.qq.com/s/f0Zm3uPTcNz30oxKwf1OQQ 二、AI编辑的…

[硬件电路-57]:根据电子元器件的受控程度,可以把电子元器件分为:不受控、半受控、完全受控三种大类

根据电子元器件的受控程度,可将其分为不受控、半受控、完全受控三大类。这种分类基于元器件的工作状态是否需要外部信号(如电压、电流、光、热等)的主动调控,以及调控的精确性和灵活性。以下是具体分类及实例说明:一、…

基于Pytorch的人脸识别程序

人脸识别原理详解人脸识别是模式识别和计算机视觉领域的重要研究方向,其目标是从图像或视频中识别出特定个体的身份。现代人脸识别技术主要基于深度学习方法,特别是卷积神经网络 (CNN),下面从多个维度详细解析其原理:1. 人脸识别的…

ubuntu 开启ssh踩坑之旅

文章目录确认当前用户为普通用户 or root命令使用ssh还是sshd服务名称的由来apt update和apt upgrade的关系apt upgrade报错:“E: 您在 /var/cache/apt/archives/ 上没有足够的可用空间”开启ssh步骤错误排查查看日志修改sshd_config文件允许防火墙通过22端口确认当…

力扣:动态规划java

sub07 线性DP - O(1) 状态转移2_哔哩哔哩_bilibili 跳楼梯 class Solution {public int climbStairs(int n) {if (n < 1) {return 1; // 处理边界情况}int[] dp new int[n 1]; // 创建长度为n1的数组&#xff0c;比方说跳二级楼梯dp[0] 1; // 初始值设定dp[1] 1;for (…

React Native打开相册选择图片或拍照 -- react-native-image-picker

官方文档&#xff1a;https://www.npmjs.com/package/react-native-image-picker 场景&#xff1a;点击按钮打开相册选择图片或者点击按钮拍照 import { launchCamera, launchImageLibrary } from react-native-image-picker;// ... <TouchableOpacityactiveOpacity{0.7}o…

USRP B210生成信号最大带宽测试之Frank

书接上文&#xff1a; USRP B210生成LFM,SFM,BPSK,Frank信号的最大带宽测试&#xff08;一&#xff09; USRP B210生成信号最大带宽测试&#xff08;二&#xff09;SFM USRP B210生成信号最大带宽测试&#xff08;三&#xff09;LFM USRP B210生成信号最大带宽测试之BPSK …