LLM基础5_从零开始实现 GPT 模型

基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn

设计 LLM 的架构

GPT 模型基于 Transformer 的 decoder-only 架构,其主要特点包括:

  • 顺序生成文本

  • 参数数量庞大(而非代码量复杂)

  • 大量重复的模块化组件

以 GPT-2 small 模型(124M 参数)为例,其配置如下:

GPT_CONFIG_124M = {"vocab_size": 50257,  # BPE 分词器词表大小"ctx_len": 1024,      # 最大上下文长度"emb_dim": 768,       # 嵌入维度"n_heads": 12,        # 注意力头数量"n_layers": 12,       # Transformer 块层数"drop_rate": 0.1,     # Dropout 比例"qkv_bias": False     # QKV 计算是否使用偏置
}

GPT 模型基本结构

cfg是配置实例

import torch.nn as nnclass GPTModel(nn.Module):def __init__(self, cfg):super().__init__()# Token 嵌入层self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])# 位置嵌入层self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])# Dropout 层self.drop_emb = nn.Dropout(cfg["drop_rate"])# 堆叠n_layers相同的Transformer 块self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])# 最终层归一化self.final_norm = LayerNorm(cfg["emb_dim"])# 输出层self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)def forward(self, in_idx):batch_size, seq_len = in_idx.shape# Token 嵌入tok_embeds = self.tok_emb(in_idx)# 位置嵌入pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))# 组合嵌入x = tok_embeds + pos_embedsx = self.drop_emb(x)# 通过 Transformer 块x = self.trf_blocks(x)# 最终归一化x = self.final_norm(x)# 输出 logitslogits = self.out_head(x)return logits

 层归一化 (Layer Normalization)

层归一化将激活值规范化为均值为 0、方差为 1 的分布,加速模型收敛:

class LayerNorm(nn.Module):def __init__(self, emb_dim):super().__init__()self.eps = 1e-5    # 防止除零错误的标准设定值self.scale = nn.Parameter(torch.ones(emb_dim))  #可学习缩放参数,初始化为全1向量self.shift = nn.Parameter(torch.zeros(emb_dim)) #可学习平移参数,初始化为全0向量def forward(self, x):mean = x.mean(dim=-1, keepdim=True)    #计算均值 μ,沿最后一维,保持维度var = x.var(dim=-1, keepdim=True, unbiased=False)    #计算方差 σ²,同均值维度,有偏估计(分母n)norm_x = (x - mean) / torch.sqrt(var + self.eps)    #标准化计算,分母添加ε防溢出return self.scale * norm_x + self.shift    #仿射变换,恢复模型表达能力

GELU 激活函数与前馈网络

GPT 使用 GELU(高斯误差线性单元)激活函数:

场景ReLU 的行为GELU 的行为
处理微弱负信号直接丢弃(可能丢失细节)部分保留(如:保留 30% 的信号强度)
遇到强烈正信号完全放行几乎完全放行(保留 95% 以上)
训练稳定性容易在临界点卡顿平滑过渡,减少训练震荡
应对复杂模式需要堆叠更多层数单层就能捕捉更细腻的变化
class GELU(nn.Module):def __init__(self):super().__init__()def forward(self, x):return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))))

前馈神经网络实现:

class FeedForward(nn.Module):def __init__(self, cfg):super().__init__()self.layers = nn.Sequential(nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),GELU(),nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),nn.Dropout(cfg["drop_rate"]))def forward(self, x):return self.layers(x)

Shortcut 连接

Shortcut 连接(残差连接)解决深度网络中的梯度消失问题:

class TransformerBlock(nn.Module):def __init__(self, cfg):super().__init__()self.att = MultiHeadAttention(d_in=cfg["emb_dim"],d_out=cfg["emb_dim"],block_size=cfg["ctx_len"],num_heads=cfg["n_heads"], dropout=cfg["drop_rate"],qkv_bias=cfg["qkv_bias"])self.ff = FeedForward(cfg)self.norm1 = LayerNorm(cfg["emb_dim"])self.norm2 = LayerNorm(cfg["emb_dim"])self.drop_resid = nn.Dropout(cfg["drop_rate"])def forward(self, x):# 注意力块的残差连接shortcut = xx = self.norm1(x)x = self.att(x)x = self.drop_resid(x)x = x + shortcut# 前馈网络的残差连接shortcut = xx = self.norm2(x)x = self.ff(x)x = self.drop_resid(x)x = x + shortcutreturn x

Transformer 块整合

将多头注意力与前馈网络整合为 Transformer 块:

