cuda编程笔记(18)-- 使用im2col + GEMM 实现卷积

我们之前介绍了cudnn调用api直接实现卷积,本文我们探究手动实现。

对于直接使用for循环在cpu上的实现方法,就不过多介绍,只要了解卷积的原理,就很容易实现。

im2col 的核心思想

im2col = image to column

  • 把输入 feature map 的每个卷积感受野(sliding window)展开成一列向量

  • 卷积核也展开成一个行向量

  • 然后把卷积转化成 矩阵乘法(GEMM, General Matrix Multiply)

举例

假设:

  • 输入:1 channel, 4×4

  • 卷积核:1×2×2

  • 步长 stride=1

输入:

1  2  3  4
5  6  7  8
9  10 11 12
13 14 15 16

卷积核 2×2 的每个 sliding window:

[[1,2],[5,6]] → 展开为 [1,2,5,6]
[[2,3],[6,7]] → 展开为 [2,3,6,7]
...

im2col 后:

X_col = [[1, 2, 5, 6],   # 第一个位置[2, 3, 6, 7],   # 第二个位置[3, 4, 7, 8],...]

卷积核也展开成:

W_col = [w1, w2, w3, w4]

然后 卷积计算就变成矩阵乘法

Y=W_{col}\cdot X_{col}

  • 输出每个位置就是矩阵乘法的一个元素

  • 对多通道、多卷积核也可以批量做

说人话,就是把卷积核每次对准的这块矩阵区域展平成向量;比如X_col的第一行,与W_col作向量乘法,结果就是第一次卷积得到的结果。

但是一般我们会将W_col的参数作为一行,所以X_col实际存储需要转置一下,这样才符合矩阵乘法的要求

为什么效率高?

  1. GEMM 有高度优化

    • BLAS/cuBLAS/cuDNN 都对矩阵乘法做了很多优化

    • 可以充分利用 SIMD / GPU 并行

  2. 循环嵌套少

    • 原始卷积是 6 层循环(batch, output channel, height, width, input channel, kernel height/width)

    • im2col + GEMM → 只要一次矩阵乘法

  3. 容易扩展到多通道、多 batch

代码实现

代码中对应的X_col就是按列存储展开的元素

#ifndef __CUDACC__
#define __CUDACC__
#endif
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <cudnn.h>
#include <cublas_v2.h>
#include <iostream>
#include<cstdio>
#include <cmath>
#include <cstdlib>
#include <vector>__global__ void im2col_kernel(const float *data_im,//输入图片数据 [C,H,W]int channels,int height,int width,// 输入通道数 C,输入高度 H, 宽度 Wint ksize,int pad,int stride,// 卷积核大小 (假设方形: kH=kW=ksize),padding 大小,卷积步长int height_col,int width_col,// 输出特征图的高宽float *data_col){ // im2col 展开的矩阵输出 [C*ksize*ksize, H_out*W_out]int index=blockIdx.x*blockDim.x+threadIdx.x;//index: 每个线程负责展开一个卷积窗口中的一个元素。int n=channels*ksize*ksize*height_col*width_col;//n: 总的元素数if(index<n){//对比一下w_out*h_out*k_idx*k_w*k_h与n就知道为什么要这么计算这五个变量了//w_out, h_out: 代表当前处理的输出特征图位置 (卷积结果的坐标)。int w_out=index%width_col;int h_out=(index/width_col)%height_col;//k_idx → (c_in, k_h, k_w): 把卷积核的 index 分解成通道号、核内行列。int k_idx=(index/width_col/height_col);int k_w=k_idx%ksize;int k_h=(k_idx/ksize)%ksize;//c_in输入通道索引int c_in=k_idx/(ksize*ksize);//im_row, im_col: 对应输入图片上的真实位置(考虑 stride、padding 后)//这是经典的把输出坐标映射回输入坐标的公式int im_row=h_out*stride-pad+k_h;int im_col=w_out*stride-pad+k_w;//col_index: data_col 的存储索引int col_index=(c_in*ksize*ksize+k_h*ksize+k_w)* (height_col * width_col) + h_out * width_col + w_out;float val=0;//如果 im_row, im_col 在输入图像范围内 → 取值;否则属于padding → 0if(im_row>=0&&im_row<height&&im_col>=0&&im_col<width){val=data_im[(c_in*height+im_row)*width+im_col];}data_col[col_index]=val;}
}
void im2col_gpu(const float *d_im,int channels,int height,int width,int ksize,int pad,int stride,float *d_col){//计算输出大小:就是卷积输出 H_out, W_out 的公式。int height_col=(height +2*pad-ksize)/stride+1;int width_col=(width+2*pad-ksize)/stride+1;int n=channels*ksize*ksize*height_col*width_col;int threads=256;int blocks=(n+threads-1)/threads;im2col_kernel<<<blocks,threads>>>(d_im,channels,height,width,ksize,pad,stride,height_col,width_col,d_col);cudaDeviceSynchronize();
}
void conv_forward_im2col(const float *d_input,// 输入图像 [C,H,W]const float *d_weight,// 卷积核 [K,C,ksize,ksize]float *d_output,//输出 [K,H_out,W_out]int C,int H,int W,// 输入通道数,高,宽int K,// 卷积核数量 (输出通道数)int ksize,int stride,int pad,// 核大小,步长,填充cublasHandle_t &handle){int H_out=(H+2*pad-ksize)/stride+1;int W_out=(W+2*pad-ksize)/stride+1;float *d_col;// im2col buffer: [C*ksize*ksize, H_out*W_out]cudaMalloc(&d_col,sizeof(float)*C*ksize*ksize*H_out*W_out);im2col_gpu(d_input,C,H,W,ksize,pad,stride,d_col);//原本是W*x的顺序,但是由于cublasSgemm函数的特性,写的时候参照实际传递的方式// GEMM: d_weight:[K, C*ksize*ksize] * d_col:[C*ksize*ksize, H_out*W_out] = [K, H_out*W_out]const float alpha=1.0f,beta=0.0f;cublasSgemm(handle,CUBLAS_OP_N,CUBLAS_OP_N,H_out*W_out,K,C*ksize*ksize,&alpha,d_col,H_out*W_out,d_weight,C*ksize*ksize,&beta,d_output,H_out*W_out);cudaFree(d_col);
}
int main() {int C=1, H=5, W=5;int K=1, ksize=3, stride=1, pad=1;std::vector<float> h_input(C*H*W, 1.0f);   // 全1输入std::vector<float> h_weight(K*C*ksize*ksize, 1.0f); // 全1卷积核std::vector<float> h_output(K*H*W);float *d_input, *d_weight, *d_output;cudaMalloc(&d_input, h_input.size()*sizeof(float));cudaMalloc(&d_weight, h_weight.size()*sizeof(float));cudaMalloc(&d_output, h_output.size()*sizeof(float));cudaMemcpy(d_input, h_input.data(), h_input.size()*sizeof(float), cudaMemcpyHostToDevice);cudaMemcpy(d_weight, h_weight.data(), h_weight.size()*sizeof(float), cudaMemcpyHostToDevice);cublasHandle_t handle;cublasCreate(&handle);conv_forward_im2col(d_input, d_weight, d_output,C,H,W,K,ksize,stride,pad, handle);cudaMemcpy(h_output.data(), d_output, h_output.size()*sizeof(float), cudaMemcpyDeviceToHost);cublasDestroy(handle);// 打印结果int H_out=(H+2*pad-ksize)/stride+1;int W_out=(W+2*pad-ksize)/stride+1;std::cout << "Output (" << K << "," << H_out << "," << W_out << "):\n";for(int i=0;i<K;i++){for(int h=0;h<H_out;h++){for(int w=0;w<W_out;w++){std::cout << h_output[i*H_out*W_out+h*W_out+w] << " ";}std::cout << "\n";}}cudaFree(d_input);cudaFree(d_weight);cudaFree(d_output);return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/98013.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/98013.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Loopback for Mac:一键打造虚拟音频矩阵,实现跨应用音频自由流转

虚拟音频设备创建 模拟物理设备&#xff1a;Loopback允许用户在Mac上创建虚拟音频设备&#xff0c;这些设备可被系统及其他应用程序识别为真实硬件&#xff0c;实现音频的虚拟化传输。多源聚合&#xff1a;支持将麦克风、应用程序&#xff08;如Skype、Zoom、GarageBand、Logic…

深入解析Django重定向机制

概述 核心是一个基类 HttpResponseRedirectBase&#xff0c;以及两个具体的子类 HttpResponseRedirect&#xff08;302 临时重定向&#xff09;和 HttpResponsePermanentRedirect&#xff08;301 永久重定向&#xff09;。它们都是 HttpResponse 的子类&#xff0c;专门用于告诉…

【Java实战⑳】从IO到NIO:Java高并发编程的飞跃

目录一、NIO 与 IO 的深度剖析1.1 IO 的局限性1.2 NIO 核心特性1.3 NIO 核心组件1.4 NIO 适用场景二、NIO 核心组件实战2.1 Buffer 缓冲区2.2 Channel 通道2.3 Selector 选择器2.4 NIO 文件操作案例三、NIO2.0 实战3.1 Path 类3.2 Files 类3.3 Files 类高级操作3.4 NIO2.0 实战…

OpenCV 实战:图像模板匹配与旋转处理实现教程

目录 一、功能概述&#xff1a;代码能做什么&#xff1f; 二、环境准备&#xff1a;先搭好运行基础 1. 安装 Python 2. 安装 OpenCV 库 3. 准备图像文件 三、代码逐段解析&#xff1a;从基础到核心 1. 导入 OpenCV 库 2. 读取图像文件 3. 模板图像旋转&#xff1a;处理…

一、cadence的安装及入门教学(反相器的设计与仿真)

一、Cadence的安装 1、安装VMware虚拟机 2、安装带有cadence软件的Linux系统 注&#xff1a;网盘链接 分享链接&#xff1a;https://disk.ningsuan.com.cn/#s/8XaVdtRQ 访问密码&#xff1a;11111 所有文件压缩包及文档密码&#xff1a; Cadence_ic 3、安装tsmc18工艺库…

用ai写了个UE5插件

文章目录实际需求1.头文件2.源文件3.用法小结实际需求 这个需求来源于之前的一个项目&#xff0c;当时用了一个第三方插件&#xff0c;里边有一些绘制线段的代码&#xff0c;c层用的是drawdebugline&#xff0c;当时看底层&#xff0c;觉得应该没问题&#xff0c;不应该在rele…

机器学习从入门到精通 - 强化学习初探:Q-Learning到Deep Q-Network实战

机器学习从入门到精通 - 强化学习初探&#xff1a;从 Q-Learning 到 Deep Q-Network 实战 一、开场白&#xff1a;推开强化学习这扇门 不知道你有没有过这种感觉 —— 盯着一个复杂的系统&#xff0c;既想让它达到某个目标&#xff0c;又苦于无法用传统规则去精确描述每一步该怎…

【OpenHarmony文件管理子系统】文件访问接口解析

OpenHarmony文件访问接口&#xff08;filemanagement_file_api&#xff09; 概述 OpenHarmony文件访问接口&#xff08;filemanagement_file_api&#xff09;是开源鸿蒙操作系统中的核心文件系统接口&#xff0c;为应用程序提供了完整的文件IO操作能力。该项目基于Node-API&…

云手机运行是否消耗自身流量?

云手机运行是否消耗自身流量&#xff0c;取决于具体的使用场景和设置&#xff1a;若用户在连接云手机时&#xff0c;使用的是家中Wi-Fi、办公室局域网等非移动数据网络&#xff0c;那么在云手机运行过程中&#xff0c;基本不会消耗用户自身的移动数据流量&#xff0c;在家中连接…

JavaSe之多线程

一、多线程基本了解 1、多线程基本知识 1.进程:进入到内存中执行的应用程序 2.线程:内存和CPU之间开通的通道->进程中的一个执行单元 3.线程作用:负责当前进程中程序的运行.一个进程中至少有一个线程,一个进程还可以有多个线程,这样的应用程序就称之为多线程程序 4.简单理解…

产品月报|睿本云8月产品功能迭代

睿本云8月更新已陆续上线&#xff01; 睿本云8月产品月报&#xff0c;点击查收&#x1f447;小程序支付成功弹窗广告、企业会员增加卡券销售和卡券退货模块、工厂端可批量新增多门店订货单、门店端和工厂端新增“极速订货”、商品调拨业务支持自定义多种流程配置等功能迭代更新…

融云:当我们谈论 AI 重构业务时,我们到底在谈论什么

所有业务都值得用 AI 重新做一次。 这句话正在从一句鼓舞人心的口号&#xff0c;演变为一场无人可避的商业现实。AI 带来的结构性机会&#xff0c;意味着企业有机会从根本上重构成本、效率与体验的曲线。但这一切最终都要回到一个无比务实的问题上&#xff1a; AI 究竟如何在我…

org.yaml.snakeyaml.error.YAMLException: java.nio.charset.MalformedInputException: Input length = 1异常

org.yaml.snakeyaml.error.YAMLException: java.nio.charset.MalformedInputException: Input length 1异常问题解决一、问题背景二、错误现象三、原因分析核心问题&#xff1a;字符集不匹配四、解决过程试错路径记录五、最终方案1.创建launch.json文件&#xff0c;修改VSCode…

【C语言】深入理解指针(5)

目录 sizeof和strlen 1.sizeof 2.strlen 3. sizeof 和 strlen 的对比 sizeof和strlen 1.sizeof sizeo正名&#xff1a;sizeof是操作符&#xff0c;不是函数&#xff0c;sizeof是操作符&#xff0c;括号内如果有计算不会进行计算sizeof 是操作符&#xff0c;用于计算变量所…

动态代理设计模式

JDK动态代理实现 动态代理利用了JDK API,动态地在内存中构建代理对象,从而实现对目标对象的代理功能.动态代理又被称为JDK代理或接口代理. 静态代理与动态代理的区别: 静态代理在编译时就已经实现了,编译完成后代理类是一个实际的class文 动态代理是在运行时动态生成的,即编译…

《Html泛型魔法学院:用霍格沃茨风格网页教授集合框架》

一、项目概述 这个创意教学网页&#xff0c;将Java泛型与集合框架知识融入霍格沃茨魔法世界主题。通过沉浸式UI设计和交互式代码练习&#xff0c;让抽象的技术概念变得生动有趣。主要技术栈包括&#xff1a; HTML5语义化结构Tailwind CSS框架Font Awesome图标库纯JavaScript交…

学习PaddlePaddle--环境配置-PyCharm + Conda​

第一阶段&#xff1a;安装与配置 Python 和 Conda​​ 虽然 PyCharm 可以管理环境&#xff0c;但我们先独立准备好 Conda 环境&#xff0c;这样更清晰可靠。 ​​1. 安装 Miniconda (Python 环境管理)​​ 1. ​​下载​​&#xff1a; • 访问 Miniconda 官网。 • 选择 ​​M…

【数据库】Sql Server数据库中isnull、iif、case when三种方式的使用和空值判断

大家好&#xff0c;我是全栈小5&#xff0c;欢迎来到《小5讲堂》。 这是《Sql Server》系列文章&#xff0c;每篇文章将以博主理解的角度展开讲解。 温馨提示&#xff1a;博主能力有限&#xff0c;理解水平有限&#xff0c;若有不对之处望指正&#xff01; 目录前言ISNULL用法c…

【蓝桥杯选拔赛真题64】C++最大空白区 第十四届蓝桥杯青少年创意编程大赛 算法思维 C++编程选拔赛真题解

C++最大空白区 第十四届蓝桥杯青少年创意编程大赛C++选拔赛真题 博主推荐 所有考级比赛学习相关资料合集【推荐收藏】 1、C++专栏 电子学会C++一级历年真题解析 电子学会C++二级历年真题解析

试用Augment编写python脚本实现智能家居3D环境交互响应

环境配置 VS Code中直接安装Augment扩展&#xff0c;然后邮箱登录就能获得7天的试用。 从如下位置安装3D建模软件Blender&#xff1a; https://www.blendercn.org/downloadme#xiazai Blender 是一款免费开源的 3D 创作套件。它支持整个三维流程&#xff1a;建模、绑定、动画…