LLM:MoE原理与实现探索

文章目录

  • 前言
  • 一、Deepseek Moe
  • 二. Moe架构
    • 1. Expert
    • 2. Gate
    • 3. MoE Module
  • 三、Auxiliary Loss
  • 总结


前言

MoE(Mixture of Experts) 已经逐渐在LLM中广泛应用,其工程部署相关目前也有了越来越多的支持,本文主要记录一下MoE的基本模块构造与原理。以Deepseek中的MoE构造为例。


MoE本质上可以看作一个MLP层,不过是对每个token都不一样的MLP层,假设一个MoE模块存在64个expert,相当于有64个并行MLP层,当有一个token送入进行处理时,会自动选择当前最合适的几个expert进行运算,而不是一个固定个MLP。

一、Deepseek Moe

模型整体架构如下

Transformer((embed): ParallelEmbedding()(layers): ModuleList((0): Block((attn): MLA((wq): ColumnParallelLinear()(wkv_a): Linear()(kv_norm): RMSNorm()(wkv_b): ColumnParallelLinear()(wo): RowParallelLinear())(ffn): MLP((w1): ColumnParallelLinear()(w2): RowParallelLinear()(w3): ColumnParallelLinear())(attn_norm): RMSNorm()(ffn_norm): RMSNorm())(1-26): 26 x Block((attn): MLA((wq): ColumnParallelLinear()(wkv_a): Linear()(kv_norm): RMSNorm()(wkv_b): ColumnParallelLinear()(wo): RowParallelLinear())(ffn): MoE((gate): Gate()(experts): ModuleList((0-63): 64 x Expert((w1): Linear()(w2): Linear()(w3): Linear()))(shared_experts): MLP((w1): ColumnParallelLinear()(w2): RowParallelLinear()(w3): ColumnParallelLinear()))(attn_norm): RMSNorm()(ffn_norm): RMSNorm()))(norm): RMSNorm()(head): ColumnParallelLinear()
)

二. Moe架构

从上面的架构可以发现, moe作为一个模块替换了原始Transformer架构mlp模块,主要包含一个Gate和Expert。

1. Expert

Expert代码如下:


