在 PyTorch 中借助 GloVe 词嵌入完成情感分析

一. Glove 词嵌入原理

GloVe是一种学习词嵌入的方法,它希望拟合给定上下文单词i时单词j出现的次数x_{ij}。使用的误差函数为:

\sum_{i=1}^{N}\sum_{j=1}^{N}f(x_{ij})(\theta _{j}^{T}e_{i}+b_{i}+b_{j}^{'}-logx_{ij})

其中N是词汇表大小,\theta ,b是线性层参数,e_{i} 是词嵌入。f(x)是权重项,用于平衡不同频率的单词对误差的影响,并消除log0时式子不成立情况。

GloVe作者提供了官方的预训练词嵌入(https://nlp.stanford.edu/projects/glove/ )。预训练的GloVe有好几个版本,按数据来源,可以分成:

  • 维基百科+gigaword(6B)
  • 爬虫(42B)
  • 爬虫(840B)
  • 推特(27B)

按照词嵌入向量的大小分,又可以分成50维,100维,200维等不同维度。

预训练GloVe的文件格式非常简明,一行代表一个单词向量,每行先是一个单词,再是若干个浮点数,表示该单词向量的每一个元素。

在Pytorch里,我们不必自己去下载解析GloVe,而是可以直接调用Pytorch库自动下载解析GloVe。首先我们要安装Pytorch的NLP库-- torchtext。

如上所述,GloVe版本可以由其数据来源和向量维数确定,在构建GloVe类时,要提供这两个参数,我们选择的是6B token,维度100的GloVe

调用glove.get_vecs_by_tokens,我们能够把token转换成GloVe里的向量。

import torch

from torchtext.vocab import GloVe

glove = GloVe(name='6B', dim=100)

# Get vectors

tensor = glove.get_vecs_by_tokens(['', '1998', '199999998', ',', 'cat'], True)

print(tensor)

PyTorch提供的这个函数非常方便。如果token不在GloVe里的话,该函数会返回一个全0向量。如果你运行上面的代码,可以观察到一些有趣的事:空字符串和199999998这样的不常见数字不在词汇表里,而1998这种常见的数字以及标点符号都在词汇表里。

GloVe类内部维护了一个矩阵,即每个单词向量的数组。因此,GloVe需要一个映射表来把单词映射成向量数组的下标。glove.itosglove.stoi完成了下标与单词字符串的相互映射。比如用下面的代码,我们可以知道词汇表的大小,并访问词汇表的前几个单词:

myvocab = glove.itos
print(len(myvocab))
print(myvocab[0], myvocab[1], myvocab[2], myvocab[3])

最后,我们来通过一个实际的例子认识一下词嵌入的意义。词嵌入就是向量,向量的关系常常与语义关系对应。利用词嵌入的相对关系,我们能够回答“x1之于y1,相当于x2之于谁?”这种问题。比如,男人之于女人,相当于国王之于王后。设我们要找的向量为y2,我们想让x1-y1=x2-y2,即找出一个和x2-(x1-y1)最相近的向量y2出来。这一过程可以用如下的代码描述:

def get_counterpart(x1, y1, x2):x1_id = glove.stoi[x1]y1_id = glove.stoi[y1]x2_id = glove.stoi[x2]#print("x1:",x1,"y1:",y1,"x2:",x2)x1, y1, x2 = glove.get_vecs_by_tokens([x1, y1, x2],True)#print("x1:",x1,"y1:",y1,"x2:",x2)target = x2 - x1 + y1max_sim =0 max_id = -1for i in range(len(myvocab)):vector = glove.get_vecs_by_tokens([myvocab[i]],True)[0]cossim = torch.dot(target, vector)if cossim > max_sim and i not in {x1_id, y1_id, x2_id}:max_sim = cossimmax_id = ireturn myvocab[max_id]
print(get_counterpart('man', 'woman', 'king'))
print(get_counterpart('more', 'less', 'long'))
print(get_counterpart('apple', 'red', 'banana'))

运行结果: 

queen

short

yellow

二.基于GloVe的情感分析

情感分析任务与数据集

和猫狗分类类似,情感分析任务是一种比较简单的二分类NLP任务:给定一段话,输出这段话的情感是积极的还是消极的。

比如下面这段话:

I went and saw this movie last night after being coaxed to by a few friends of mine. I'll admit that I was reluctant to see it because from what I knew of Ashton Kutcher he was only able to do comedy. I was wrong. Kutcher played the character of Jake Fischer very well, and Kevin Costner played Ben Randall with such professionalism. ......

这是一段影评,大意说,这个观众本来不太想去看电影,因为他认为演员Kutcher只能演好喜剧。但是,看完后,他发现他错了,所有演员都演得非常好。这是一段积极的评论。

1. 读取数据集:

import os 
from torchtext.data import get_tokenizerdef read_imdb(dir='aclImdb', split = 'pos', is_train=True):subdir = 'train' if is_train else 'test'dir = os.path.join(dir, subdir, split)lines = []for file in os.listdir(dir):with open(os.path.join(dir, file), 'rb') as f:line = f.read().decode('utf-8')lines.append(line)return lineslines = read_imdb()
print('Length of the file:', len(lines))
print('lines[0]:', lines[0])
tokenizer = get_tokenizer('basic_english')
tokens = tokenizer(lines[0])
print('lines[0] tokens:', tokens)

output: 

2.获取经GloVe预处理的数据

在这个作业里,模型其实很简单,输入序列经过词嵌入,送入单层RNN,之后输出结果。作业最难的是如何把token转换成GloVe词嵌入。

torchtext其实还提供了一些更方便的NLP工具类(Field,Vectors),用于管理向量。但是,这些工具需要一定的学习成本,后续学习pytorch时再学习。

Pytorch通常用nn.Embedding来表示词嵌入层。nn.Embedding其实就是一个矩阵,每一行都是一个词嵌入,每一个token都是整型索引,表示该token再词汇表里的序号。有了索引,有了矩阵就可以得到token的词嵌入了。但是有些token在词汇表中并不存在,我们得对输入做处理,把词汇表里没有的token转换成<unk>这个表示未知字符的特殊token。同时为了对齐序列的长度,我们还得添加<pad>这个特殊字符。而用glove直接生成的nn.Embedding里没有<unk>和<pad>字符。如果使用nn.Embedding的话,我们要编写非常复杂的预处理逻辑。

为此,我们可以用GloVe类的get_vecs_by_tokens直接获取token的词嵌入,以代替nn.Embedding。回忆一下前文提到的get_vecs_by_tokens的使用结果,所有没有出现的token都会被转换成零向量。这样,我们就不必操心数据预处理的事了。get_vecs_by_tokens应该发生在数据读取之后,可以直接被写在Dataset的读取逻辑里

from torch.utils.data import DataLoader, Dataset
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVeclass IMDBDataset(Dataset):def __init__(self, is_train=True, dir = 'aclImdb'):super().__init__()self.tokenizer = get_tokenizer('basic_english')pos_lines = read_imdb(dir, 'pos', is_train)neg_lines = read_imdb(dir, 'neg', is_train)self.pos_length = len(pos_lines)self.neg_length = len(neg_lines)self.lines = pos_lines+neg_linesdef __len__(self):return self.pos_length + self.neg_lengthdef __getitem__(self, index):sentence = self.tokenizer(self.lines[index])x = glove.get_vecs_by_tokens(sentence)label = 1 if index < self.pos_length else 0return x, label

数据预处理的逻辑都在__getitem__里。每一段字符串会先被token化,之后由GLOVE.get_vecs_by_tokens得到词嵌入数组。 

3.对齐输入

使用一个batch的序列数据时常常会碰到序列不等长的问题。实际上利用Pytorch Dataloader的collate_fn机制有更简洁的实现方法。

from torch.nn.utils.rnn import pad_sequencedef get_dataloader(dir='aclImdb'):def collate_fn(batch):x, y = zip(*batch)x_pad = pad_sequence(x, batch_first=True)y = torch.Tensor(y)return x_pad, ytrain_dataloader = DataLoader(IMDBDataset(True, dir),batch_size=32,shuffle=True,collate_fn=collate_fn)test_dataloader = DataLoader(IMDBDataset(False, dir),batch_size=32,shuffle=True,collate_fn=collate_fn)return train_dataloader, test_dataloader

PyTorch DataLoader在获取Dataset的一个batch的数据时,实际上会先吊用Dataset.__getitem__获取若干个样本,再把所有样本拼接成一个batch,比如用__getitem__获取四个[4,3,10,10]这一个batch,可是序列数据通常长度不等,__getitem__可能会获得[10, 100][15, 100]这样不等长的词嵌入数组。

为了解决这个问题,我们要手动编写把所有张量拼成一个batch的函数。这个函数就是DataLoadercollate_fn函数。我们的collate_fn应该这样编写:

def collate_fn(batch):x, y = zip(*batch)x_pad = pad_sequence(x, batch_first=True)y = torch.Tensor(y)return x_pad, y

collate_fn的输入batch是每次__getitem__的结果的数组。比如在我们这个项目中,第一次获取了一个长度为10的积极的句子,__getitem__返回(Tensor[10, 100], 1);第二次获取了一个长度为15的消极的句子,__getitem__返回(Tensor[15, 100], 0)。那么,输入batch的内容就是:

[(Tensor[10, 100], 1), (Tensor[15, 100], 0)]

我们可以用x, y = zip(*batch)把它巧妙地转换成两个元组:

x = (Tensor[10, 100], Tensor[15, 100])
y = (1, 0)

之后,PyTorch的pad_sequence可以把不等长序列的数组按最大长度填充成一整个batch张量。也就是说,经过这个函数后,x_pad变成了:

x_pad = Tensor[2, 15, 100]

pad_sequencebatch_first决定了batch是否在第一维。如果它为False,则结果张量的形状是[15, 2, 100]

pad_sequence还可以决定填充内容,默认填充0。在我们这个项目中,被填充的序列已经是词嵌入了,直接用全零向量表示<pad>没问题。

有了collate_fn,构建DataLoader就很轻松了:

DataLoader(IMDBDataset(True, dir),batch_size=32,shuffle=True,collate_fn=collate_fn)

注意,使用shuffle=True可以令DataLoader随机取数据构成batch。由于我们的Dataset十分工整,前一半的标签是1,后一半是0,必须得用随机的方式去取数据以提高训练效率。 

4.模型

import torch.nn as nn
GLOVE_DIM = 100
GLOVE = GloVe(name = '6B', dim=GLOVE_DIM)
class RNN(torch.nn.Module):def __init__(self, hidden_units=64, dropout_rate = 0.5):super().__init__()self.drop = nn.Dropout(dropout_rate)self.rnn = nn.GRU(GLOVE_DIM, hidden_units, 1, batch_first=True)self.linear = nn.Linear(hidden_units,1)self.sigmoid = nn.Sigmoid()def forward(self, x:torch.Tensor):# x: [batch, max_word_length, embedding_length]emb = self.drop(x)output,_ = self.rnn(emb)output = output[:, -1]output = self.linear(output)output = self.sigmoid(output)return output

这里要注意一下,PyTorch的RNN会返回整个序列的输出。而在预测分类概率时,我们只需要用到最后一轮RNN计算的输出。因此,要用output[:, -1]取最后一次的输出。

5. 训练、测试、推理 

train_dataloader, test_dataloader = get_dataloader()
model = RNN()optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
citerion = torch.nn.BCELoss()for epoch in range(100):loss_sum = 0dataset_len = len(train_dataloader.dataset)for x, y in train_dataloader:batchsize = y.shape[0]hat_y = model(x)hat_y = hat_y.squeeze(-1)loss = citerion(hat_y, y)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()loss_sum += loss * batchsizeprint(f'Epoch{epoch}. loss :{loss_sum/dataset_len}')torch.save(model.state_dict(),'rnn.pth')

output: 

model.load_state_dict(torch.load('rnn.pth'))
accuracy = 0
dataset_len = len(test_dataloader.dataset)
model.eval()
for x, y in test_dataloader:with torch.no_grad():hat_y = model(x)hat_y.squeeze_(1)predictions = torch.where(hat_y>0.5,1,0)score = torch.sum(torch.where(predictions==y,1,0))accuracy += score.item()
accuracy /= dataset_lenprint(f'Accuracy:{accuracy}')   

Accuracy:0.90516

tokenizer = get_tokenizer('basic_english')
article = "U.S. stock indexes fell Tuesday, driven by expectations for tighter Federal Reserve policy and an energy crisis in Europe. Stocks around the globe have come under pressure in recent weeks as worries about tighter monetary policy in the U.S. and a darkening economic outlook in Europe have led investors to sell riskier assets."x = GLOVE.get_vecs_by_tokens(tokenizer(article)).unsqueeze(0)
with torch.no_grad():hat_y = model(x)
hat_y = hat_y.squeeze_().item()
result = 'positive' if hat_y > 0.5 else 'negative'
print(result)

negative

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/79706.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/79706.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

kotlin中 热流 vs 冷流 的本质区别

&#x1f525; 冷流&#xff08;Cold Flow&#xff09; vs 热流&#xff08;Hot Flow&#xff09;区别 特性冷流&#xff08;Cold Flow&#xff09;热流&#xff08;Hot Flow&#xff09;数据生产时机每次 collect 才开始执行启动时就开始生产、始终运行生命周期与 collect 者…

精益数据分析(44/126):深度解析媒体网站商业模式的关键要点

精益数据分析&#xff08;44/126&#xff09;&#xff1a;深度解析媒体网站商业模式的关键要点 在创业与数据分析的探索道路上&#xff0c;我们不断挖掘不同商业模式的核心要素&#xff0c;今天将深入剖析媒体网站商业模式。希望通过对《精益数据分析》相关内容的解读&#xf…

Android学习总结之Java和kotlin区别

一、空安全机制 真题 1&#xff1a;Kotlin 如何解决 Java 的 NullPointerException&#xff1f;对比两者在空安全上的设计差异 解析&#xff1a; 核心考点&#xff1a;Kotlin 可空类型系统&#xff08;?&#xff09;、安全操作符&#xff08;?./?:&#xff09;、非空断言&…

[Survey]Remote Sensing Temporal Vision-Language Models: A Comprehensive Survey

BaseInfo TitleRemote Sensing Temporal Vision-Language Models: A Comprehensive SurveyAdresshttps://arxiv.org/abs/2412.02573Journal/Time2024 arxivAuthor北航 上海AI LabCodehttps://github.com/Chen-Yang-Liu/Awesome-RS-Temporal-VLM 1. Introduction 传统遥感局限…

jmeter读取CSV文件中文乱码的解决方案

原因分析​ CSV文件出现中文乱码通常是因为文件编码与JMeter读取编码不一致。常见场景&#xff1a; 文件保存为GBK/GB2312编码&#xff0c;但JMeter以UTF-8读取。文件包含BOM头&#xff08;如Windows记事本保存的UTF-8&#xff09;&#xff0c;但JMeter未正确处理。脚本读取文…

Webview通信系统学习指南

Webview通信系统学习指南 一、定义与核心概念 1. 什么是Webview&#xff1f; 定义&#xff1a;Webview是移动端&#xff08;Android/iOS&#xff09;内置的轻量级浏览器组件&#xff0c;用于在原生应用中嵌入网页内容。作用&#xff1a;实现H5页面与原生应用的深度交互&…

【C++】C++中的命名/名字/名称空间 namespace

C中的命名/名字/名称空间 namespace 1、问题引入2、概念3、作用4、格式5、使用命名空间中的成员5.1 using编译指令&#xff08; 引进整个命名空间&#xff09; ---将这个盒子全部打开5.2 using声明使特定的标识符可用(引进命名空间的某个成员) ---将这个盒子中某个成员的位置打…

Arduino IDE中离线更新esp32 3.2.0版本的办法

在Arduino IDE中更新esp32-3.2.0版本是个不可能的任务&#xff0c;下载文件速度极慢。网上提供了离线的办法&#xff0c;提供了安装文件&#xff0c;但是没有3.2.0的版本。 下面提供了一种离线安装方法 一、腾讯元宝查询解决办法 通过打开开发板管理地址&#xff1a;通过在腾…

【工具使用-数据可视化工具】Apache Superset

1. 工具介绍 1.1. 简介 一个轻量级、高性能的数据可视化工具 官网&#xff1a;https://superset.apache.org/GitHub链接&#xff1a;https://github.com/apache/superset官方文档&#xff1a;https://superset.apache.ac.cn/docs/intro/ 1.2. 核心功能 丰富的可视化库&…

算法每日一题 | 入门-顺序结构-三角形面积

三角形面积 题目描述 一个三角形的三边长分别是 a、b、c&#xff0c;那么它的面积为 p ( p − a ) ( p − b ) ( p − c ) \sqrt{p(p-a)(p-b)(p-c)} p(p−a)(p−b)(p−c) ​&#xff0c;其中 p 1 2 ( a b c ) p\frac{1}{2}(abc) p21​(abc) 。输入这三个数字&#xff0c;…

MongoDB入门详解

文章目录 MongoDB下载和安装1.MongoDBCompass字段简介1.1 Aggregations&#xff08;聚合&#xff09;1.2 Schema&#xff08;模式分析&#xff09;1.3 Indexes&#xff08;索引&#xff09;1.4 Validation&#xff08;数据验证&#xff09; 2.增删改查操作2.1创建、删除数据库&…

从Oculus到Meta:Facebook实现元宇宙的硬件策略

Oculus的起步 Facebook在2014年收购了Oculus&#xff0c;这标志着其在虚拟现实&#xff08;VR&#xff09;领域的首次重大投资。Oculus Rift作为公司的旗舰产品&#xff0c;是一款高端的VR头戴设备&#xff0c;它为用户带来了沉浸式的体验。Facebook通过Oculus Rift&#xff0…

安装与配置Go语言开发环境 -《Go语言实战指南》

为了开始使用Go语言进行开发&#xff0c;我们首先需要正确安装并配置Go语言环境。Go的安装相对简单&#xff0c;支持多平台&#xff0c;包括Windows、macOS和Linux。本节将逐一介绍各平台的安装流程及环境变量配置方式。 一、Windows系统 1. 下载Go安装包 前往Go语言官网&…

网络的搭建

1、rpm rpm -ivh 2、yum仓库&#xff08;rpm包&#xff09;&#xff1a;网络源 ----》网站 本地源 ----》/dev/sr0 光盘映像文件 3、源码安装 源码安装&#xff08;编译&#xff09; 1、获取源码 2、检测环境生成Ma…

多元随机变量协方差矩阵

主要记录多元随机变量数字特征相关内容。 关键词&#xff1a;多元统计分析 一元随机变量 总体 随机变量Y 总体均值 μ E ( Y ) ∫ y f ( y ) d y \mu E(Y) \int y f(y) \, dy μE(Y)∫yf(y)dy 总体方差 σ 2 V a r ( Y ) E ( Y − μ ) 2 \sigma^2 Var(Y) E(Y - \…

Ros工作空间

工作空间其实放到嵌入式里就是相关的编程包 ------------------------------------- d第一个Init 就是类型的初始化 然后正常一个catkin_make 后 就会产生如devil之类的文件&#xff0c; 你需要再自己 终端 一个catkin_make install 一下 。这样对应install也会产生&#xf…

qt国际化翻译功能用法

文章目录 [toc]1 概述2 设置待翻译文本3 生成ts翻译源文件4 编辑ts翻译源文件5 生成qm翻译二进制文件6 加载qm翻译文件进行翻译 更多精彩内容&#x1f449;内容导航 &#x1f448;&#x1f449;Qt开发经验 &#x1f448; 1 概述 在 Qt 中&#xff0c;ts 文件和 qm 文件是用于国…

PyTorch 与 TensorFlow 中基于自定义层的 DNN 实现对比

深度学习双雄对决&#xff1a;PyTorch vs TensorFlow 自定义层大比拼 目录 深度学习双雄对决&#xff1a;PyTorch vs TensorFlow 自定义层大比拼一、TensorFlow 实现 DNN1. 核心逻辑 二、PyTorch 实现自定义层1. 核心逻辑 三、关键差异对比四、总结 一、TensorFlow 实现 DNN 1…

1ms城市算网稳步启航,引领数字领域的“1小时经济圈”效应

文 | 智能相对论 作者 | 陈选滨 为什么近年来国产动画、国产3A大作迎来了井喷式爆发&#xff1f;抛开制作水平以及市场需求的升级不谈&#xff0c;还有一个重要原因往往被大多数人所忽视&#xff0c;那就是新型信息的完善与成熟。 譬如&#xff0c;现阶段惊艳用户的云游戏以及…

【计算机视觉】语义分割:Segment Anything (SAM):通用图像分割的范式革命

Segment Anything&#xff1a;通用图像分割的范式革命 技术突破与架构创新核心设计理念关键技术组件 环境配置与快速开始硬件要求安装步骤基础使用示例 深度功能解析1. 多模态提示融合2. 全图分割生成3. 高分辨率处理 模型微调与定制1. 自定义数据集准备2. 微调训练配置 常见问…