从代码学习深度学习 - 预训练word2vec PyTorch版

文章目录

  • 前言
  • 辅助工具
    • 1. 绘图工具 (`utils_for_huitu.py`)
    • 2. 数据处理工具 (`utils_for_data.py`)
    • 3. 训练辅助工具 (`utils_for_train.py`)
  • 预训练 Word2Vec - 主流程
    • 1. 环境设置与数据加载
    • 2. 跳元模型 (Skip-gram Model)
      • 2.1. 嵌入层 (Embedding Layer)
      • 2.2. 定义前向传播
    • 3. 训练
      • 3.1. 二元交叉熵损失
      • 3.2. 初始化模型参数
      • 3.3. 定义训练阶段代码
      • 3.4. 开始训练
    • 4. 应用词嵌入
  • 总结


前言

词嵌入(Word Embeddings)是自然语言处理(NLP)领域中的基石技术之一。它们将词语从稀疏的、高维的独热编码(one-hot encoding)表示转换为稠密的、低维的向量表示。这些向量能够捕捉词语之间的语义和句法关系,使得相似的词在向量空间中距离更近。Word2Vec是其中一种非常流行且有效的词嵌入算法,由Google的Tomas Mikolov等人在2013年提出。它主要包含两种模型架构:CBOW(Continuous Bag-of-Words,连续词袋模型)和Skip-gram(跳字模型)。

本篇博客将聚焦于Skip-gram模型,并结合**负采样(Negative Sampling)**这一重要的优化技巧,通过PyTorch框架从零开始实现一个Word2Vec模型。我们将详细探讨数据预处理的每一个步骤,如何构建模型,如何进行训练,以及训练完成后如何应用得到的词向量来寻找相似词。通过深入代码细节,我们希望能帮助读者更好地理解Word2Vec的内部工作原理及其在PyTorch中的实现。

我们将依赖一系列辅助脚本来处理数据、可视化训练过程以及进行模型训练。让我们一步步揭开Word2Vec的神秘面纱。

完整代码:下载链接

辅助工具

在构建和训练Word2Vec模型之前,我们首先介绍一下项目中用到的一些辅助Python脚本。这些脚本提供了数据加载、预处理、可视化以及训练监控等常用功能。

1. 绘图工具 (utils_for_huitu.py)

这个脚本主要封装了使用matplotlib进行绘图的常用函数,特别是在Jupyter Notebook环境中,它包含了一个Animator类,可以动态地展示训练过程中的损失变化。

