5 手写卷积函数

5 手写卷积函数

  • 背景
  • 介绍
  • 滑动窗口的方式
    • 代码
    • 问题
  • 矩阵乘法的方式
    • 原理
    • 代码
    • 结果
  • 效果对比
    • 对比代码
    • 日志
    • 结果
  • 一些思考

背景

从现在开始各种手写篇章,先从最经典的卷积开始

介绍

对于卷积层的具体操作,我这里就不在具体说卷积具体是什么东西了。
对于手写卷积操作而言,有两种方式,一种就是最朴素的通过滑动窗口来实现的方式,另一种方式就是使用矩阵乘法来简化操作过程的方式。

滑动窗口的方式

在这里插入图片描述

卷积操作的动图https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

通过上面的图片和连接就可以很直观地感受到卷积操作的方式,也能很直接想到使用简单的滑动窗口来实现,如果还不能理解,建议去B站搜下视频学习下

代码

"""
-*- coding: utf-8 -*-
使用滑动窗口方式的手动卷积
@Author : Leezed
@Time : 2025/6/27 15:33
"""import numpy as npclass ManualSlideWindowConv():"""手动实现卷积操作,使用滑动窗口方式没有实现反向传播功能"""def __init__(self, kernel_size, in_channel, out_channel, stride=1, padding=0, bias=True):self.kernel_size = kernel_sizeself.in_channel = in_channelself.out_channel = out_channelself.stride = strideself.padding = paddingself.bias = biasself.weight = np.random.randn(out_channel, in_channel, kernel_size, kernel_size)if bias:self.bias = np.random.randn(out_channel)else:self.bias = Nonedef print_weight(self):print("Weight shape:", self.weight.shape)print("Weight values:\n", self.weight)def get_weight(self):return self.weightdef set_weight(self, weight):if weight.shape != self.weight.shape:raise ValueError(f"Weight shape mismatch: expected {self.weight.shape}, got {weight.shape}")self.weight = weightdef __call__(self, x, *args, **kwargs):if self.padding > 0:x = np.pad(x, ((0, 0), (0, 0), (self.padding, self.padding), (self.padding, self.padding)), mode='constant')  # 在四周填充0batch_size, in_channel, height, width = x.shapekernel_size = self.kernel_size# 计算输出的高度和宽度out_height = (height - kernel_size) // self.stride + 1out_width = (width - kernel_size) // self.stride + 1output = np.zeros((batch_size, self.out_channel, out_height, out_width))for channel in range(self.out_channel):# 取出当前输出通道的权重kernel = self.weight[channel, :, :, :]# 添加biasif self.bias is not None:output[:, channel, :, :] += self.bias[channel]else:output[:, channel, :, :] = 0for i, end_height in enumerate(range(kernel_size - 1, height, self.stride)):for j, end_width in enumerate(range(kernel_size - 1, width, self.stride)):# 取出图像的滑动窗口start_height = end_height - kernel_size + 1start_width = end_width - kernel_size + 1window = x[:, :, start_height:end_height + 1, start_width:end_width + 1]# 计算卷积result = np.sum(kernel * window, axis=(1, 2, 3))output[:, channel, i, j] += resultreturn outputif __name__ == '__main__':# 测试代码x = np.random.randn(2, 3, 5, 5)  # batch_size=2, in_channel=3, height=5, width=5conv_layer = ManualSlideWindowConv(kernel_size=3, in_channel=3, out_channel=2, stride=1, padding=1)output = conv_layer(x)print("Output shape:", output.shape)conv_layer.print_weight()

问题

但是活动卷积的方式有一个问题就是这个方式太费时了,因为有三层循环,而对于python而言是用循环去计算是一件费力不讨好的事情,具体的时间花费我会在后面画出图来直观地展现

矩阵乘法的方式

原理

https://zhuanlan.zhihu.com/p/360859627
https://gist.github.com/hsm207/7bfbe524bfd9b60d1a9e209759064180
https://blog.csdn.net/caip12999203000/article/details/126494740