class TransformerBlock(nn.Module):def __init__(self, cfg):super().__init__()self.att = MultiHeadAttention(d_in=cfg["emb_dim"],d_out=cfg["emb_dim"],block_size=cfg["ctx_len"],num_heads=cfg["n_heads"], dropout=cfg["drop_rate"],qkv_bias=cfg["qkv_bias"])self.ff = FeedForward(cfg)self.norm1 = LayerNorm(cfg["emb_dim"])self.norm2 = LayerNorm(cfg["emb_dim"])self.drop_resid = nn.Dropout(cfg["drop_rate"])def forward(self, x):# 注意力块的残差连接shortcut = xx = self.norm1(x)x = self.att(x)x = self.drop_resid(x)x = x + shortcut# 前馈网络的残差连接shortcut = xx = self.norm2(x)x = self.ff(x)x = self.drop_resid(x)x = x + shortcutreturn x

完整 GPT 模型实现

class GPTModel(nn.Module):def __init__(self, cfg):super().__init__()self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"])self.drop_emb = nn.Dropout(cfg["drop_rate"])self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])self.final_norm = LayerNorm(cfg["emb_dim"])self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)def forward(self, in_idx):batch_size, seq_len = in_idx.shapetok_embeds = self.tok_emb(in_idx)pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))x = tok_embeds + pos_embedsx = self.drop_emb(x)x = self.trf_blocks(x)x = self.final_norm(x)logits = self.out_head(x)return logits

文本生成

使用贪婪解码生成文本:

def generate_text_simple(model, idx, max_new_tokens, context_size):for _ in range(max_new_tokens):# 截断超过上下文长度的部分idx_cond = idx[:, -context_size:]with torch.no_grad():logits = model(idx_cond)# 获取最后一个 token 的 logitslogits = logits[:, -1, :]  probas = torch.softmax(logits, dim=-1)idx_next = torch.argmax(probas, dim=-1, keepdim=True)idx = torch.cat((idx, idx_next), dim=1)return idx

使用示例:

# 初始化模型
model = GPTModel(GPT_CONFIG_124M)# 设置评估模式
model.eval()# 生成文本
start_context = "Every effort moves you"
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)generated = generate_text_simple(model=model,idx=encoded_tensor,max_new_tokens=10,context_size=GPT_CONFIG_124M["ctx_len"]
)decoded_text = tokenizer.decode(generated.squeeze(0).tolist())
print(decoded_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/83648.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 中 linux 命令查询设备信息

一、getprop 命令 在 Linux 系统中, getprop 命令通常用于获取 Android 设备的系统属性,这些属性包括设备型号、Android 版本、电池状态等。 1、获取 Android 版本号 adb shell getprop ro.build.version.release2、获取设备型号 adb shell getprop …

26考研 | 王道 | 计算机组成原理 | 六、总线

26考研 | 王道 | 计算机组成原理 | 六、总线 文章目录 26考研 | 王道 | 计算机组成原理 | 六、总线6.1 总线概述1. 总线概述2. 总线的性能指标 6.2 总线仲裁(考纲没有,看了留个印象)6.3 总线操作和定时6.4 总线标准(考纲没有&…

SE(Secure Element)加密芯片与MCU协同工作的典型流程

以下是SE(Secure Element)加密芯片与MCU协同工作的典型流程,综合安全认证、数据保护及防篡改机制: 一、基础认证流程(参数保护方案) 密钥预置‌ SE芯片与MCU分别预置相同的3DES密钥(Key1、Key2…

数据库——MongoDB

一、介绍 1. MongoDB 概述 MongoDB 是一款由 C 语言编写的开源 NoSQL 数据库,采用分布式文件存储设计。作为介于关系型和非关系型数据库之间的产品,它是 NoSQL 数据库中最接近传统关系数据库的解决方案,同时保留了 NoSQL 的灵活性和扩展性。…

WebSocket 前端断连原因与检测方法

文章目录 前言WebSocket 前端断连原因与检测方法常见 WebSocket 断连原因及检测方式聊天系统场景下的断连问题与影响行情推送场景下的断连问题与影响React 前端应对断连的稳健策略自动重连机制的设计与节流控制心跳机制的实现与保持连接存活连接状态管理与 React 集成错误提示与…

2025年真实面试问题汇总(三)

线上数据库数据丢失如何恢复 线上数据库数据丢失的恢复方法需要根据数据丢失原因、备份情况及数据库类型(如MySQL、SQL Server、PostgreSQL等)综合处理,以下是通用的分步指南: 一、紧急止损:暂停写入,防止…

Android音视频多媒体开源框架基础大全

安卓多媒体开发框架中,从音频采集,视频采集,到音视频处理,音视频播放显示分别有哪些常用的框架?分成六章,这里一次帮你总结完。 音视频的主要流程是录制、处理、编解码和播放显示。本文也遵循这个流程展开…

安卓上架华为应用市场、应用宝、iosAppStore上架流程,保姆级记录(1)

上架前请准备好apk、备案、软著、企业开发者账号!!!其余准备好app相关的截图、介绍、测试账号,没讲解明白的评论区留言~ 华为应用市场 1、登录账号 打开 华为开发者平台 https://developer.huawei.com/consumer/cn/ 2.登录企…

【Docker】docker 常用命令

目录 一、镜像管理 二、容器操作 三、网络管理 四、存储卷管理 五、系统管理 六、Docker Compose 常用命令 一、镜像管理 命令参数解说示例说明docker pull镜像名:标签docker pull nginx:alpine拉取镜像(默认从 Docker Hub)docker images-a&#x…

OSPF域内路由

简介 Router-LSA Router-LSA(Router Link State Advertisement)是OSPF(Open Shortest Path First)协议中的一种链路状态通告(LSA),它由OSPF路由器生成,用于描述路由器自身的链路状态…

torch 高维矩阵乘法分析,一文说透

文章目录 简介向量乘法二维矩阵乘法三维矩阵乘法广播 高维矩阵乘法开源 简介 一提到矩阵乘法,大家对于二维矩阵乘法都很了解,即 A 矩阵的行乘以 B 矩阵的列。 但对于高维矩阵乘法可能就不太清楚,不知道高维矩阵乘法是怎么在计算。 建议使用…

瑞萨RA-T系列芯片马达类工程TCM加速化设置

本篇介绍在使用RA8-T系列芯片,建立马达类工程应用时,如何将电流环部分的指令和变量设置到TCM单元,以提高电流环执行速度,从而提高系统整体的运行性能,在伺服和高端工业领域有很高的实用价值。本文以RA8T1为范例&#x…

获取Unity节点路径

解决目的: 避免手动拼写节点路径的时候,出现路径错误导致获取不到节点的情况。解决效果: 添加如下脚本之后,将自动复制路径到剪贴板中,在代码中通过 ctrlv 粘贴路径代码如下: public class CustomMenuItems…

Docker 安装 Oracle 12C

镜像 https://docker.aityp.com/image/docker.io/truevoly/oracle-12c:latest docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/truevoly/oracle-12c:latest docker tag swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/truevoly/oracle-12c:latest d…

Linux内核网络协议注册与初始化:从proto_register到tcp_v4_init_sock的深度解析

一、协议注册:proto_register的核心使命 在Linux网络协议栈中,proto_register是协议初始化的基石,主要完成三项关键任务: Slab缓存创建(内存管理核心) prot->slab = kmem_cache_create_usercopy(prot->name, prot->obj_size, ...); if (prot->twsk_prot) pr…

GD32 MCU的真随机数发生器(TRNG)

GD32 MCU的真随机数发生器(TRNG) 文章目录 GD32 MCU的真随机数发生器(TRNG)一、定义与核心特征二、物理机制:量子与经典随机性三、生成方法四、应用场景五、与伪随机数的对比六、局限性⚙️ 七、物理熵源原理🔧 八、硬件实现流程(以GD32F450 GD32L233为例)8.1. **初始…

Vulkan学习笔记6—渲染呈现

一、渲染循环核心 while (!glfwWindowShouldClose(window)) {glfwPollEvents();helloTriangleApp.drawFrame(); // 绘制帧} 在 Vulkan 中渲染帧包含一组常见的步骤 等待前一帧完成(vkWaitForFences) 从交换链获取图像(vkAcquireNextImageKH…

React第六十二节 Router中 createStaticRouter 的使用详解

前言 createStaticRouter 是 React Router 专为 服务端渲染(SSR) 设计的 API,用于在服务器端处理路由匹配和数据加载。它在构建静态 HTML 响应时替代了客户端的 BrowserRouter,确保 SSR 和客户端 Hydration 的路由状态一致。 一…

qt 双缓冲案例对比

双缓冲 1.双缓冲原理 单缓冲:在paintEvent中直接绘制到屏幕,绘制过程被用户看到 双缓冲:先在redrawBuffer绘制到缓冲区,然后一次性显示完整结果 代码结构 单缓冲:所有绘制逻辑在paintEvent中 双缓冲:绘制…

华为云AI开发平台ModelArts

华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…