自然语言处理之PyTorch实现词袋CBOW模型

在自然语言处理(NLP)领域,词向量(Word Embedding)是将文本转换为数值向量的核心技术。它能让计算机“理解”词语的语义关联,例如“国王”和“女王”的向量差可能与“男人”和“女人”的向量差相似。而Word2Vec作为经典的词向量训练模型,其核心思想是通过上下文预测目标词(或反之)。本文将以 --CBOW(连续词袋模型)为例,带你从代码到原理,一步步实现一个简单的词向量训练过程。

一、CBOW模型简介

CBOW(Continuous Bag-of-Words)是Word2Vec的两种核心模型之一。其核心思想是:给定目标词的上下文窗口内的所有词,预测目标词本身。例如,对于句子“We are about to study”,若上下文窗口大小为2(即目标词左右各取2个词),则当目标词是“about”时,上下文是“We, are, to, study”,模型需要根据这4个词预测出“about”。

CBOW的优势在于通过平均上下文词向量来预测目标词,计算效率高;缺点是对低频词不友好。本文将实现的CBOW模型包含词嵌入层、投影层和输出层,最终输出目标词的概率分布。

二、环境准备与数据预处理

2.1 进度条库安装

pip install torch numpy tqdm

2.2 语料库与基础设置

我们使用一段英文文本作为语料库,并定义上下文窗口大小(CONTEXT_SIZE=2,即目标词左右各取2个词):

CONTEXT_SIZE = 2  # 上下文窗口大小(左右各2个词)
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()  # 按空格分割成单词列表

2.3 构建词汇表与映射

为了将文本转换为模型可处理的数值,需要先构建词汇表(所有唯一词),并为每个词分配唯一索引:

vocab = set(raw_text)  # 去重后的词汇表(集合)
vocab_size = len(vocab)  # 词汇表大小(本文示例中为49)# 词到索引的映射(如:"We"→0,"are"→1)
word_to_idx = {word: i for i, word in enumerate(vocab)}
# 索引到词的反向映射(如:0→"We",1→"are")
idx_to_word = {i: word for i, word in enumerate(vocab)}

三、生成训练数据:上下文-目标词对

CBOW的训练数据是“上下文词列表”与“目标词”的配对。例如,若目标词是raw_text[i],则上下文是[raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]](假设窗口大小为2)。

3.1 数据生成逻辑

通过遍历语料库,跳过前CONTEXT_SIZE和后CONTEXT_SIZE个词(避免越界),生成上下文-目标词对:

data = []
for i in range(CONTEXT_SIZE, len(raw_text) - CONTEXT_SIZE):# 左上下文:取i-2, i-1(j从0到1,2-j对应2,1)left_context = [raw_text[i - (2 - j)] for j in range(CONTEXT_SIZE)]# 右上下文:取i+1, i+2(j从0到1,i+j+1对应i+1, i+2)right_context = [raw_text[i + j + 1] for j in range(CONTEXT_SIZE)]context = left_context + right_context  # 合并上下文(共4个词)target = raw_text[i]  # 目标词(当前中心词)data.append((context, target))

3.2 示例验证

i=2为例:

  • 左上下文:i-2=0(“We”),i-1=1(“are”)→ ["We", "are"]
  • 右上下文:i+1=3(“to”),i+2=4(“study”)→ ["to", "study"]
  • 上下文合并:["We", "are", "to", "study"]
  • 目标词:raw_text[2](“about”)

四、CBOW模型实现(PyTorch)

4.1 模型结构设计

CBOW模型的核心是通过上下文词的词向量预测目标词。模型结构包含三层:

  1. 词嵌入层(Embedding):将词的索引映射为低维稠密向量(如10维)。
  2. 投影层(Linear):将拼接后的词向量投影到更高维度(如128维),增加非线性表达能力。
  3. 输出层(Linear):将投影后的向量映射回词汇表大小,通过Softmax输出目标词的概率分布。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass CBOW(nn.Module):  # 神经网络def __init__(self, vocab_size, embedding_dim):super(CBOW, self).__init__()  # 父类的初始化self.embeddings = nn.Embedding(vocab_size, embedding_dim)self.proj = nn.Linear(embedding_dim, 128)self.output = nn.Linear(128, vocab_size)def forward(self, inputs):embeds = sum(self.embeddings(inputs)).view(1, -1)out = F.relu(self.proj(embeds))  # nn.relu() 激活层out = self.output(out)nll_prob = F.log_softmax(out, dim=-1)  # softmax交叉熵return nll_prob

五、模型训练与优化

5.1 初始化模型与超参数

设置词向量维度(embedding_dim=10)、学习率(lr=0.001)、训练轮数(epochs=200),并初始化模型、优化器和损失函数:

vocab_size = 49
model = CBOW(vocab_size, 10).to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
losses = []# 存储损失的集合  losses: []
loss_function = nn.NLLLoss()        #NLLLoss损失函数(当分类列表非常多的情况),将多个类

