23 - HaLoAttention模块

论文《Scaling Local Self-Attention for Parameter Efficient Visual Backbones》

1、作用

HaloNet通过引入Haloing机制和高效的注意力实现,在图像识别任务中达到了最先进的准确性。这些模型通过局部自注意力机制,有效地捕获像素间的全局交互,同时通过分块和Haloing策略,显著提高了处理速度和内存效率。

2、机制

1、Haloing策略

为了克服传统自注意力的计算和内存限制,HaloNet采用了Haloing策略,将图像分割成多个块,并为每个块扩展一定的Halo区域,仅在这些区域内计算自注意力。这种方法减少了计算量,同时保持了较大的感受野。

2、多尺度特征层次

HaloNet构建了多尺度特征层次结构,通过分层采样和跨尺度的信息流,有效捕获不同尺度的图像特征,增强了模型对图像中对象大小变化的适应性。

3、高效的自注意力实现

通过改进的自注意力算法,包括非中心化的局部注意力和分层自注意力下采样操作,HaloNet在保持高准确性的同时,提高了训练和推理速度。

3、独特优势

1、参数效率

HaloNet通过局部自注意力机制和Haloing策略,大幅度减少了所需的计算量和内存需求,实现了与当前最佳卷积模型相当甚至更好的性能,但使用更少的参数。

2、适应多尺度

多尺度特征层次结构使得HaloNet能够有效处理不同尺度的对象,提高了对复杂视觉任务的适应性和准确性。

3、提升速度和效率

通过优化的自注意力实现,HaloNet在不牺牲准确性的前提下,实现了比现有技术更快的训练和推理速度,使其更适合实际应用。

4、代码

import torch
from torch import nn, einsum
import torch.nn.functional as Ffrom einops import rearrange, repeat# 将设备和数据类型转换为字典格式def to(x):return {'device': x.device, 'dtype': x.dtype}# 确保输入是元组形式
def pair(x):return (x, x) if not isinstance(x, tuple) else x# 在指定维度上扩展张量
def expand_dim(t, dim, k):t = t.unsqueeze(dim=dim)expand_shape = [-1] * len(t.shape)expand_shape[dim] = kreturn t.expand(*expand_shape)# 将相对位置编码转换为绝对位置编码
def rel_to_abs(x):b, l, m = x.shaper = (m + 1) // 2col_pad = torch.zeros((b, l, 1), **to(x))x = torch.cat((x, col_pad), dim=2)flat_x = rearrange(x, 'b l c -> b (l c)')flat_pad = torch.zeros((b, m - l), **to(x))flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)final_x = flat_x_padded.reshape(b, l + 1, m)final_x = final_x[:, :l, -r:]return final_x# 生成一维的相对位置logits
def relative_logits_1d(q, rel_k):b, h, w, _ = q.shaper = (rel_k.shape[0] + 1) // 2logits = einsum('b x y d, r d -> b x y r', q, rel_k)logits = rearrange(logits, 'b x y r -> (b x) y r')logits = rel_to_abs(logits)logits = logits.reshape(b, h, w, r)logits = expand_dim(logits, dim=2, k=r)return logits# 相对位置嵌入类
class RelPosEmb(nn.Module):def __init__(self,block_size,rel_size,dim_head):super().__init__()height = width = rel_sizescale = dim_head ** -0.5self.block_size = block_sizeself.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)def forward(self, q):block = self.block_sizeq = rearrange(q, 'b (x y) c -> b x y c', x=block)rel_logits_w = relative_logits_1d(q, self.rel_width)rel_logits_w = rearrange(rel_logits_w, 'b x i y j-> b (x y) (i j)')q = rearrange(q, 'b x y d -> b y x d')rel_logits_h = relative_logits_1d(q, self.rel_height)rel_logits_h = rearrange(rel_logits_h, 'b x i y j -> b (y x) (j i)')return rel_logits_w + rel_logits_h# HaloAttention类class HaloAttention(nn.Module):def __init__(self,*,dim,block_size,halo_size,dim_head=64,heads=8):super().__init__()assert halo_size > 0, 'halo size must be greater than 0'self.dim = dimself.heads = headsself.scale = dim_head ** -0.5self.block_size = block_sizeself.halo_size = halo_sizeinner_dim = dim_head * headsself.rel_pos_emb = RelPosEmb(block_size=block_size,rel_size=block_size + (halo_size * 2),dim_head=dim_head)self.to_q = nn.Linear(dim, inner_dim, bias=False)self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)self.to_out = nn.Linear(inner_dim, dim)def forward(self, x):# 验证输入特征图维度是否符合要求b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.deviceassert h % block == 0 and w % block == 0, assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})'q_inp = rearrange(x, 'b c (h p1) (w p2) -> (b h w) (p1 p2) c', p1=block, p2=block)kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo)kv_inp = rearrange(kv_inp, 'b (c j) i -> (b i) j c', c=c)#生成查询、键、值q = self.to_q(q_inp)k, v = self.to_kv(kv_inp).chunk(2, dim=-1)# 拆分头部q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=heads), (q, k, v))# 缩放查询向量q *= self.scale# 计算注意力sim = einsum('b i d, b j d -> b i j', q, k)# 添加相对位置偏置sim += self.rel_pos_emb(q)# 掩码填充mask = torch.ones(1, 1, h, w, device=device)mask = F.unfold(mask, kernel_size=block + (halo * 2), stride=block, padding=halo)mask = repeat(mask, '() j i -> (b i h) () j', b=b, h=heads)mask = mask.bool()max_neg_value = -torch.finfo(sim.dtype).maxsim.masked_fill_(mask, max_neg_value)# 注意力机制attn = sim.softmax(dim=-1)# 聚合out = einsum('b i j, b j d -> b i d', attn, v)# 合并和组合头部out = rearrange(out, '(b h) n d -> b n (h d)', h=heads)out = self.to_out(out)# 将块合并回原始特征图out = rearrange(out, '(b h w) (p1 p2) c -> b c (h p1) (w p2)', b=b, h=(h // block), w=(w // block), p1=block,p2=block)return out# 输入 N C H W,  输出 N C H W
if __name__ == '__main__':block = HaloAttention(dim=512,block_size=2,halo_size=1, ).cuda()# 创建HaloAttention实例input = torch.rand(1, 512, 64, 64).cuda()# 创建随机输入output = block(input) # 前向传播print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/87456.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/87456.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025Mybatis最新教程(五)