# 导入必要的包
import matplotlib.pyplot as plt  # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline  # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display  # 用于后续动态显示(如 Animator)
import torch  # 导入PyTorch库,用于处理张量类型的图像
import numpy as np  # 导入NumPy,可能用于数据处理
import matplotlib as mpl  # 导入Matplotlib主模块,用于设置图像属性def set_figsize(figsize=(3.5, 2.5)):"""设置matplotlib图形的大小参数:figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸输出:无返回值"""plt.rcParams['figure.figsize'] = figsize  # 设置图形默认大小def use_svg_display():"""使用 SVG 格式在 Jupyter 中显示绘图输入:无输出:无返回值"""backend_inline.set_matplotlib_formats('svg')  # 设置 Matplotlib 使用 SVG 格式def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""设置 Matplotlib 的轴  输入:axes: Matplotlib 的轴对象  # 输入参数:轴对象xlabel: x 轴标签  # 输入参数:x 轴标签ylabel: y 轴标签  # 输入参数:y 轴标签xlim: x 轴范围  # 输入参数:x 轴范围ylim: y 轴范围  # 输入参数:y 轴范围xscale: x 轴刻度类型  # 输入参数:x 轴刻度类型yscale: y 轴刻度类型  # 输入参数:y 轴刻度类型legend: 图例标签列表  # 输入参数:图例标签输出:无返回值  # 函数无显式返回值"""axes.set_xlabel(xlabel)  # 设置 x 轴标签axes.set_ylabel(ylabel)  # 设置 y 轴标签axes.set_xscale(xscale)  # 设置 x 轴刻度类型axes.set_yscale(yscale)  # 设置 y 轴刻度类型axes.set_xlim(xlim)  # 设置 x 轴范围axes.set_ylim(ylim)  # 设置 y 轴范围if legend:  # 检查是否提供了图例标签axes.legend(legend)  # 如果有图例,则设置图例axes.grid()  # 为轴添加网格线class Animator:"""在动画中绘制数据,仅针对一张图的情况"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):"""初始化 Animator 类 输入:xlabel: x 轴标签,默认为 None  # 输入参数:x 轴标签ylabel: y 轴标签,默认为 None  # 输入参数:y 轴标签legend: 图例标签列表,默认为 None  # 输入参数:图例标签xlim: x 轴范围,默认为 None  # 输入参数:x 轴范围ylim: y 轴范围,默认为 None  # 输入参数:y 轴范围xscale: x 轴刻度类型,默认为 'linear'  # 输入参数:x 轴刻度类型yscale: y 轴刻度类型,默认为 'linear'  # 输入参数:y 轴刻度类型fmts: 绘图格式元组,默认为 ('-', 'm--', 'g-.', 'r:')  # 输入参数:线条格式nrows: 子图行数,默认为 1  # 输入参数:子图行数ncols: 子图列数,默认为 1  # 输入参数:子图列数figsize: 图像大小元组,默认为 (3.5, 2.5)  # 输入参数:图像大小输出:无返回值  # 方法无显式返回值定义位置::numref:`sec_softmax_scratch`  # 指明定义的参考位置"""if legend is None:  # 检查 legend 是否为 Nonelegend = []  # 如果为 None,则初始化为空列表use_svg_display()  # 设置绘图显示为 SVG 格式self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)  # 创建绘图对象和子图if nrows * ncols == 1:  # 判断是否只有一个子图self.axes = [self.axes, ]  # 如果是单个子图,将 axes 转为列表self.config_axes = lambda: set_axes(  # 定义 lambda 函数配置坐标轴self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)  # 调用 set_axes 设置参数self.X, self.Y, self.fmts = None, None, fmts  # 初始化数据和格式属性def add(self, x, y):"""向图表中添加多个数据点  输入:x: x 轴数据点  # 输入参数:x 轴数据y: y 轴数据点  # 输入参数:y 轴数据输出:无返回值  # 方法无显式返回值"""if not hasattr(y, "__len__"):  # 检查 y 是否具有长度属性(是否可迭代)y = [y]  # 如果不可迭代,将 y 转为单元素列表n = len(y)  # 获取 y 的长度if not hasattr(x, "__len__"):  # 检查 x 是否具有长度属性x = [x] * n  # 如果不可迭代,将 x 扩展为与 y 同长度的列表if not self.X:  # 检查 self.X 是否已初始化self.X = [[] for _ in range(n)]  # 如果未初始化,为每条线创建空列表if not self.Y:  # 检查 self.Y 是否已初始化self.Y = [[] for _ in range(n)]  # 如果未初始化,为每条线创建空列表for i, (a, b) in enumerate(zip(x, y)):  # 遍历 x 和 y 的数据对if a is not None and b is not None:  # 检查数据点是否有效self.X[i].append(a)  # 将 x 数据点添加到对应列表self.Y[i].append(b)  # 将 y 数据点添加到对应列表self.axes[0].cla()  # 清除当前轴的内容for x, y, fmt in zip(self.X, self.Y, self.fmts):  # 遍历所有数据和格式self.axes[0].plot(x, y, fmt)  # 绘制每条线self.config_axes()  # 调用 lambda 函数配置坐标轴display.display(self.fig)  # 显示当前图形display.clear_output(wait=True)  # 标记当前输出为待清除,但由于 wait=True,它不会立即清除,而是等待下一次 display.display()。def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):"""绘制列表长度对的直方图,用于比较两组列表中元素长度的分布参数:legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签xlabel: str - x轴标签ylabel: str - y轴标签xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)输出:无返回值,但会显示生成的直方图"""set_figsize()  # 设置图形大小# plt.hist返回的三个值:# n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)# bins: array - bin的边界值,形状为 (bin数量+1,)# patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)_, _, patches = plt.hist([[len(l) for l in xlist], [len(l) for l in ylist]])  # 绘制两组数据长度的直方图plt.xlabel(xlabel)  # 设置x轴标签plt.ylabel(ylabel)  # 设置y轴标签# 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据for patch in patches[1].patches:  # patches[1]是ylist对应的矩形对象列表patch.set_hatch('/')  # 设置填充图案为斜线plt.legend(legend)  # 添加图例

解读

  • set_figsizeuse_svg_display 用于基础的Matplotlib绘图设置。
  • set_axes 是一个通用的函数,用于配置图表的坐标轴标签、范围、刻度类型和图例。
  • Animator 类是实现动态绘图的关键。在训练循环中,我们可以周期性地调用其add方法,传入当前的训练轮次(或迭代次数)和对应的损失值(或其他指标)。Animator会清除旧的图像并重新绘制,从而在Jupyter Notebook中形成动画效果,直观地展示训练趋势。
  • show_list_len_pair_hist 函数用于绘制两个列表集合中,各子列表长度分布的直方图,方便进行数据分析和比较。

2. 数据处理工具 (utils_for_data.py)

这个脚本是Word2Vec数据预处理的核心,包含了从读取原始文本、构建词汇表、下采样、生成中心词-上下文词对、负采样到最终打包成PyTorch DataLoader的完整流程。