5.2 训练循环逻辑

遍历每个训练轮次(epoch),对每个上下文-目标词对进行前向传播、损失计算、反向传播和参数更新:

model.train()
for epoch in tqdm(range(200)):#开始训练total_loss = 0for context, target in data:context_vector = make_context_vector(context, word_to_idx).to(device)target = torch.tensor([word_to_idx[target]]).to(device)# 开始前向传播train_predict = model(context_vector)  # 可以不写forward,torch的内置功能,loss = loss_function(train_predict, target)  # 计算损失# 反向传播optimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度optimizer.step()  # 根据梯度更新网络参数total_loss += loss.item()

六、词向量提取与应用

训练完成后,模型的词嵌入层(model.embeddings.weight)中存储了每个词的向量表示。我们可以将其提取并保存,用于后续任务(如文本分类、相似度计算)。

6.1 提取词向量

# 将词向量从GPU移至CPU,并转换为NumPy数组
W = model.embeddings.weight.cpu().detach().numpy()
print("词向量矩阵形状:", W.shape)  # (vocab_size, embedding_dim) → (49, 10)

6.2 生成词-向量映射字典

word_2_vec = {}
for word, idx in word_to_idx.items():word_2_vec[word] = W[idx]  # 每个词对应词向量矩阵中的一行
print("示例词向量('process'):", word_2_vec["process"])

6.3 保存词向量

使用np.savez保存词向量矩阵,方便后续加载使用:

import numpy as np
np.savez('word2vec实现.npz', word_vectors=W)  # 保存为npz文件# 加载验证
data = np.load('word2vec实现.npz')
loaded_vectors = data['word_vectors']
print("加载的词向量形状:", loaded_vectors.shape)  # 应与原始矩阵一致

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/95787.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/95787.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TCP, 三次握手, 四次挥手, 滑动窗口, 快速重传, 拥塞控制, 半连接队列, RST, SYN, ACK

目录 TCP 是什么:面向连接 可靠 字节流三次握手:为什么不是两次四次挥手与 TIME_WAIT:谁等谁序列号/确认号与去重、排序、确认重传机制:超时重传与快速重传滑动窗口与流量控制拥塞控制:慢启动/拥塞避免/快重传/快恢…

CentOS 7.2 虚机 ssh 登录报错在重启后无法进入系统

文章目录前言1. 故障描述2. 故障诊断3. 故障原因4. 解决方案总结前言 上周帮用户处理了一个 linux 虚拟机在重启后无法正常进入操作系统的故障,觉得比较有意思,在这里分享给大家。 1. 故障描述 事情的起因是一台系统版本为 CentOS 7.2 的 VMware 虚拟机…

《从使用到源码:OkHttp3责任链模式剖析》

一 从使用开始0.依赖引入implementation ("com.squareup.okhttp3:okhttp:3.14.7")1.创建OkHttpClient实例方式一:直接使用默认配置的Builder//从源码可以看出,当我们直接new创建OkHttpClient实例时,会默认给我们配置好一个Builder …

安装3DS MAX 2026后,无法运行,提示缺少.net core的解决方案

今天安装了3DS MAX 2026(俗称3DMAX),安装完毕后死活运行不了。提示如下: 大意是找不到所需的.NET Core 8库文件。后来搜索了下,各种文章说.NET CORE和.NET FRAMEWORK不是一个东西。需要单独下载安装。然后根据提示&…

FastAPI + LangChain 和 Spring AI + LangChain4j

FastAPI+LangChain和Spring AI+LangChain4j这两个技术组合进行详细对比。 核心区别: 特性维度 FastAPI + LangChain (Python栈) Spring AI + LangChain4j (Java栈) 技术栈 Python生态 (FastAPI, LangChain) Java生态 (Spring Boot, Spring AI, LangChain4j) 核心设计哲学 灵活…

Apache 2.0 开源协议详解:自由、责任与商业化的完美平衡-优雅草卓伊凡

Apache 2.0 开源协议详解:自由、责任与商业化的完美平衡-优雅草卓伊凡引言由于我们优雅草要推出收银系统,因此要采用开源代码,卓伊凡目前看好了一个产品是apache 2.0协议,因此我们有必要深刻理解apache 2.0协议避免触犯版权问题。…

自学嵌入式第37天:MQTT协议

一、MQTT(消息队列遥测传输协议Message Queuing Telemetry Transport)1.MQTT是应用层的协议,是一种基于发布/订阅模式的“轻量级”通讯协议,建构于TCP/IP协议上,可以以极少的代码和有限的带宽为连接远程设备提供实时可…

RabbitMQ--延时队列总结

