【Torch】nn.Embedding算法详解

1. 定义

nn.Embedding 是 PyTorch 中的 查表式嵌入层(lookup‐table),用于将离散的整数索引(如词 ID、实体 ID、离散特征类别等)映射到一个连续的、可训练的低维向量空间。它通过维护一个形状为 (num_embeddings, embedding_dim) 的权重矩阵,实现高效的“索引 → 向量”转换。

2. 输入与输出

  • 输入

    • 类型:整型张量(torch.longtorch.int64),必须是 LongTensor,其他类型会报错。
    • 形状:任意形状 (*, L),其中最内层长度 L 常为序列长度,前面的 * 可以是 batch 及其他维度。
    • 取值范围0 ≤ index < num_embeddings;超出范围会抛出 IndexError
  • 输出

    • 类型:浮点型张量(与权重相同的 dtype,默认为 torch.float32)。
    • 形状(*, L, embedding_dim);就是在输入张量后追加一个维度 embedding_dim
    • 语义:若输入某位置的值为 j,则该位置对应输出就是权重矩阵的第 j 行。

3. 底层原理

  1. 查表操作 vs. One-hot 乘法

    • 直观上,Embedding 相当于:
      output = one_hot ( i n p u t ) × W \text{output} = \text{one\_hot}(input) \;\times\; W output=one_hot(input)×W
      其中 W(num_embeddings×embedding_dim) 的权重矩阵。
    • 为避免显式构造稀疏的 one-hot 张量,PyTorch 直接根据索引做“取行”操作,效率更高、内存更省。
  2. 梯度更新

    • 稠密模式(默认):整个 W 都有梯度缓冲,优化器根据梯度更新所有行。
    • 稀疏模式sparse=True):仅对被索引过的行计算和存储梯度,可配合 optim.SparseAdam 高效更新,适合超大字典(百万级以上)但每次只访问少量索引的场景。
  3. 范数裁剪

    • 若指定 max_norm,每次前向都会对输出向量(即对应的行)做范数裁剪,保证其 L-norm_type 范数不超过 max_norm,有助于防止某些频繁访问的词向量过大。
  4. 权重初始化

    • 默认初始化使用均匀分布:
      W i , j ∼ U ( − 1 num_embeddings , 1 num_embeddings ) W_{i,j} \sim \mathcal{U}\Bigl(-\sqrt{\tfrac{1}{\text{num\_embeddings}}},\;\sqrt{\tfrac{1}{\text{num\_embeddings}}}\Bigr) Wi,jU(num_embeddings1 ,num_embeddings1 )
    • 可以通过 _weight 参数传入外部预训练权重(如 Word2Vec、GloVe 等)。

4. 构造函数参数详解

参数类型及默认说明
num_embeddingsint必填。嵌入表行数,等于类别总数(最大索引 + 1)。
embedding_dimint必填。每个向量的维度。
padding_idxintNone默认 None。指定该索引对应行始终输出全零,并且该行的梯度永远为 0,适合做序列填充。
max_normfloatNone默认 None。若设为数值,每次前向时对取出的向量做范数裁剪(L-norm_typemax_norm)。
norm_typefloat,默认 2max_norm 配合使用时定义范数类型,如 1-范数、2-范数等。
scale_grad_by_freqbool,默认 False若为 True,在反向传播阶段按照索引在 batch 中出现的频次对梯度做缩放(出现越多,梯度越小),有助于高频词的梯度平滑。
sparsebool,默认 False若为 True,开启稀疏更新,仅对被访问行生成梯度;必须配合 optim.SparseAdam 使用,不支持常规稠密优化器。
_weightTensorNone若提供,则用此张量(形状应为 (num_embeddings, embedding_dim))作为权重初始化,否则随机初始化。

5. 使用示例

import torch
import torch.nn as nn# 1. 参数设定
vocab_size = 10000   # 词表大小
embed_dim  = 300     # 嵌入维度# 2. 创建 Embedding 层
embedding = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,padding_idx=0,         # 将 0 作为填充索引,输出全 0max_norm=5.0,          # 向量范数不超过 5norm_type=2.0,scale_grad_by_freq=True,sparse=False
)# 3. 构造输入
# batch_size=2, seq_len=6
input_ids = torch.tensor([[  1, 234,  56, 789,   0,  23],[123,   4, 567,   8,   9,   0],
], dtype=torch.long)# 4. 前向计算
# 输出 shape = [2, 6, 300]
output = embedding(input_ids)
print(output.shape)  # -> torch.Size([2, 6, 300])

加载并冻结预训练权重

import numpy as np# 假设有预训练权重 pre_trained.npy,shape=(10000,300)
weights = torch.from_numpy(np.load("pre_trained.npy"))
embed_pre = nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim,_weight=weights
)
# 冻结所有权重
embed_pre.weight.requires_grad = False