具体的原理我就不赘述了,上面的三个链接认真看也能看明白了,他的本质思想就是讲滑动窗口中的多次乘法,直接改成矩阵乘法,通过这种方式来进行加速,而且加速的幅度不小,但是会生成一个较大的矩阵,不可避免的带来内存的开销,也就是说本质是拿空间换时间

具体的例子就在图中
在这里插入图片描述
mmConv = ManualMatMulConv(kernel_size=3, in_channel=3, out_channel=64,padding=1)的卷积层面对
x = np.random.randn(64, 3, 224, 224).astype(np.float32)的特征就要吃1.5G+的内存了。

代码

class ManualMatMulConv():"""手动实现卷积操作,使用卷积乘法方式没有实现反向传播功能"""def __init__(self, kernel_size, in_channel, out_channel, stride=1, padding=0, bias=True):self.kernel_size = kernel_sizeself.in_channel = in_channelself.out_channel = out_channelself.stride = strideself.padding = paddingself.bias = biasself.weight = np.random.randn(out_channel, in_channel, kernel_size, kernel_size)if bias:self.bias = np.random.randn(out_channel)else:self.bias = Nonedef print_weight(self):print("Weight shape:", self.weight.shape)print("Weight values:\n", self.weight)def get_weight(self):return self.weightdef set_weight(self, weight):if weight.shape != self.weight.shape:raise ValueError(f"Weight shape mismatch: expected {self.weight.shape}, got {weight.shape}")self.weight = weightdef __call__(self, x, *args, **kwargs):if self.padding > 0:x = np.pad(x, ((0, 0), (0, 0), (self.padding, self.padding), (self.padding, self.padding)), mode='constant')  # 在四周填充0batch_size, in_channel, height, width = x.shapekernel_size = self.kernel_size# 计算输出的高度和宽度out_height = (height - kernel_size) // self.stride + 1out_width = (width - kernel_size) // self.stride + 1# 将权重转换为矩阵形式weight_matrix = self.weight.reshape(self.out_channel, -1)  # shape (out_channel, in_channel * kernel_size * kernel_size)# 将输入转为矩阵形式 手写unfold方式unfolded_x = []for i in range(0, height - kernel_size + 1, self.stride):for j in range(0, width - kernel_size + 1, self.stride):# 取出图像的滑动窗口 转成矩阵形式window = x[:, :, i:i + kernel_size, j:j + kernel_size].reshape(batch_size, -1)unfolded_x.append(window)unfolded_x = np.array(unfolded_x)  # shape: (num_windows, batch_size, in_channel * kernel_size * kernel_size)unfolded_x = np.transpose(unfolded_x, (1, 0, 2))  # shape: (batch_size, num_windows, in_channel * kernel_size * kernel_size)# 使用矩阵乘法计算卷积output = np.matmul(unfolded_x, weight_matrix.T)  # shape (batch_size, num_windows, out_channel)output = np.transpose(output, (0, 2, 1))  # shape (batch_size, out_channel, num_windows)output = output.reshape(batch_size, self.out_channel, out_height, out_width)# 添加biasif self.bias is not None:output += self.bias.reshape(1, -1, 1, 1)# 输出结果return output

结果

检测代码

if __name__ == '__main__':# 测试代码conv = ManualMatMulConv(kernel_size=3, in_channel=3, out_channel=2, stride=1, padding=0, bias=False)slide_window_conv = ManualSlideWindowConv(kernel_size=3, in_channel=3, out_channel=2, stride=1, padding=0, bias=False)conv.set_weight(slide_window_conv.get_weight())x = np.random.randn(1, 3, 5, 5)  # 输入形状 (batch_size, in_channel, height, width)output = conv(x)slide_window_output = slide_window_conv(x)print("Output shape:", output.shape)print("slide_window_output shape:", slide_window_output.shape)assert np.allclose(conv.get_weight(), slide_window_conv.get_weight()), "Weights do not match!"print("output:")print(output)print("slide_window_output:")print(slide_window_output)# 校验是否相同assert np.allclose(output, slide_window_output), "Outputs do not match!"print("Outputs match!")