from collections import Counter  # 导入 Counter 类
from collections import Counter  # 用于词频统计
import torch  # PyTorch 核心库
from torch.utils import data  # PyTorch 数据加载工具
import numpy as np  # NumPy 用于数组操作
import random  # 导入随机模块,用于下采样和负采样
import math  # 导入数学函数模块,用于概率计算
import osdef count_corpus(tokens):"""统计词元的频率参数:tokens: 词元列表,可以是:- 一维列表,例如 ['a', 'b']- 二维列表,例如 [['a', 'b'], ['c']]返回值:Counter: Counter 对象,统计每个词元的出现次数"""# 如果输入为空列表,直接返回空计数器if not tokens:  # 等价于 len(tokens) == 0return Counter()# 检查输入是否为二维列表if isinstance(tokens[0], list):# 将二维列表展平为一维列表flattened_tokens = [token for sublist in tokens for token in sublist]else:# 如果是一维列表,直接使用原列表flattened_tokens = tokens# 使用 Counter 统计词频并返回return Counter(flattened_tokens)class Vocab:"""文本词表类,用于管理词元及其索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化词表Args:tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表"""# 处理默认参数self.tokens = tokens if tokens is not None else []self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []# 统计词元频率并按频率降序排序counter = self._count_corpus(self.tokens)self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化词表,'<unk>'为未知词元,索引为0self.idx_to_token = ['<unk>'] + self.reserved_tokensself.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 添加满足最小频率要求的词元到词表for token, freq in self._token_freqs:if freq < min_freq:breakif token not in self.token_to_idx:self.idx_to_token.append(token)self.token_to_idx[token] = 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/906810.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python实现对大批量Word文档进行自动添加页码(16)

前言 本文是该专栏的第16篇,后面会持续分享Python办公自动化干货知识,记得关注。 在处理word文档的时候,相信或多或少都遇到过这样的需求——需要对“目标word文档,自动添加页码”。 换言之,如果有大批量的word文档文件需要你添加页码,这个时候最聪明的办法就是使用“程…

云原生安全:Linux命令行操作全解析

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 ——从基础概念到安全实践的完整指南 一、基础概念 1. Shell与终端交互 Shell是Linux命令行的解释器&#xff08;如Bash、Zsh&#xff09;&#xff0c;负…

Day 34

GPU训练 要让模型在 GPU 上训练&#xff0c;主要是将模型和数据迁移到 GPU 设备上。 在 PyTorch 里&#xff0c;.to(device) 方法的作用是把张量或者模型转移到指定的计算设备&#xff08;像 CPU 或者 GPU&#xff09;上。 对于张量&#xff08;Tensor&#xff09;&#xff1…

C++笔试题(金山科技新未来训练营):

题目分布&#xff1a; 17道单选&#xff08;每题3分&#xff09;3道多选题&#xff08;全对3分&#xff0c;部分对1分&#xff09;2道编程题&#xff08;每一道20分&#xff09;。 不过题目太多&#xff0c;就记得一部分了&#xff1a; 单选题&#xff1a; static变量的初始…

Spark(29)基础自定义分区器

&#xff08;一&#xff09;什么是分区 【复习提问&#xff1a;RDD的定义是什么&#xff1f;】 在 Spark 里&#xff0c;弹性分布式数据集&#xff08;RDD&#xff09;是核心的数据抽象&#xff0c;它是不可变的、可分区的、里面的元素并行计算的集合。 在 Spark 中&#xf…

python打卡训练营打卡记录day35

知识点回顾&#xff1a; 三种不同的模型可视化方法&#xff1a;推荐torchinfo打印summary权重分布可视化进度条功能&#xff1a;手动和自动写法&#xff0c;让打印结果更加美观推理的写法&#xff1a;评估模式 作业&#xff1a;调整模型定义时的超参数&#xff0c;对比下效果 1…

【MySQL】07.表内容的操作

1. insert 我们先创建一个表结构&#xff0c;这部分操作我们使用这张表完成我们的操作&#xff1a; mysql> create table student(-> id int primary key auto_increment,-> name varchar(20) not null,-> qq varchar(20) unique-> ); Query OK, 0 rows affec…

使用SQLite Expert个人版VACUUM功能修复数据库

使用SQLite Expert个人版VACUUM功能修复数据库 一、SQLite Expert工具简介 SQLite Expert 是一款功能强大的SQLite数据库管理工具&#xff0c;分为免费的个人版&#xff08;Personal Edition&#xff09;和收费的专业版&#xff08;Professional Edition&#xff09;。其核心功…

LM-BFF——语言模型微调新范式

gpt3&#xff08;GPT3——少样本示例推动下的通用语言模型雏形)结合提示词和少样本示例后&#xff0c;展示出了强大性能。但大语言模型的训练门槛太高&#xff0c;普通研究人员无力&#xff0c;LM-BFF(Making Pre-trained Language Models Better Few-shot Learners)的作者受gp…

