从代码学习深度学习 - 情感分析:使用卷积神经网络 PyTorch版

文章目录

  • 前言
  • 加载数据集
  • 一维卷积
  • 最大时间汇聚层
  • textCNN模型
    • 定义模型
    • 加载预训练词向量
    • 训练和评估模型
  • 总结


前言

在之前的章节中,我们探讨了如何使用循环神经网络(RNN)来处理序列数据。今天,我们将探索另一种强大的模型——卷积神经网络(CNN)——并将其应用于自然语言处理中的经典任务:情感分析。

你可能会觉得奇怪,CNN不是主要用于图像处理的吗?确实,CNN在计算机视觉领域取得了巨大的成功,它通过二维卷积核捕捉图像的局部特征(如边缘、纹理)。但如果我们换个角度思考,文本序列可以被看作是一维的“图像”,其中每个词元(token)就是一个“像素”。这样,我们就可以使用一维卷积来捕捉文本中的局部模式,比如由相邻单词组成的n-gram

本篇博客将详细介绍如何使用 textCNN 模型,这是一种专为文本分类设计的CNN架构。我们将基于IMDb电影评论数据集,训练一个能够判断评论是正面还是负面的模型。整个流程如下图所示,我们将使用预训练的GloVe词向量作为输入,将其送入textCNN模型,最终得到情感分类结果。

在这里插入图片描述

让我们开始吧!首先,我们需要加载所需的数据集。

完整代码:下载链接

加载数据集

我们仍然使用IMDb电影评论数据集。通过我们预先准备好的 utils_for_data.load_data_imdb 辅助函数,我们可以方便地加载训练和测试数据迭代器,以及一个根据训练数据构建好的词汇表(vocab)。

import torch
import utils_for_data
from torch import nnbatch_size = 64
train_iter, test_iter, vocab = utils_for_data.load_data_imdb(batch_size)

一维卷积

在深入textCNN模型之前,我们先来回顾一下一维卷积是如何工作的。它本质上是二维卷积在只有一个维度(时间或序列步长)上的特例。

如下图所示,卷积窗口(或称为卷积核)在一个一维输入张量上从左到右滑动。在每个位置,输入子张量与核张量进行逐元素相乘,然后求和,得到输出张量中对应位置的一个标量值。例如,图中第一个输出值 2 是通过 0*1 + 1*2 = 2 计算得出的。

在这里插入图片描述

我们可以通过代码来实现这个一维互相关(corr1d)运算,代码中详尽的注释解释了每一步的维度变化和操作目的。

import torchdef corr1d(X, K):"""实现一维互相关(卷积)运算参数:X: 输入张量,维度为 (n,) 其中 n 是输入序列的长度K: 卷积核张量,维度为 (w,) 其中 w 是卷积核的长度返回:Y: 输出张量,维度为 (n - w + 1,) 其中 n-w+1 是输出序列的长度"""# 获取卷积核的长度,维度: 标量w = K.shape[0]# 创建输出张量,长度为输入长度减去卷积核长度加1# Y的维度: (X.shape[0] - w + 1,)Y = torch.zeros((X.shape[0] - w + 1))# 遍历输出张量的每个位置for i in range(Y.shape[0]):# 在第i个位置进行卷积运算# X[i: i + w] 的维度: (w,) - 提取输入序列的一个窗口# K 的维度: (w,) - 卷积核# 两者逐元素相乘后求和得到标量结果Y[i] = (X[i: i + w] * K).sum()return Y# 测试代码
# X: 输入张量,维度 (7,) - 包含7个元素的一维张量
X = torch.tensor([0, 1, 2, 3, 4, 5, 6])# K: 卷积核张量,维度 (2,) - 包含2个元素的一维张量  
K = torch.tensor([1, 2])# 调用函数进行一维卷积运算
# 输出结果的维度: (7 - 2 + 1,) = (6,)
result = corr1d(X, K)
print(result)

输出结果与预期一致:

tensor([ 2.,  5.,  8., 11., 14., 17.])

在NLP中,词嵌入通常是多维的,这意味着我们的输入有多个通道。一维卷积同样可以处理多通道输入。此时,卷积核也需要有相同数量的输入通道。运算时,对每个通道分别执行一维互相关,然后将所有通道的结果相加,得到一个单通道的输出。

在这里插入图片描述

下面是多输入通道一维互相关的实现。