在这里插入图片描述

效果对比

这里采取了四种方式的卷积来进行对比

  1. 滑动窗口方式的卷积
  2. 矩阵乘法的卷积
  3. torch.nn.Conv2d
  4. torch.nn.Conv2d 使用cuda

对比代码

import numpy as np
from matplotlib import pyplot as plt
from manual.conv.slide_window import ManualSlideWindowConv
from manual.conv.matmul import ManualMatMulConv
import torch
import time# 对比不同batchsize的卷积速度speeds = {'manual_matmul': [],'manual_slide_window': [],'torch': [],'torch_cuda': []
}swConv = ManualSlideWindowConv(kernel_size=3, in_channel=3, out_channel=64,padding=1)
mmConv = ManualMatMulConv(kernel_size=3, in_channel=3, out_channel=64,padding=1)
torchConv = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
torchCudaConv = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1).cuda()def timing_conv(conv,x):start = time.time()y = conv(x)end = time.time()return y, end - startfor bs in [1, 2, 4, 8, 16, 32]:x = np.random.randn(bs, 3, 224, 224).astype(np.float32)x_torch = torch.from_numpy(x)x_torch_cuda = x_torch.cuda()y, speed = timing_conv(swConv, x)speeds['manual_slide_window'].append(speed)print(f'slide_window bs={bs}, speed={speed:.4f}s')y, speed = timing_conv(mmConv, x)speeds['manual_matmul'].append(speed)print(f'matmul bs={bs}, speed={speed:.4f}s')y, speed = timing_conv(torchConv, x_torch)speeds['torch'].append(speed)print(f'torch bs={bs}, speed={speed:.4f}s')y, speed = timing_conv(torchCudaConv, x_torch_cuda)speeds['torch_cuda'].append(speed)print(f'torch_cuda bs={bs}, speed={speed:.4f}s')print('-' * 50)

日志

slide_window bs=1, speed=39.8342s
matmul bs=1, speed=0.1436s
torch bs=1, speed=0.0080s
torch_cuda bs=1, speed=0.0000s
--------------------------------------------------
slide_window bs=2, speed=39.8841s
matmul bs=2, speed=0.2185s
torch bs=2, speed=0.0172s
torch_cuda bs=2, speed=0.0010s
--------------------------------------------------
slide_window bs=4, speed=44.0416s
matmul bs=4, speed=0.3975s
torch bs=4, speed=0.0329s
torch_cuda bs=4, speed=0.0000s
--------------------------------------------------
slide_window bs=8, speed=41.7520s
matmul bs=8, speed=0.3222s
torch bs=8, speed=0.0588s
torch_cuda bs=8, speed=0.0000s
--------------------------------------------------
slide_window bs=16, speed=45.5278s
matmul bs=16, speed=0.5858s
torch bs=16, speed=0.1067s
torch_cuda bs=16, speed=0.0010s
--------------------------------------------------
slide_window bs=32, speed=58.1965s
matmul bs=32, speed=1.2161s
torch bs=32, speed=0.2045s
torch_cuda bs=32, speed=0.0010s
--------------------------------------------------

结果

在这里插入图片描述

去掉最慢的滑动窗口的结果展示
在这里插入图片描述
可以看到矩阵乘法的方式还是挺快的,至少比滑动窗口快多了

一些思考

但是随之而来的就是还有一个问题,为什么矩阵乘法的方式的内存开销这么大,但是torch.nn.Conv2d好像并没有这个问题