遥感解译项目Land-Cover-Semantic-Segmentation-PyTorch之二训练模型

遥感解译项目Land-Cover-Semantic-Segmentation-PyTorch之一推理模型 背景 上一篇文章了解了这个项目的环境安装和模型推理,这篇文章介绍下如何训练这个模型,添加类别 下载数据集 在之前的一篇文章中,也有用到这个数据集 QGIS之三十六Deepness插件实现AI遥感训练模型 数…

【NLP 71、常见大模型的模型结构对比】

三到五年的深耕&#xff0c;足够让你成为一个你想成为的人 —— 25.5.8 模型名称位置编码Transformer结构多头机制Feed Forward层设计归一化层设计线性层偏置项激活函数训练数据规模及来源参数量应用场景侧重GPT-5 (OpenAI)RoPE动态相对编码混合专家架构&#xff08;MoE&#…

[250521] DBeaver 25.0.5 发布:SQL 编辑器、导航器全面升级,新增 Kingbase 支持!

目录 DBeaver 25.0.5 发布&#xff1a;SQL 编辑器、导航器全面升级&#xff0c;新增 Kingbase 支持&#xff01; DBeaver 25.0.5 发布&#xff1a;SQL 编辑器、导航器全面升级&#xff0c;新增 Kingbase 支持&#xff01; 近日&#xff0c;DBeaver 发布了 25.0.5 版本&#xf…

服务器硬盘虚拟卷的处理

目前的情况是需要删除逻辑卷&#xff0c;然后再重新来弄一遍。 数据已经备份好了&#xff0c;所以不用担心数据会丢失。 查看服务器的具体情况 使用 vgdisplay 操作查看服务器的卷组情况&#xff1a; --- Volume group ---VG Name vg01System IDFormat …

Flutter 中 build 方法为何写在 StatefulWidget 的 State 类中

Flutter 中 build 方法为何写在 StatefulWidget 的 State 类中 在 Flutter 中&#xff0c;build 方法被设计在 StatefulWidget 的 State 类中而非 StatefulWidget 类本身&#xff0c;这种设计基于几个重要的架构原则和实际考量&#xff1a; 1. 核心设计原因 1.1 生命周期管理…

传统医疗系统文档集中标准化存储和AI智能化更新路径分析

引言 随着医疗数智化建设的深入推进&#xff0c;传统医疗系统如医院信息系统(HIS)、临床信息系统(CIS)、护理信息系统(NIS)、影像归档与通信系统(PACS)和实验室信息系统(LIS)已经成为了现代医疗机构不可或缺的技术基础设施。这些系统各自承担着不同的功能&#xff0c;共同支撑…

探索常识性概念图谱:构建智能生活的知识桥梁

目录 一、知识图谱背景介绍 &#xff08;一&#xff09;基本背景 &#xff08;二&#xff09;与NLP的关系 &#xff08;三&#xff09;常识性概念图谱的引入对比 二、常识性概念图谱介绍 &#xff08;一&#xff09;常识性概念图谱关系图示例 &#xff08;二&#xff09…

Linux/aarch64架构下安装Python的Orekit开发环境

1.背景 国产化趋势越来越强&#xff0c;从软件到硬件&#xff0c;从操作系统到CPU&#xff0c;甚至显卡&#xff0c;就产生了在国产ARM CPU和Kylin系统下部署Orekit的需求&#xff0c;且之前的开发是基于Python的&#xff0c;需要做适配。 2.X86架构下安装Python/Orekit开发环…

Ctrl+鼠标滚动阻止页面放大/缩小

项目场景&#xff1a; 提示&#xff1a;这里简述项目相关背景&#xff1a; 一般在我们做大屏的时候&#xff0c;不希望Ctrl鼠标上下滚动的时候页面会放大/缩小&#xff0c;那么在有时候&#xff0c;又不希望影响到别的页面&#xff0c;比如说这个大屏是在另一个管理后台中&am…

MySQL——复合查询表的内外连

目录 复合查询 回顾基本查询 多表查询 自连接 子查询 where 字句中使用子查询 单行子查询 多行子查询 多列子查询 from 字句中使用子查询 合并查询 实战OJ 查找所有员工入职时候的薪水情况 获取所有非manager的员工emp_no 获取所有员工当前的manager 表的内外…

聊一下CSS中的标准流,浮动流,文本流,文档流

在网络上关于CSS的文章中&#xff0c;有时候能听到“标准流”&#xff0c;“浮动流”&#xff0c;“定位流”等等词语&#xff0c;还有像“文档流”&#xff0c;“文本流”等词&#xff0c;这些流是什么意思&#xff1f;它们是CSS中的一些布局方案和特性。今天我们就来聊一下CS…