第5章 ORM映射 5.1 MyBatis自动ORM失效 MyBatis只能自动维护库表”列名“与”属性名“相同时的对应关系,二者不同时,无法自动ORM。 自动ORM失效建表 create table t_managers(mgr_id int primary key auto_increment,mgr_name varchar(50),mgr_pwd varchar(50) ); 添加数据…

解决lombok注解失效问题

Lombok 注解失效是 Java 开发中的常见问题,通常由依赖配置、IDE 支持或构建工具设置引起。最近在拉取别人springboot3jdk21版本的项目时遇到了lombok注解失效,导致项目无法启动的问题,以下是我的解决方案: 首先检查idea 的lombok…

3分钟搭建LarkXR实时云渲染PaaS平台,实现各类3D/XR应用的一键推流

LarkXR是由Paraverse平行云自主研发的去中心化实时云渲染平台,以其卓越的性能和丰富完备的功能插件,引领3D/XR云化行业风向标。LarkXR适用于3D/XR开发者、设计师、终端用户等创新用户,可以在零硬件负担下,轻松实现超高清低时延的3…

vue3 watch监视详解

watch监视 一 &#xff1a;watch监视{ref}定义的基本类型结构 <template><div class"person"><h1>情况一:watch监视{ref}定义的基本类型结构</h1><h1>当前的和为{{ sum }}</h1><button click"changeSum">点我…

TensorFlow Serving学习笔记2: 模型服务

本文深入剖析 TensorFlow Serving 的核心架构与实现机制&#xff0c;结合源码分析揭示其如何实现高可用、动态更新的生产级模型服务。 一、TensorFlow Serving 核心架构 1.1 分层架构设计 TensorFlow Serving 采用模块化分层设计&#xff0c;各组件职责分明&#xff1a; 组件…

共享云桌面为什么能打败传统电脑

近年来&#xff0c;随着云桌面技术的快速发展&#xff0c;共享云桌面作为一种新型的计算模式&#xff0c;正在逐步改变人们的工作和生活方式。它凭借其独特的优势&#xff0c;正在逐步取代传统电脑&#xff0c;成为企业和个人用户的新选择。之所以在部分场景中展现出替代传统电…

B站PWN教程笔记-12

完结撒花。 今天还是以做题为主。 fmtstruaf 格式化字符串USER AFTER FREE 首先补充一个背景知识&#xff0c;指针也是有数据类型的&#xff0c;不同数据类型的指针xx&#xff0c;所加的字节数也不一样&#xff0c;其实是指针指的项目的下一项。如int a[20]&#xff0c;a是…

零基础设计模式——总结与进阶 - 3. 学习资源与下一步

第五部分&#xff1a;总结与进阶 - 3. 学习资源与下一步 到这里&#xff0c;你已经完成了设计模式主要内容的学习。但这仅仅是一个开始&#xff0c;设计模式的精髓在于实践和持续学习。本节将为你提供一些优质的学习资源和后续学习的建议&#xff0c;帮助你在这条道路上走得更…

多模态大语言模型arxiv论文略读(125)