经过查阅了一些资料,我简单总结一下

  1. 瓦片式(Tiling)或分块(Blocking)计算

    虽然 矩阵乘法 概念上是将整个输入和卷积核展开,但实际的硬件实现(如GPU)并不总是一次性处理所有数据。它们可能会将计算任务分解成更小的“瓦片”或“块”。

    局部 矩阵乘法: 不是一次性将整个图像展开,而是每次只对输入的一小部分(例如一个批次、或者一个输出通道的一个小区域)进行 im2col 变换和矩阵乘法。这样可以限制中间矩阵的大小,从而减少瞬时内存占用。计算完成后,再将结果拼接到最终的输出特征图上。

    重用数据: 这种分块策略有助于更好地利用 CPU 缓存或 GPU 显存,因为同一小块数据可以在其被完全处理完毕前,反复用于计算,减少数据在主存和缓存之间的移动。

  2. 智能算法选择
    根据卷积参数动态选择最适合的底层算法(如 Winograd, FFT, 或优化过的直接卷积),而不是单一地依赖 im2col。

这也就是为啥我们写出来的方法跟官方版本的有差距的原因。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/912194.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/912194.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3+element-plus,实现两个表格同步滚动

需求:现在需要两个表格,为了方便对比左右的数据,需要其中一边的表格滚动时,另一边的表格也跟着一起滚动,并且保持滚动位置的一致性。具体如下图所示。 实现步骤: 确保两个表格的宽度一致:如果两…

Mysql架构

思考:Mysql需要重点学习什么: 索引:索引存储结构、索引优化......事务:锁机制与隔离级别、日志、集群架构 本文是对Mysql架构进行初步学习 1、Mysql链接 Mysql监听器是长连接 BIO(阻塞同步IO调用), 不是NIO. 为什么…

使用deepseek制作“喝什么奶茶”随机抽签小网页

教程很简单,如下操作 1. 新建文本文档,命名为奶茶.txt 2. 打开deepseek,发送下面这段提示词:用html5帮我生成一个喝什么奶茶的网页,点击按钮随机生成奶茶品牌等,包括喜茶等众多常见的奶茶品牌如果不满意还…

WOE值:风险建模中的“证据权重”量化术——从似然比理论到FICO评分卡实践

WOE值(Weight of Evidence,证据权重) 是信用评分和风险建模中用于量化特征分箱对目标变量的预测能力的核心指标。 本文由「大千AI助手」原创发布,专注用真话讲AI,回归技术本质。拒绝神话或妖魔化。搜索「大千AI助手」关…

js递归性能优化

JavaScript 递归性能优化 递归是编程中强大的技术,但在 JavaScript 中如果不注意优化可能会导致性能问题甚至栈溢出。以下是几种优化递归性能的方法: 1. 尾调用优化 (Tail Call Optimization, TCO) ES6 引入了尾调用优化,但只在严格模式下…

vue界面增加自定义水印 js