class Expert(nn.Module):"""Expert layer for Mixture-of-Experts (MoE) models.Attributes:w1 (nn.Module): Linear layer for input-to-hidden transformation.w2 (nn.Module): Linear layer for hidden-to-output transformation.w3 (nn.Module): Additional linear layer for feature transformation."""def __init__(self, dim: int, inter_dim: int):"""Initializes the Expert layer.Args:dim (int): Input and output dimensionality.inter_dim (int): Hidden layer dimensionality."""super().__init__()self.w1 = Linear(dim, inter_dim)self.w2 = Linear(inter_dim, dim)self.w3 = Linear(dim, inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Forward pass for the Expert layer.Args:x (torch.Tensor): Input tensor.Returns:torch.Tensor: Output tensor after expert computation."""return self.w2(F.silu(self.w1(x)) * self.w3(x))

可以看到Expert本质上就是一个简单的线性层

2. Gate

如何在当前MoE模块中选择合适的Expert则是通过Gate操作来完成的

Gate代码如下:


class Gate(nn.Module):"""Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.Attributes:dim (int): Dimensionality of input features.topk (int): Number of top experts activated for each input.n_groups (int): Number of groups for routing.topk_groups (int): Number of groups to route inputs to.score_func (str): Scoring function ('softmax' or 'sigmoid').route_scale (float): Scaling factor for routing weights.weight (torch.nn.Parameter): Learnable weights for the gate.bias (Optional[torch.nn.Parameter]): Optional bias term for the gate."""def __init__(self, args: ModelArgs):"""Initializes the Gate module.Args:args (ModelArgs): Model arguments containing gating parameters."""super().__init__()self.dim = args.dimself.topk = args.n_activated_expertsself.n_groups = args.n_expert_groupsself.topk_groups = args.n_limited_groupsself.score_func = args.score_funcself.route_scale = args.route_scaleself.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else Nonedef forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:"""Forward pass for the gating mechanism.Args:x (torch.Tensor): Input tensor.Returns:Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices."""scores = linear(x, self.weight)if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)else:scores = scores.sigmoid()original_scores = scoresif self.bias is not None:scores = scores + self.biasif self.n_groups > 1:scores = scores.view(x.size(0), self.n_groups, -1)if self.bias is None:group_scores = scores.amax(dim=-1)else:group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)indices = group_scores.topk(self.topk_groups, dim=-1)[1]mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)indices = torch.topk(scores, self.topk, dim=-1)[1]weights = original_scores.gather(1, indices)if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)weights *= self.route_scalereturn weights.type_as(x), indices

Gate函数通过输入特征向量,输出分数

请添加图片描述

基于当前分数根据情况确定是否采用分组路由, 若采用分组路由,则会将多个expert分组,然后取每个组的最高分

请添加图片描述


3. MoE Module

有了Gate计算的分数,就可以进行MoE的计算了,MoE模块如下

class MoE(nn.Module):"""Mixture-of-Experts (MoE) module.Attributes:dim (int): Dimensionality of input features.n_routed_experts (int): Total number of experts in the model.n_local_experts (int): Number of experts handled locally in distributed systems.n_activated_experts (int): Number of experts activated for each input.gate (nn.Module): Gating mechanism to route inputs to experts.experts (nn.ModuleList): List of expert modules.shared_experts (nn.Module): Shared experts applied to all inputs."""def __init__(self, args: ModelArgs):"""Initializes the MoE module.Args:args (ModelArgs): Model arguments containing MoE parameters."""super().__init__()self.dim = args.dimassert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"self.n_routed_experts = args.n_routed_expertsself.n_local_experts = args.n_routed_experts // world_sizeself.n_activated_experts = args.n_activated_expertsself.experts_start_idx = rank * self.n_local_expertsself.experts_end_idx = self.experts_start_idx + self.n_local_expertsself.gate = Gate(args)self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else Nonefor i in range(self.n_routed_experts)])self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Forward pass for the MoE module.Args:x (torch.Tensor): Input tensor.Returns:torch.Tensor: Output tensor after expert routing and computation."""shape = x.size()x = x.view(-1, self.dim)weights, indices = self.gate(x)y = torch.zeros_like(x)counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()for i in range(self.experts_start_idx, self.experts_end_idx):if counts[i] == 0:continueexpert = self.experts[i]idx, top = torch.where(indices == i)y[idx] += expert(x[idx]) * weights[idx, top, None]z = self.shared_experts(x)if world_size > 1:dist.all_reduce(y)return (y + z).view(shape)

为了方便理解,绘制了个草图

请添加图片描述

通过gate函数可以获得每个token所选的专家索引与对应的score
如上图中画了一个简单的当topk=3时的索引矩阵,针对每个token,选择三个expert进行处理,每个expert对应了一个分数用于加权。

循环所有expert,基于indices矩阵找出需要第i个expert处理的所有token ——> x[idx] ,经过expert处理后加权赋予y 作为新的特征。

获得y后,在将x送入shared_expert 获得z, 最后两者相加获得最终MoE的特征。
请添加图片描述

三、Auxiliary Loss

这里还需要补充一点,在使用MoE的同时,需要使用Auxiliary Loss。在混合专家 (MoE, Mixture of Experts) 模型中,Auxiliary Loss(辅助损失) 的主要作用是 负载均衡,即:

平衡专家的负载: 避免所有的 token 都被分配到少数几个专家上,导致这些专家过度繁忙,而其他专家闲置。

稳定训练过程: 防止路由器(Router,负责将 token 分配给专家)学到某种不健康的分布,例如所有 token 都集中到一个专家,或 token 分布极其稀疏。

每个 token 的路由器分数(router_logits)表示该 token 应该路由到哪些专家的偏好。
为了负载均衡,需要统计:
每个专家获得分配的 token 的比例(fraction_tokens_per_expert)。
路由概率(softmax(router_logits))中每个专家的总概率比例(fraction_prob_per_expert)。

请添加图片描述

关于aux loss的直观理解如下:
请添加图片描述


总结

MoE作为目前大模型在结构推理上的一大创新还是很厉害的,比较现在的LLM结构同质化蛮严重的。MoE确实给LLM未来的发展带来更多可能性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/92894.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/92894.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于领域事件驱动的微服务架构设计与实践

引言&#xff1a;为什么你的微服务总是"牵一发而动全身"&#xff1f; 在复杂的业务系统中&#xff0c;你是否遇到过这样的困境&#xff1a;修改一个订单服务&#xff0c;却导致支付服务异常&#xff1b;调整库存逻辑&#xff0c;用户服务开始报错。这种"蝴蝶效应…

如何使用curl编程来下载文件

libcurl 是一个功能强大的跨平台网络传输库&#xff0c;支持多种协议。 本篇来介绍libcul的C语言编程&#xff0c;实现一个文件下载的功能。 1 curl基础介绍 1.1 核心数据结构 1.1.1 CURL句柄 CURL是libcurl 的核心句柄&#xff0c;每个请求对应一个 CURL 实例&#xff0c;…

大语言模型提示工程与应用:ChatGPT提示工程技术指南

ChatGPT提示工程 学习目标 在本课程中&#xff0c;我们将学习更多关于ChatGPT的最新提示工程技术。 相关知识点 ChatGPT提示工程 学习内容 1 ChatGPT提示工程 ChatGPT是OpenAI研发的新型对话模型&#xff0c;具备多轮对话能力。该模型通过人类反馈强化学习(RLHF)训练&am…

能力评估:如何系统评估你的技能和经验

能力评估&#xff1a;如何系统评估你的技能和经验 作为一名38岁的互联网研发老兵&#xff0c;你已经积累了丰富的经验&#xff0c;包括技术深度、项目管理、团队协作等。但能力评估不是一次性事件&#xff0c;而是持续过程&#xff0c;帮助你识别优势、短板&#xff0c;并为职业…

鸿蒙开发中所有自定义装饰器的完整案例解析--涵盖 16 个核心装饰器的详细用法和实战场景

以下是鸿蒙开发中 所有自定义装饰器的完整案例解析 和 终极总结指南&#xff0c;涵盖 16 个核心装饰器的详细用法和实战场景&#xff1a; 一、终极总结表&#xff1a;16大装饰器全景图 装饰器类别V1V2核心作用典型场景Component组件定义✅❌创建标准组件业务UI组件ComponentV2…

【C++】哈希表的实现(unordered_map和unordered_set的底层)

文章目录 目录 文章目录 前言 一、unordered_set和unordered_map介绍 二、哈希表的介绍 三、哈希冲突的解决方法 1.开放定址法 2.链地址法 四、两种哈希表代码实现 总结 前言 前面我们学习了红黑树&#xff0c;红黑树就是map和set的底层&#xff0c;本篇文章带来的是unordered…

欧拉公式的意义

欧拉公式的意义 欧拉公式&#xff08;Euler’s Formula&#xff09;是数学中最重要的公式之一&#xff0c;它将复数、指数函数和三角函数紧密联系在一起。其基本形式为&#xff1a; eiθcos⁡θisin⁡θ e^{i\theta} \cos \theta i \sin \theta eiθcosθisinθ 当 θπ\thet…

Linux Docker 运行SQL Server

在Linux操作系统&#xff0c;已安装docker&#xff0c;现在以docker compose方式&#xff0c;安装一个最新版SQL Server 2022的数据库。 # 建个目录&#xff08;请不要照抄&#xff0c;我的数据盘在/data&#xff0c;你可以改为/opt&#xff09; mkdir /data/sqlserver# 进入目…

C++:stack_queue(2)实现底层

文章目录一.容器适配器1. 本质&#xff1a;2. 接口&#xff1a;3. 迭代器&#xff1a;4. 功能&#xff1a;二.deque的简单介绍1.概念与特性2.结构与底层逻辑2.1 双端队列&#xff08;deque&#xff09;结构&#xff1a;2.2 deque的内部结构2.3 deque的插入与删除操作&#xff1…

Lightroom 安卓版 + Windows 版 + Mac 版全适配,编辑管理一站式,专业摄影后期教程

软件是啥样的​ Adobe Lightroom 这软件&#xff0c;在安卓手机、Windows 电脑和 Mac 电脑上都能用。不管是喜欢拍照的人&#xff0c;还是专门搞摄影的&#xff0c;用它都挺方便&#xff0c;能一站式搞定照片编辑、整理和分享这些事儿。 ****下载地址 分享文件&#xff1a;【Li…

office卸载不干净?Office356卸载不干净,office强力卸载软件下载

微软官方认可的卸载工具&#xff0c;支持彻底清除Office组件及注册表残留。需要以管理员身份运行&#xff0c;选择“移除Office”功能并确认操作。 Office Tool Plus安装地址获取 点击这里获取&#xff1a;Office Tool Plus 1、双击打开软件 image 2、选择左右的工具箱&…

互联网企业慢性死亡的招聘视角分析:从岗位割裂看战略短视

内容简介&#xff1a; 一个猎头和HR的简单拒绝&#xff0c;揭示了中国互联网企业人才观念的深层问题。通过分析岗位过度细分现象&#xff0c;本文探讨了战略短视、内斗文化和核心竞争力缺失如何导致企业慢性死亡&#xff0c;并提出了系统性的解决方案。#互联网企业 #人才招聘 #…

OpenBMC中phosphor-dbus-interfaces深度解析:架构、原理与应用实践

引言 在OpenBMC生态系统中&#xff0c;phosphor-dbus-interfaces作为D-Bus接口定义的核心组件&#xff0c;扮演着系统各模块间通信"契约"的关键角色。本文将基于OpenBMC源码&#xff0c;从架构设计、实现原理到实际应用三个维度&#xff0c;全面剖析这一基础组件的技…

驾驶场景玩手机识别准确率↑32%:陌讯动态特征融合算法实战解析

原创声明本文为原创技术解析文章&#xff0c;核心技术参数与架构设计参考自《陌讯技术白皮书》&#xff0c;转载请注明出处。一、行业痛点&#xff1a;驾驶场景行为识别的现实挑战根据交通运输部道路运输司发布的《驾驶员不安全行为研究报告》显示&#xff0c;驾驶过程中使用手…

Mysql——单表最多数据量多少需要分表

目录 一、MySql单表最多数据量多少需要分表 1.1、阿里开发公约 1.2、一个三层的B+树,它最多可以存储多少数据量 1.3、示例 1.3.1、示例表中一行的数据占多少字节数 1.3.2、示例表中一页里面最多可以存多少条记录 1.3.3、按示例表计算,一个三层的B+树,可以放多少条100字节的数…

scikit-learn/sklearn学习|岭回归解读

【1】引言 前序学习进程中&#xff0c;对用scikit-learn表达线性回归进行了初步解读。 线性回归能够将因变量yyy表达成由自变量xxx、线性系数矩阵www和截距bbb组成的线性函数式&#xff1a; y∑i1nwi⋅xibwTxby\sum_{i1}^{n}w_{i}\cdot x_{i}bw^T{x}byi1∑n​wi​⋅xi​bwTxb实…

基于Django的图书馆管理系统的设计与实现

基于Django的图书馆管理系统的设计与实现、

ComfyUI版本更新---解决ComfyUI的节点不兼容问题

前言&#xff1a; 新版本的COMFYUI与节点容易出现不兼容的问题,会导致整个系统崩掉。 目录 一、前期准备工作&#xff1a;虚拟环境配置 为什么需要虚拟环境&#xff1f; 具体操作步骤 二、常见问题解决方案 1、工作流输入输出图像不显示问题 2、工作流不能拖动&#xff0…

生产管理ERP系统|物联及生产管理ERP系统|基于SprinBoot+vue的制造装备物联及生产管理ERP系统设计与实现(源码+数据库+文档)

生产管理ERP系统 目录 基于SprinBootvue的制造装备物联及生产管理ERP系统设计与实现 一、前言 二、系统设计 三、系统功能设计 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a; 博主介绍&#xff1a;✌️大厂码农|毕…

Numpy科学计算与数据分析:Numpy数组创建与应用入门

Numpy数组创建实战 学习目标 通过本课程的学习&#xff0c;学员将掌握使用Numpy库创建不同类型的数组的方法&#xff0c;包括一维数组、多维数组、全零数组、全一阵列、空数组等。本课程将通过理论讲解与实践操作相结合的方式&#xff0c;帮助学员深入理解Numpy数组的创建过程…