import torchdef corr1d_multi_in(X, K):"""实现多输入通道的一维互相关(卷积)运算参数:X: 多通道输入张量,维度为 (c, n) 其中 c 是输入通道数,n 是每个通道的序列长度K: 多通道卷积核张量,维度为 (c, w) 其中 c 是输入通道数,w 是卷积核的长度返回:result: 输出张量,维度为 (n - w + 1,) 其中 n-w+1 是输出序列的长度"""# 遍历X和K的第0维(通道维),对每个通道分别进行一维卷积,然后求和# X的维度: (c, n) - c个通道,每个通道长度为n# K的维度: (c, w) - c个通道,每个通道的卷积核长度为w# zip(X, K) 将对应通道的输入和卷积核配对# 每次corr1d(x, k)的结果维度: (n - w + 1,)# sum()将所有通道的结果相加,最终输出维度: (n - w + 1,)return sum(corr1d(x, k) for x, k in zip(X, K))# 测试代码
# X: 多通道输入张量,维度 (3, 7) - 3个输入通道,每个通道包含7个元素
X = torch.tensor([[0, 1, 2, 3, 4, 5, 6],[1, 2, 3, 4, 5, 6, 7],[2, 3, 4, 5, 6, 7, 8]])# K: 多通道卷积核张量,维度 (3, 2) - 3个通道,每个通道的卷积核长度为2
K = torch.tensor([[1, 2], [3, 4], [-1, -3]])# 调用函数进行多通道一维卷积运算
# 输出结果的维度: (7 - 2 + 1,) = (6,)
result = corr1d_multi_in(X, K)
print(result)

输出结果:

tensor([ 2.,  8., 14., 20., 26., 32.])

有趣的是,多输入通道的一维互相关等价于单输入通道的二维互相关,只要将二维卷积核的高度设置为与输入张量的高度相同即可,如下图所示。

在这里插入图片描述

最大时间汇聚层

在卷积层之后,textCNN使用了一个称为最大时间汇聚层(Max-over-time Pooling)的关键组件。卷积操作的输出长度依赖于输入序列和卷积核的宽度,导致不同卷积核产生的输出序列长度不同。最大时间汇聚层的作用是在时间步(序列长度)维度上取最大值。这相当于从每个卷积核提取的特征图中,只保留最强烈的信号。无论输入序列多长,经过这个操作后,每个通道都只会输出一个标量值,从而解决了不同卷积核输出维度不一的问题,并生成了用于分类的固定长度的特征向量。

textCNN模型

理解了一维卷积和最大时间汇聚后,我们就可以构建textCNN模型了。整个模型的架构如下图所示:

在这里插入图片描述