6. 注意事项

  1. 类型与范围
    • 输入必须为 LongTensor,且所有索引满足 0 ≤ index < num_embeddings
  2. Padding 与 Mask
    • 仅指定 padding_idx 会返回零向量,但上游网络(如 RNN、Transformer)还需显式 mask,避免无效位置影响注意力或累积状态。
  3. 性能考量
    • max_norm 每次前向都做范数计算和裁剪,若不需要可关闭以提升速度。
  4. 稀疏更新限制
    • sparse=True 可节省内存,但只支持 SparseAdam,且在 GPU 上效率有时不如稠密模式。
  5. EmbeddingBag
    • 对于可变长度序列的 sum/mean/power-mean 汇聚,可使用 nn.EmbeddingBag,避免中间张量开销。
  6. 分布式与大词表
    • 在分布式训练时,可将嵌入表切分到多个进程上(torch.nn.parallel.DistributedDataParallel + torch.nn.Embedding 支持参数分布式)。
    • 超大词表(千万级)时,可考虑动态加载、分布式哈希表或专用库(如 DeepSpeed 的嵌入稀疏优化)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/912175.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/912175.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

cdq 三维偏序应用 / P4169 [Violet] 天使玩偶/SJY摆棋子

最近学了 cdq 分治想来做做这道题&#xff0c;结果被有些毒瘤的代码恶心到了。 /ll 题目大意&#xff1a;一开始给定一些平面中的点。然后给定一些修改和询问&#xff1a; 修改&#xff1a;增加一个点。询问&#xff1a;给定一个点&#xff0c;求离这个点最近&#xff08;定义…

System.Threading.Tasks 库简介

System.Threading.Tasks 是 .NET 中任务并行库(Task Parallel Library, TPL)的核心组件&#xff0c;它提供了基于任务的异步编程模型&#xff0c;是现代 .NET 并发编程的基础。 设计原理 1. 核心目标 抽象并发工作&#xff1a;将并发操作抽象为"任务"概念 资源高效…

Python爬虫实战:研究jieba相关技术

1. 引言 1.1 研究背景与意义 随着互联网技术的飞速发展,网络新闻已成为人们获取信息的主要渠道之一。每天产生的新闻文本数据量呈爆炸式增长,如何从海量文本中高效提取有价值的信息,成为信息科学领域的重要研究课题。文本分析技术通过对文本内容的结构化处理和语义挖掘,能…

github 淘金技巧

1. 效率&#xff0c;搜索&#xff0c;先不管。后面再说。 2. 分享的话&#xff0c; 其实使用默认的分享功能也行。也是后面再说。此 app &#xff0c; 今天先做到这里。 下面我们再聊点其他东西。其实我还想问&#xff0c;这个事情&#xff0c;其他人是否也做了&#xff0c; ht…

RAG技术发展综述

摘要 检索增强生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;技术已成为大语言模型应用的核心技术栈。RAG有效解决了LLM的幻觉问题、知识截止和实时更新挑战&#xff0c;目前正处于全面产业化阶段。本文系统性地分析RAG的全栈技术架构&#xff0c;包括检索…

集群聊天服务器---muduo库(3)

使用muduo网络库进行编译和链接的示例 项目的目录结构 bin: 存放可执行文件。 lib: 存放库文件。 include: 存放头文件。 src: 存放源代码文件。 build: 存放编译生成的中间文件。 example: 存放示例代码。 thirdparty: 存放第三方库。 CMakeLists.txt: CMake构建系统…

双核SOC/5340 应用和网络核间通讯

1&#xff1a; 可以在 nRF Connect SDK 文件夹结构的 samples/ipc/ipc_service 下找到示例&#xff0c;应用和网络核心在由 CONFIG_APP_IPC_SERVICE_SEND_INTERVAL 选项指定的时隙内相互发送数据。可以更改该值并观察每个核心的吞吐量如何变化 nRF5340 DK 可以使用 RPMsg 或 IC…

Spring Cloud Ribbon核心负载均衡算法详解

Ribbon 作为 Spring Cloud 生态中的客户端负载均衡工具&#xff0c;提供多种动态负载均衡算法&#xff0c;根据后端服务状态智能分配请求。其核心算法及适用场景如下&#xff1a; &#x1f9e0; 一、Ribbon 负载均衡算法 算法名称工作原理引用来源轮询 (RoundRobinRule)按服务…

网站图片过于太大影响整体加载响应速度怎么办? Typecho高级图像处理插件

文章目录 LeleImges - Typecho高级图像处理插件 🖼️插件介绍 📝插件架构 🏗️主要功能 ✨性能优势 🚀系统要求 📋安装方法 📥详细配置说明 ⚙️图片质量设置 🎚️最大宽度/高度限制 📏压缩格式选择 🗜️压缩方法选择 🔧GIF处理方式 🎞️备份源文件 💾…