Uni-Med: A Unified Medical Generalist Foundation Model For Multi-Task Learning Via Connector-MoE ➡️ 论文标题&#xff1a;Uni-Med: A Unified Medical Generalist Foundation Model For Multi-Task Learning Via Connector-MoE ➡️ 论文作者&#xff1a;Xun Zhu, Yi…

【学习笔记】NLP 基础概念

1.1 什么是 NLP 定义&#xff1a; 自然语言处理&#xff08;NLP&#xff09;**是一种让计算机理解、解释和生成人类语言的技术。它是人工智能领域中极为活跃且重要的研究方向&#xff0c;旨在模拟人类对语言的认知和使用过程 特点&#xff1a; 多学科交叉&#xff1a;结合计…

RNN为什么不适合大语言模型

在自然语言处理&#xff08;NLP&#xff09;领域中&#xff0c;循环神经网络&#xff08;RNN&#xff09;及衍生架构&#xff08;如LSTM&#xff09;采用序列依序计算的模式&#xff0c;这种模式之所以“限制了计算机并行计算能力”&#xff0c;核心原因在于其时序依赖的特性&a…

微信小程序一款不错的文字动画

效果图 .js Page({data: {list:[],animation:[text-left,text-right,text-top,text-bottom],text:[[春眠不觉晓&#xff0c;处处闻啼鸟。,夜来风雨声&#xff0c;花落知多少。 ],[床前明月光&#xff0c;疑是地上霜。,举头望明月&#xff0c;低头思故乡。],[千山鸟飞绝&#…

循环神经网络(RNN):序列数据处理的强大工具

在人工智能和机器学习的广阔领域中&#xff0c;处理和理解序列数据一直是一个重要且具有挑战性的任务。循环神经网络&#xff08;Recurrent Neural Network&#xff0c;RNN&#xff09;作为一类专门设计用于处理序列数据的神经网络&#xff0c;在诸多领域展现出了强大的能力。从…

手机SIM卡通话中随时插入录音语音片段(Windows方案)

手机SIM卡通话中随时插入录音语音片段&#xff08;Windows方案&#xff09; --本地AI电话机器人 上一篇&#xff1a;手机SIM卡通话中随时插入录音语音片段&#xff08;Android方案&#xff09;​​​​​​​ 下一篇&#xff1a;​​​​​​​编写中 一、前言 书接上文《手…

阿里云通义大模型:AI浪潮中的领航者

通义大模型初印象 在当今 AI 领域蓬勃发展的浪潮中&#xff0c;阿里云通义大模型宛如一颗璀璨的明星&#xff0c;迅速崛起并占据了重要的地位。随着人工智能技术的不断突破&#xff0c;大模型已成为推动各行业数字化转型和创新发展的核心驱动力。通义大模型凭借其强大的技术实…

【算法篇】逐步理解动态规划模型7(两个数组dp问题)

目录 两个数组dp问题 1.最长公共子序列 2.不同的子序列 3.通配符匹配 本文旨在通过对力扣上三道题进行讲解来让大家对使用动态规划解决两个数组的dp问题有一定思路&#xff0c;培养大家对状态定义&#xff0c;以及状态方程书写的思维。 顺序&#xff1a; 题目链接-》算法思…

什么是 HTTP Range 请求(范围请求)

HTTP Range 请求&#xff0c;即范围请求&#xff0c;是一种 HTTP 请求方法&#xff0c;允许客户端请求资源的部分数据。这种请求在处理大型文件&#xff08;如视频、音频、或大文件下载&#xff09;时特别有用&#xff0c;因为它可以有效地进行断点续传和按需加载数据&#xff…

java集合(十) ---- LinkedList 类

目录 十、LinkedList 类 10.1 位置 10.2 特点 10.3 与 ArrayList 的区别 10.4 构造方法 10.5 常用方法 十、LinkedList 类 10.1 位置 LinkedList 类位于 java.util 包下 10.2 特点 是 List 接口的实现类是 Deque 接口的实现类底层使用双向循环链表结构 10.3 与 Arra…

kafka消费的模式及消息积压处理方案

目录 1、kafka消费的流程 2、kafka的消费模式 2.1、点对点模式 2.2、发布-订阅模式 3、consumer消息积压 3.1、处理方案 3.2、积压量 4、消息过期失效 5、kafka注意事项 Kafka消费积压(Consumer Lag)是指消费者处理消息的速度跟不上生产者发送消息的速度&#xff0c;导致消息在…

RAG实践:Routing机制与Query Construction策略

Routing机制与Query Construction策略 前言RoutingLogical RoutingChatOpenAIStructuredRouting DatasourceConclusion Semantic RoutingEmbedding & LLMPromptRounting PromptConclusion Query ConstructionGrab Youtube video informationStructuredPrompt GithubReferen…