输入是一个句子,每个词元由一个多维向量表示。我们定义了多种不同宽度的卷积核(

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/86640.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/86640.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入解析分布式训练基石:ps-lite源码实现原理

分布式机器学习框架是现代推荐、广告和搜索系统的核心支撑。面对海量训练数据和高维稀疏特征,参数服务器(Parameter Server, PS) 架构应运而生。作为早期经典实现的ps-lite因其简洁性和完整性,成为理解PS原理的绝佳切入点。本文将…

IDEA 插件开发:Internal Actions 与 UI Inspector 快速定位 PSI

在开发 IntelliJ 平台插件的过程中,你常常需要搞清楚 某个 IDE 弹框背后是如何操作 PSI(Program Structure Interface) 的。下面这篇笔记将介绍如何通过 Internal Actions、UI Inspector 以及调试技巧快速定位 PSI 调用链。 1. 启用 Internal…

26考研|数学分析:多元函数微分学

前言 本章我们将进行多元函数微分学的学习,多元函数微分学与一元函数微分学相对应,涉及到可微性、中值定理、泰勒公式等诸多问题的探讨与研究,本章难度较大,在学习过程中需要进行深度思考与分析,才能真正掌握这一章的…

数星星--二分

https://www.matiji.net/exam/brushquestion/17/4498/F16DA07A4D99E21DFFEF46BD18FF68AD 二分思路不难&#xff0c;关键的区间内个数的确定 #include<bits/stdc.h> using namespace std; #define N 100011 #define inf 0x3f3f3f3f typedef long long ll; typedef pair&…

Oracle/PostgreSQL/MSSQL/MySQL函数实现对照表

函数列表清单 函数作用OraclePOSTGRESQLMSSQLMYSQL求字符串长度LENGTH(str)LENGTH(str)LEN(str)LENGTH(str)字符切割SUBSTR(str,index,length)SUBSTR(str,index,length)SUBSTRING(str,index,length)SUBSTRING(str,index,length)字符串连接str1||str2||str3...strNstr1||str2||…

pycharm客户端安装教程

二、 pycharm客户端安装 打开pycharm官网&#xff1a;https://www.jetbrains.com/pycharm/download/?sectionwindows 选择其他版本 选择2018社区版本&#xff0c;点击下载 双击下载的安装程序(第一个弹框允许)&#xff0c;选择下一步 更改安装路径&#xff0c;在pycah…

博图SCL语言中用户自定义数据类型(UDT)使用详解

博图SCL语言中用户自定义数据类型&#xff08;UDT&#xff09;使用详解 一、UDT概述 用户自定义数据类型&#xff08;UDT&#xff09;是TIA Portal中强大的结构化工具&#xff0c;允许将多个相关变量组合成单一数据结构。UDT本质是可重用的数据模板&#xff0c;具有以下核心优…

Vscode自定义代码快捷方式

首选项>配置代码片段 >新建全局代码片段 (也可以选择你的语言 为了避免有的时候不生效 选择全局代码) {"console.log": { //名字"prefix": "log",//prefix 快捷键 &#xff1a; log"body": ["console.log($1);", //b…

ESP32 008 MicroPython Web框架库 Microdot 实现的网络文件服务器

以下是整合了所有功能的完整 main.py(在ESP32 007 MicroPython 适用于 Python 和 MicroPython 的小型 Web 框架库 Microdot基础上)&#xff0c;实现了&#xff1a; Wi‑Fi 自动连接&#xff08;支持静态 IP&#xff09;&#xff1b;SD 卡挂载&#xff1b;从 /sd/www/ 读取 HTML…

Mcp-git-ingest Quickstart

目录 配置例子 文档github链接&#xff1a;git_ingest.md 配置 {"mcpServers": {"mcp-git-ingest": {"command": "uvx","args": ["--from", "githttps://github.com/adhikasp/mcp-git-ingest", "…

(LeetCode 面试经典 150 题) 27.移除元素

目录 题目&#xff1a; 题目描述&#xff1a; 题目链接&#xff1a; 思路&#xff1a; 核心思路&#xff1a; 思路详解&#xff1a; 样例模拟&#xff1a; 代码&#xff1a; C代码&#xff1a; Java代码&#xff1a; 题目&#xff1a; 题目描述&#xff1a; 题目链接…

MySQL之事务原理深度解析

MySQL之事务原理深度解析 一、事务基础&#xff1a;ACID特性的本质1.1 事务的定义与核心作用1.2 ACID特性的内在联系 二、原子性与持久性的基石&#xff1a;日志系统2.1 Undo Log&#xff1a;原子性的实现核心2.2 Redo Log&#xff1a;持久性的保障2.3 双写缓冲&#xff08;Dou…

JUC:5.start()与run()

这两个方法都可以使线程进行运行&#xff0c;但是start只能用于第一次运行线程&#xff0c;后续要继续运行该线程需要使用run()方法。如果多次运行start()方法&#xff0c;会出现报错。 初次调用线程使用run()方法&#xff0c;无法使线程运行。 如果你对一个 Thread 实例直接调…

微服务中解决高并发问题的不同方法!

如果由于流量大而在短时间内几乎同时发出请求&#xff0c;或者由于服务器不稳定而需要很长时间来处理请求&#xff0c;并发问题可能会导致数据完整性问题。 示例问题情况 让我们假设有一个逻辑可以检索产品的库存并将库存减少一个&#xff0c;如上所述。此时&#xff0c;两个请…

【2025CCF中国开源大会】OpenChain标准实践:AI时代开源软件供应链安全合规分论坛重磅来袭!

点击蓝字 关注我们 CCF Opensource Development Committee 在AI时代&#xff0c;软件供应链愈发复杂&#xff0c;从操作系统到开发框架&#xff0c;从数据库到人工智能工具&#xff0c;开源无处不在。AI 与开源生态深度融合&#xff0c;在为软件行业带来前所未有的创新效率的同…

[Java实战]springboot3使用JDK21虚拟线程(四十)

[Java实战]springboot3使用JDK21虚拟线程(四十) 告别线程池爆满、内存溢出的噩梦!JDK21 虚拟线程让高并发连接变得触手可及。本文将带你深入实战,见证虚拟线程如何以极低资源消耗轻松应对高并发压测。 一、虚拟线程 传统 Java 线程(平台线程)与 OS 线程 1:1 绑定,创建和…

SpringBoot 中使用 @Async 实现异步调用​

​ ​ SpringBoot 中使用 Async 实现异步调用 一、Async 注解的使用场合​二、Async 注解的创建与调试​三、Async 注解的注意事项​四、总结​ 在高并发、高性能要求的应用场景下&#xff0c;异步处理能够显著提升系统的响应速度和吞吐量。Spring Boot 提供的 Async 注解为开…

CMOS SENSOR HDR场景下MIPI 虚拟端口的使用案例

CMOS SENSOR HDR场景下MIPI 虚拟端口的使用案例 文章目录 CMOS SENSOR HDR场景下MIPI 虚拟端口的使用案例📷 **一、HDR模式下的虚拟通道核心作用**⚙️ **二、典型应用案例****1. 车载多目HDR系统****2. 工业检测多模态HDR****3. 手机多摄HDR合成**🔧 **三、实现关键技术点…

RJ45 以太网与 5G 的原理解析及区别

一、RJ45 以太网的原理 1. RJ45 接口与以太网的关系 RJ45 是一种标准化的网络接口&#xff0c;主要用于连接以太网设备&#xff08;如电脑、路由器&#xff09;&#xff0c;其物理形态为 8 针模块化接口&#xff0c;适配双绞线&#xff08;如 CAT5、CAT6 网线&#xff09;。以…

valkey之sdscatrepr 函数优化解析

一、函数功能概述 sds sdscatrepr(sds s, const char *p, size_t len)函数的核心功能是将字符串p追加到字符串s中。在追加过程中&#xff0c;它会对字符串p中的字符进行判断&#xff0c;使用isprint()函数识别不可打印字符&#xff0c;并对这些字符进行转义处理&#xff0c;确…