VUE3入门很简单(1)--- 响应式对象

前言 重要提示&#xff1a;文章只适合初学者&#xff0c;不适合专家&#xff01;&#xff01;&#xff01; 什么是响应式对象&#xff1f; 在Vue3中&#xff0c;响应式对象就是这种智能温控器。当你修改JavaScript对象的数据时&#xff0c;Vue会自动更新网页上显示的内容&am…

广州华锐互动携手中石油:AR 巡检系统实现重大突破​

广州华锐互动在 AR 技术领域的卓越成就&#xff0c;通过一系列与知名企业、机构的成功合作案例得以充分彰显。其中&#xff0c;与中石油的合作项目堪称经典&#xff0c;展现了广州华锐互动运用 AR 技术解决实际难题、达成目标的强大实力。​ 中石油作为能源行业的巨擘&#xff…

权威认证!华宇TAS应用中间件荣获CCRC“中间件产品安全认证”

近日&#xff0c;华宇TAS应用中间件顺利通过了中国网络安全审查认证和市场监管大数据中心(CCRC)的信息安全认证&#xff0c;获得了IT产品信息安全认证证书。此次获证&#xff0c;标志着华宇TAS应用中间件在安全性、可靠性及合规性等方面达到行业领先水平&#xff0c;可以为政企…

BI财务分析 – 反映盈利水平利润占比的指标如何分析(下)

之前的文章重点把构成销售净利率、主营业务利润率、成本费用利润率、营业利润率、销售毛利率的分母像销售收入、营业收入、主营业务收入净额、成本费用总额做了比较细致的说明&#xff0c;把这几个基本的概念搞明白后&#xff0c;再来看这几个指标就比较容易理解了。 销售净利…

竹云受邀出席华为开发者大会,与华为联合发布海外政务数字化解决方案

6月20日-22日&#xff0c;华为开发者大会&#xff08;HDC 2025&#xff09;在东莞松山湖盛大召开。作为华为一年一度面向全球开发者的顶级科技盛会&#xff0c;今年的HDC不仅带来了HarmonyOS 6.0 Beta版本、盘古大模型5.5等多项重磅技术和产品更新&#xff0c;更聚集了全球极客…

AI助力游戏设计——从灵感到行动-靠岸篇

OK&#xff0c;朋友&#xff0c;如果你到了这里&#xff0c;那就证明这趟旅程&#xff0c;快要到岸了。 首先&#xff0c;恭喜你&#xff0c;到了需要这一步的时候。其实&#xff0c;如果你有一天真的用到了&#xff0c;希望你可以回来打个卡。行了&#xff0c;不废话&#xf…

vue将页面导出pdf,vue导出pdf ,使用html2canvas和jspdf组件

vue导出pdf 需求&#xff1a;需要前端下载把当前html下载成pdf文件–有十八页超长&#xff0c;之前使用vue-html2pdf组件&#xff0c;但是这个组件有长度限制和比较新浏览器版本限制&#xff0c;所以改成使用html2canvas和jspdf组件 方法&#xff1a; 1、第一步&#xff1a;我…

024 企业客户管理系统技术解析:基于 Spring Boot 的全流程管理平台

企业客户管理系统技术解析&#xff1a;基于Spring Boot的全流程管理平台 在企业数字化转型的浪潮中&#xff0c;高效的客户管理系统成为提升企业竞争力的关键工具。本文将深入解析基于Java和Spring Boot框架构建的企业客户管理系统&#xff0c;该系统涵盖员工管理、客户信息管…

JavaScript性能优化代码示例

JavaScript性能优化实战大纲 性能优化的核心目标 减少加载时间、提升渲染效率、降低内存占用、优化交互响应 代码层面的优化实践 避免全局变量污染&#xff0c;使用局部变量和模块化开发 减少DOM操作频率&#xff0c;批量处理DOM更新 使用事件委托替代大量事件监听器 优化循…

树的重心(双dfs,换根)

思路&#xff1a; 基于树形 DP 的两次遍历&#xff08;第一次dfs计算以某个初始根&#xff08;这里选了 1&#xff09;为根时各子树的深度和与节点数&#xff0c;第二次zy进行换根操作&#xff0c;更新每个节点作为根时的深度和&#xff09; 换根原理&#xff1a; 更换主根&…

官方App Store,直链下载macOS ,无需Apple ID,macOS10.10以上.

前言 想必很多人都有过维修老旧Mac的体验,也有过想要重装macos的体验. 尤其是前者,想要重装或者升级系统,由于官方已经无法更新,必须下载iSo镜像 这时就会遇到死循环:想要更新macOS ,必须先使用更高版本的App Store,但要使用更高版本的App Store,必须先更新macOS !!! 如果想…