文章目录
- 前言
- 一、Deepseek Moe
- 二. Moe架构
- 1. Expert
- 2. Gate
- 3. MoE Module
- 三、Auxiliary Loss
- 总结
前言
MoE(Mixture of Experts) 已经逐渐在LLM中广泛应用,其工程部署相关目前也有了越来越多的支持,本文主要记录一下MoE的基本模块构造与原理。以Deepseek中的MoE构造为例。
MoE本质上可以看作一个MLP层,不过是对每个token都不一样的MLP层,假设一个MoE模块存在64个expert,相当于有64个并行MLP层,当有一个token送入进行处理时,会自动选择当前最合适的几个expert进行运算,而不是一个固定个MLP。
一、Deepseek Moe
模型整体架构如下:
Transformer((embed): ParallelEmbedding()(layers): ModuleList((0): Block((attn): MLA((wq): ColumnParallelLinear()(wkv_a): Linear()(kv_norm): RMSNorm()(wkv_b): ColumnParallelLinear()(wo): RowParallelLinear())(ffn): MLP((w1): ColumnParallelLinear()(w2): RowParallelLinear()(w3): ColumnParallelLinear())(attn_norm): RMSNorm()(ffn_norm): RMSNorm())(1-26): 26 x Block((attn): MLA((wq): ColumnParallelLinear()(wkv_a): Linear()(kv_norm): RMSNorm()(wkv_b): ColumnParallelLinear()(wo): RowParallelLinear())(ffn): MoE((gate): Gate()(experts): ModuleList((0-63): 64 x Expert((w1): Linear()(w2): Linear()(w3): Linear()))(shared_experts): MLP((w1): ColumnParallelLinear()(w2): RowParallelLinear()(w3): ColumnParallelLinear()))(attn_norm): RMSNorm()(ffn_norm): RMSNorm()))(norm): RMSNorm()(head): ColumnParallelLinear()
)
二. Moe架构
从上面的架构可以发现, moe作为一个模块替换了原始Transformer架构mlp模块,主要包含一个Gate和Expert。
1. Expert
Expert代码如下:
class Expert(nn.Module):"""Expert layer for Mixture-of-Experts (MoE) models.Attributes:w1 (nn.Module): Linear layer for input-to-hidden transformation.w2 (nn.Module): Linear layer for hidden-to-output transformation.w3 (nn.Module): Additional linear layer for feature transformation."""def __init__(self, dim: int, inter_dim: int):"""Initializes the Expert layer.Args:dim (int): Input and output dimensionality.inter_dim (int): Hidden layer dimensionality."""super().__init__()self.w1 = Linear(dim, inter_dim)self.w2 = Linear(inter_dim, dim)self.w3 = Linear(dim, inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Forward pass for the Expert layer.Args:x (torch.Tensor): Input tensor.Returns:torch.Tensor: Output tensor after expert computation."""return self.w2(F.silu(self.w1(x)) * self.w3(x))
可以看到Expert本质上就是一个简单的线性层
2. Gate
如何在当前MoE模块中选择合适的Expert则是通过Gate操作来完成的
Gate代码如下:
class Gate(nn.Module):"""Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.Attributes:dim (int): Dimensionality of input features.topk (int): Number of top experts activated for each input.n_groups (int): Number of groups for routing.topk_groups (int): Number of groups to route inputs to.score_func (str): Scoring function ('softmax' or 'sigmoid').route_scale (float): Scaling factor for routing weights.weight (torch.nn.Parameter): Learnable weights for the gate.bias (Optional[torch.nn.Parameter]): Optional bias term for the gate."""def __init__(self, args: ModelArgs):"""Initializes the Gate module.Args:args (ModelArgs): Model arguments containing gating parameters."""super().__init__()self.dim = args.dimself.topk = args.n_activated_expertsself.n_groups = args.n_expert_groupsself.topk_groups = args.n_limited_groupsself.score_func = args.score_funcself.route_scale = args.route_scaleself.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else Nonedef forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:"""Forward pass for the gating mechanism.Args:x (torch.Tensor): Input tensor.Returns:Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices."""scores = linear(x, self.weight)if self.score_func == "softmax":scores = scores.softmax(dim=-1, dtype=torch.float32)else:scores = scores.sigmoid()original_scores = scoresif self.bias is not None:scores = scores + self.biasif self.n_groups > 1:scores = scores.view(x.size(0), self.n_groups, -1)if self.bias is None:group_scores = scores.amax(dim=-1)else:group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)indices = group_scores.topk(self.topk_groups, dim=-1)[1]mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)indices = torch.topk(scores, self.topk, dim=-1)[1]weights = original_scores.gather(1, indices)if self.score_func == "sigmoid":weights /= weights.sum(dim=-1, keepdim=True)weights *= self.route_scalereturn weights.type_as(x), indices
Gate函数通过输入特征向量,输出分数
基于当前分数根据情况确定是否采用分组路由, 若采用分组路由,则会将多个expert分组,然后取每个组的最高分
3. MoE Module
有了Gate计算的分数,就可以进行MoE的计算了,MoE模块如下
class MoE(nn.Module):"""Mixture-of-Experts (MoE) module.Attributes:dim (int): Dimensionality of input features.n_routed_experts (int): Total number of experts in the model.n_local_experts (int): Number of experts handled locally in distributed systems.n_activated_experts (int): Number of experts activated for each input.gate (nn.Module): Gating mechanism to route inputs to experts.experts (nn.ModuleList): List of expert modules.shared_experts (nn.Module): Shared experts applied to all inputs."""def __init__(self, args: ModelArgs):"""Initializes the MoE module.Args:args (ModelArgs): Model arguments containing MoE parameters."""super().__init__()self.dim = args.dimassert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"self.n_routed_experts = args.n_routed_expertsself.n_local_experts = args.n_routed_experts // world_sizeself.n_activated_experts = args.n_activated_expertsself.experts_start_idx = rank * self.n_local_expertsself.experts_end_idx = self.experts_start_idx + self.n_local_expertsself.gate = Gate(args)self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else Nonefor i in range(self.n_routed_experts)])self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Forward pass for the MoE module.Args:x (torch.Tensor): Input tensor.Returns:torch.Tensor: Output tensor after expert routing and computation."""shape = x.size()x = x.view(-1, self.dim)weights, indices = self.gate(x)y = torch.zeros_like(x)counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()for i in range(self.experts_start_idx, self.experts_end_idx):if counts[i] == 0:continueexpert = self.experts[i]idx, top = torch.where(indices == i)y[idx] += expert(x[idx]) * weights[idx, top, None]z = self.shared_experts(x)if world_size > 1:dist.all_reduce(y)return (y + z).view(shape)
为了方便理解,绘制了个草图
通过gate函数可以获得每个token所选的专家索引与对应的score
如上图中画了一个简单的当topk=3时的索引矩阵,针对每个token,选择三个expert进行处理,每个expert对应了一个分数用于加权。
循环所有expert,基于indices矩阵找出需要第i个expert处理的所有token ——> x[idx] ,经过expert处理后加权赋予y 作为新的特征。
获得y后,在将x送入shared_expert 获得z, 最后两者相加获得最终MoE的特征。
三、Auxiliary Loss
这里还需要补充一点,在使用MoE的同时,需要使用Auxiliary Loss。在混合专家 (MoE, Mixture of Experts) 模型中,Auxiliary Loss(辅助损失) 的主要作用是 负载均衡,即:
平衡专家的负载: 避免所有的 token 都被分配到少数几个专家上,导致这些专家过度繁忙,而其他专家闲置。
稳定训练过程: 防止路由器(Router,负责将 token 分配给专家)学到某种不健康的分布,例如所有 token 都集中到一个专家,或 token 分布极其稀疏。
每个 token 的路由器分数(router_logits)表示该 token 应该路由到哪些专家的偏好。
为了负载均衡,需要统计:
每个专家获得分配的 token 的比例(fraction_tokens_per_expert)。
路由概率(softmax(router_logits))中每个专家的总概率比例(fraction_prob_per_expert)。
关于aux loss的直观理解如下:
总结
MoE作为目前大模型在结构推理上的一大创新还是很厉害的,比较现在的LLM结构同质化蛮严重的。MoE确实给LLM未来的发展带来更多可能性。