一、延迟队列概念 延迟队列(Delay Queue)是一种特殊类型的队列,队列中的元素需要在指定的时间点被取出和处理。简单来说,延时队列就是存放需要在某个特定时间被处理的消息。它的核心特性在于“延迟”——消息在队列中停留一段时间…

Java 提取 PDF 文件内容:告别手动复制粘贴,拥抱自动化解析!

在日常工作中,我们经常需要处理大量的 PDF 文档,无论是提取报告中的关键数据,还是解析合同中的重要条款,手动复制粘贴不仅效率低下,还极易出错。当面对海量的 PDF 文件时,这种传统方式更是让人望而却步。那…

关键字 const

Flutter 是一个使用 Dart 语言构建的 UI 工具包,因此它完全遵循 Dart 的语法和规则。Dart 中的 const 是语言层面的特性,而 Flutter 因其声明式 UI 和频繁重建的特性,将 const 的效能发挥到了极致。Dart 中的 const(语言层面&…

Ubuntu22.04中使用cmake安装abseil-cpp库

Ubuntu22.04中使用cmake安装abseil-cpp库 关于Abseil库 Abseil 由 Google 的基础 C 和 Python 代码库组成,包括一些正支撑着如 gRPC、Protobuf 和 TensorFlow 等开源项目并一起 “成长” 的库。目前已开源 C 部分,Python 部分将在后续开放。 Abseil …

FreeRTOS项目(序)目录

这章是整个专栏的目录,负责记录这个小项目的开发日志和目录。附带总流程图。 目录 项目简介 专栏目录 开发日志 总流程图 项目简介 本项目基于STM32C8T6核心板和FreeRTOS,实现一些简单的功能。以下为目前已实现的功能。 (1&#xff09…

Python 多任务编程:进程、线程与协程全面解析

目录 一、多任务基础:并发与并行 1. 什么是多任务 2. 两种表现形式 二、进程:操作系统资源分配的最小单位 1. 进程的概念 2. 多进程实现多任务 2.1 基础示例:边听音乐边敲代码 2.2 带参数的进程任务 2.3 进程编号与应用注意点 2.3.…

ADSL技术

<摘要> ADSL&#xff08;非对称数字用户线路&#xff09;是一种利用传统电话线实现宽带上网的技术。其核心原理是频率分割&#xff1a;将一根电话线的频带划分为语音、上行数据&#xff08;慢&#xff09;和下行数据&#xff08;快&#xff09;三个独立频道&#xff0c;从…

信号衰减中的分贝到底是怎么回事

问题&#xff1a;在一个低通滤波中&#xff0c;经常会看到一个值-3dB&#xff08;-3分贝&#xff09;&#xff0c;到底是个什么含义&#xff1f; 今天我就来粗浅的讲解这个问题。 在低通滤波器中&#xff0c;我们说的 “截止频率”&#xff08;或叫 - 3dB 点&#xff09;&…

工具分享--IP与域名提取工具2.0

基于原版的基础上新增了一个功能点:IP-A段过滤&#xff0c;可以快速把内网192、170、10或者其它你想要过滤掉的IP-A段轻松去掉&#xff0c;提高你的干活效率&#xff01;&#xff01;&#xff01; 界面样式预览&#xff1a;<!DOCTYPE html> <html lang"zh-CN&quo…

如何通过日志先行原则保障数据持久化:Redis AOF 和 MySQL redo log 的对比

在分布式系统或数据库管理系统中&#xff0c;日志先行原则&#xff08;Write-Ahead Logging&#xff0c;WAL&#xff09; 是确保数据一致性、持久性和恢复能力的重要机制。通过 WAL&#xff0c;系统能够在发生故障时恢复数据&#xff0c;保证数据的可靠性。在这篇博客中&#x…

临床研究三千问——临床研究体系的3个维度(8)

在上周的文章中&#xff0c;我们共同探讨了1345-10战策的“临床研究的起点——如何提出一个犀利的临床与科学问题”。问题固然是灵魂&#xff0c;但若没有坚实的骨架与血肉&#xff0c;灵魂便无所依归。今天&#xff0c;我们将深入“1345-10战策”中的“3”&#xff0c;即支撑起…

AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年9月7日第172弹

从今天开始&#xff0c;咱们还是暂时基于旧的模型进行预测&#xff0c;好了&#xff0c;废话不多说&#xff0c;按照老办法&#xff0c;重点8-9码定位&#xff0c;配合三胆下1或下2&#xff0c;杀1-2个和尾&#xff0c;再杀4-5个和值&#xff0c;可以做到100-300注左右。(1)定位…

万字详解网络编程之socket

一&#xff0c;socket简介1.什么是socketsocket通常也称作"套接字"&#xff0c;⽤于描述IP地址和端⼝&#xff0c;是⼀个通信链的句柄&#xff0c;应用程序通常通过"套接字"向⽹络发出请求或者应答⽹络请求。⽹络通信就是两个进程间的通信&#xff0c;这两…