vue整个界面增加自定义水印 需求:领导想要增加自定义水印 好不容易调完,还是想记录一下,在.vue界面编写 export default {mounted() {this.$nextTick(() > {this.addWatermark()})},methods: {// 关键:添加水印// 动态添加水印addWaterm…

Go开发工程师-Golang基础知识篇

开篇 我们尝试从2个方面来进行介绍: 1. 社招实际面试问题 2. 问题涉及的基础点梳理 社招面试题 米哈游 1. Go 里面使用 Map 时应注意问题和数据结构 2. Map 扩容是怎么做的? 3. Map 的 panic 能被 recover 掉吗?了解 panic 和 recover …

能否仅用两台服务器实现集群的高可用性??

我们将问题分为两部分来回答:一是使用 Redis 或 Hazelcast 确保数据一致性后是否仍需 Oracle 或 MySQL 等数据库;二是能否仅用两台服务器实现集群的高可用性。以下是详细探讨: 1. 使用 Redis 或 Hazelcast 确保数据一致性后,还需要…

spring-ai-alibaba DashScopeCloudStore自动装配问题

问题 在学习spring-ai-alibaba时,发现1.0.0.2版本在自动装配DashScopeCloudStore时,会报如下错误: Field dashScopeCloudStore in com.example.spring_ai_alibaba_examples.examples.SpringAiAlibabaExample01 required a bean of type com…

docker-compose部署nacos

1、docker-compose内容 高版本的nacos使用docker启动,需要将所有的端口放开,仅仅开放8848端口,spring-boot客户端获取nacos配置的时候,可能取到的内容为空。 version: 3# 定义自定义网络,确保服务间通信和外部访问 ne…

CSRF 与 SSRF 的关联与区别

CSRF 与 SSRF 的关联与区别 区别 特性CSRF (跨站请求伪造)SSRF (服务器端请求伪造)攻击方向客户端 → 目标网站服务器 → 内部/外部资源攻击目标利用用户身份执行非预期操作利用服务器访问内部资源或发起对外请求受害者已认证的用户存在漏洞的服务器利用条件用户必须已登录目…

Payload-SDK自动升级

Payload-SDK自动升级 前言 自动升级旨在通过无人机更新负载上的软件,包括不限于:Payload-SDK应用、配置文件等。对于文件的传输,大疆的Payload-SDK给我们提供了两种方式:使用FTP协议和使用大疆自研的DCFTP。我们实现的自动升级是…

第五代移动通信新型调制及非正交多址传输技术研究与设计

第五代移动通信新型调制及非正交多址传输技术研究与设计 一、新型调制技术研究与实现 1. FBMC (滤波器组多载波) 调制实现 import numpy as np import matplotlib.pyplot as plt from scipy.fft import fft, ifft, fftshift from scipy.signal import get_window

AI 智能运维,重塑大型企业软件运维:从自动化到智能化的进阶实践​

一、引言:企业软件运维的智能化转型浪潮​ 在数字化转型加速的背景下,大型企业软件架构日益复杂,微服务、多云环境、分布式系统的普及导致传统运维模式面临效率瓶颈。AI 技术的渗透催生了智能运维(AIOps)的落地&#x…

Apache CXF安装详细教程(Windows)

本章教程,主要介绍,如何在Windows上安装Apache CXF,JDK版本是使用的1.8. 一、下载Apache CXF Apache CXF(Apache Celtix Fireworks)是一个开源的 Web 服务框架,用于 构建和开发服务端与客户端的 Web 服务应用程序。它支持多种 Web 服务标准,尤其是 SOAP(基于 XML 的协议…

逆向入门(22)程序逆向篇-TraceMe

界面看起来很普通 也没有壳,直接搜索字符串找到关键代码处 但是发现这些都是赋值,并没有实现跳转相关的函数。这里通过给弹窗函数下断点,追一下返回函数来找触发点。 再次点击check,触发断点,接着按ctrlF9返回到函数…

中文PDF解析准确率排名

市面上的文档解析工具种类各异,包括更适用于论文解析的,专精于表格数据提取的,针对手写体优化的,适用于技术文档的,擅长处理复杂多语言混排文档的,专门处理政府招标文档表格的,以及擅长金融类表…

Conformal LEC:官方学习教程

相关阅读 Conformal LEChttps://blog.csdn.net/weixin_45791458/category_12993839.html?spm1001.2014.3001.5482 本文是对Conformal Equivalence Checking User Guide中附录实验的翻译(有删改),实验文件可见安装目录Conformal/share/cfm/l…

【Torch】nn.Embedding算法详解

1. 定义 nn.Embedding 是 PyTorch 中的 查表式嵌入层(lookup‐table),用于将离散的整数索引(如词 ID、实体 ID、离散特征类别等)映射到一个连续的、可训练的低维向量空间。它通过维护一个形状为 (num_embeddings, emb…

cdq 三维偏序应用 / P4169 [Violet] 天使玩偶/SJY摆棋子

最近学了 cdq 分治想来做做这道题,结果被有些毒瘤的代码恶心到了。 /ll 题目大意:一开始给定一些平面中的点。然后给定一些修改和询问: 修改:增加一个点。询问:给定一个点,求离这个点最近(定义…