【NLP入门系列五】中文文本分类案例

在这里插入图片描述

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

博主简介:努力学习的22级本科生一枚 🌟​;探索AI算法,C++,go语言的世界;在迷茫中寻找光芒​🌸
博客主页:羊小猪~~-CSDN博客
内容简介:这一篇是NLP的入门项目,中文文本分类案例。
🌸箴言🌸:去寻找理想的“天空“”之城
上一篇内容:【NLP入门系列四】评论文本分类入门案例-CSDN博客
​💁​​💁​​💁​​💁​: NLP数据格式的构建确实比较难,这里卡住的主要是文本向量化的数据格式

文章目录

    • 1、数据准备
    • 2、类别标签化
    • 3、数据加载与词典构建
      • 数据加载
      • 构建词典
      • 文本向量化
      • 数据加载
    • 4、模型构建
    • 5、训练和测试函数
    • 6、模型训练
    • 7、结果展示
    • 8、结果测试

1、数据准备

import pandas as pd
import torchtext
import torch  
import torch.nn as nn 
from torch.utils.data import DataLoader, Dataset 
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iteratordevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 读取数据
data_df = pd.read_csv("./train.csv", sep='\t', header=None)# 添加标题
data_df.columns = ["content", "label"]data_df.head()
contentlabel
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
t = data_df['label'].unique()
classesNum = len(t)
print("classes: ", t)
print("classes num: ",classesNum)
classes:  ['Travel-Query' 'Music-Play' 'FilmTele-Play' 'Video-Play' 'Radio-Listen''HomeAppliance-Control' 'Weather-Query' 'Alarm-Update' 'Calendar-Query''TVProgram-Play' 'Audio-Play' 'Other']
classes num:  12

2、类别标签化

from sklearn.preprocessing import LabelEncoder labels = data_df['label']# 创建LabelEncoder
label_encoder = LabelEncoder()# 拟合标签
encoded_list = label_encoder.fit_transform(labels)# 编码后标签
data_df["labelToNum"] = encoded_listclasses = {}
for name, idx in zip(label_encoder.classes_, range(len(label_encoder.classes_))):classes[idx] = nameprint(classes)
{0: 'Alarm-Update', 1: 'Audio-Play', 2: 'Calendar-Query', 3: 'FilmTele-Play', 4: 'HomeAppliance-Control', 5: 'Music-Play', 6: 'Other', 7: 'Radio-Listen', 8: 'TVProgram-Play', 9: 'Travel-Query', 10: 'Video-Play', 11: 'Weather-Query'}
data_df.head()
contentlabellabelToNum
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query9
1从这里怎么回家Travel-Query9
2随便播放一首专辑阁楼里的佛里的歌Music-Play5
3给看一下墓王之王嘛FilmTele-Play3
4我想看挑战两把s686打突变团竞的游戏视频Video-Play10

3、数据加载与词典构建

数据加载

# 定义数据格式
class MyDataSet(Dataset):def __init__(self, dataframe):self.labels = dataframe["labelToNum"].tolist()self.texts = dataframe["content"].tolist()def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.labels[idx], self.texts[idx]# 加载数据
data = MyDataSet(data_df)
for label, text in data:print(label)print(text)break
9
还有双鸭山到淮阴的汽车票吗13号的

构建词典

import jieba# 设置中文分词
tokenizer = jieba.lcut# 返回文本数据中词汇
def yield_tokens(data_iter):for _, text in data_iter:  # 注意返回类型# 分词text = tokenizer(text)yield text # 构建词典
vocab = build_vocab_from_iterator(yield_tokens(data), specials=["<unk>"])# 设置索引
vocab.set_default_index(vocab["<unk>"])print("Vocab size:", len(vocab))
Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\WY118C~1\AppData\Local\Temp\jieba.cache
Loading model cost 0.456 seconds.
Prefix dict has been built successfully.
Vocab size: 11147

文本向量化

# 向量化函数
text_vector = lambda x : vocab(tokenizer(x))
label_num = lambda x : int(x)# EmbeddingBag嵌入格式创建
def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (label_, text_) in batch:# 标签label_list.append(label_num(label_))# 文本temp = torch.tensor(text_vector(text_), dtype=torch.int64)text_list.append(temp)# 偏移量offsets.append(temp.size(0))  # 注意:第一个维度哦label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)  # 堆叠, 一个一个维度堆叠,注意:这里易错,这里一定要明白这里的格式 “一行为一个文本”offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)return label_list.to(device), text_list.to(device), offsets.to(device)

数据加载

# 分割数据
train_size = int(len(data) * 0.8)
test_size = len(data) - train_size
train_data, test_data = torch.utils.data.random_split(data, [train_size, test_size])batch_size = 16# 动态加载
train_dl = DataLoader(train_data,batch_size=batch_size,shuffle=True,collate_fn=collate_batch
)test_dl = DataLoader(test_data,batch_size=batch_size,shuffle=False,collate_fn=collate_batch
)

4、模型构建

class TextModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()# 注意:这里是简单入门案例,没用rnn、lstm这些,如果用这些模型这需要用embedding,才能更好捕捉序列信息self.embeddingBag = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,   # 嵌入维度sparse=False)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()# 初始化权重def init_weights(self):initrange = 0.5self.embeddingBag.weight.data.uniform_(-initrange, initrange)  # 初始化权重范围self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()  # 偏置置为0def forward(self, text, offsets):embedding = self.embeddingBag(text, offsets)return self.fc(embedding)
vocab_len = len(vocab)
embed_dim = 64  # 嵌入到64维度中
model = TextModel(vocab_size=vocab_len, embed_dim=embed_dim, num_class=classesNum).to(device=device)

5、训练和测试函数

def train(model, dataset, optimizer, loss_fn):size = len(dataset.dataset)num_batch = len(dataset)train_acc = 0train_loss = 0for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict_label = model(text, offset)loss = loss_fn(predict_label, label)# 求导与反向传播optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (predict_label.argmax(1) == label).sum().item()train_loss += loss.item()train_acc /= size train_loss /= num_batchreturn train_acc, train_lossdef test(model, dataset, loss_fn):size = len(dataset.dataset)batch_size = len(dataset)test_acc, test_loss = 0, 0with torch.no_grad():for _, (label, text, offset) in enumerate(dataset):label, text, offset = label.to(device), text.to(device), offset.to(device)predict = model(text, offset)loss = loss_fn(predict, label) test_acc += (predict.argmax(1) == label).sum().item()test_loss += loss.item()test_acc /= size test_loss /= batch_sizereturn test_acc, test_loss

6、模型训练

import copy# 超参数设置
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01)  # 动态调整学习率epochs = 10train_acc, train_loss, test_acc, test_loss = [], [], [], []best_acc = 0for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(model, train_dl, optimizer, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)model.eval()epoch_test_acc, epoch_test_loss = test(model, test_dl, loss_fn)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)if best_acc is not None and epoch_test_acc > best_acc:# 动态调整学习率scheduler.step()best_acc = epoch_test_accbest_model = copy.deepcopy(model)  # 保存模型# 当前学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,  epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件
path = './best_model.pth'
torch.save(best_model.state_dict(), path) # 保存模型参数
Epoch: 1, Train_acc:61.3%, Train_loss:1.450, Test_acc:77.2%, Test_loss:0.867, Lr:5.00E-01
Epoch: 2, Train_acc:80.4%, Train_loss:0.713, Test_acc:83.1%, Test_loss:0.585, Lr:5.00E-01
Epoch: 3, Train_acc:85.3%, Train_loss:0.516, Test_acc:85.7%, Test_loss:0.477, Lr:5.00E-01
Epoch: 4, Train_acc:88.2%, Train_loss:0.410, Test_acc:87.9%, Test_loss:0.414, Lr:5.00E-01
Epoch: 5, Train_acc:90.4%, Train_loss:0.338, Test_acc:89.3%, Test_loss:0.379, Lr:5.00E-03
Epoch: 6, Train_acc:92.1%, Train_loss:0.293, Test_acc:89.4%, Test_loss:0.378, Lr:5.00E-03
Epoch: 7, Train_acc:92.2%, Train_loss:0.291, Test_acc:89.5%, Test_loss:0.377, Lr:5.00E-03
Epoch: 8, Train_acc:92.2%, Train_loss:0.290, Test_acc:89.4%, Test_loss:0.376, Lr:5.00E-03
Epoch: 9, Train_acc:92.2%, Train_loss:0.289, Test_acc:89.3%, Test_loss:0.376, Lr:5.00E-03
Epoch:10, Train_acc:92.3%, Train_loss:0.288, Test_acc:89.3%, Test_loss:0.375, Lr:5.00E-03

7、结果展示

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epoch_length = range(epochs)plt.figure(figsize=(12, 3))plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')plt.show()


在这里插入图片描述

8、结果测试

model.load_state_dict(torch.load("./best_model.pth"))
model.eval() # 模型评估# 测试句子
test_sentence = "还有双鸭山到淮阴的汽车票吗13号的"# 转换为 token
token_ids = vocab(tokenizer(test_sentence))   # 切割分词--> 词典序列
text = torch.tensor(token_ids, dtype=torch.long).to(device)  # 转化为tensor
offsets = torch.tensor([0], dtype=torch.long).to(device)# 测试,注意:不需要反向求导
with torch.no_grad():output = model(text, offsets)predicted_label = output.argmax(1).item()print(f"预测类别: {classes[predicted_label]}")
预测类别: Travel-Query

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/913165.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/913165.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【微信小程序】- 监听全局globalData数据

【微信小程序】- 监听全局globalData数据 数据劫持&#xff08;Object.defineProperty&#xff09;实现适用场景 数据劫持&#xff08;Object.defineProperty&#xff09; 实现 通过拦截 globalData 的属性读写实现自动监听&#xff0c;适合精确监听特定变量。 ​实现步骤​&…

高速公路闲置土地资源化利用:广西浦北互通3MW分布式光伏监控实践

摘要&#xff1a; 分布式光伏项目在清洁能源转型中扮演重要角色&#xff0c;其创新的空间利用模式有助于缓解能源开发与土地资源间的矛盾。广西大唐至浦北高速公路&#xff08;浦北互通&#xff09;项目&#xff0c;利用高速公路沿线闲置空地建设光伏电站&#xff0c;发挥了分布…

【Linux网络编程】网络基础

目录 计算机网络背景 初识协议 网络协议 协议分层 OSI七层模型 TCP/IP五层(或四层)模型 再识协议 为什么要有TCP/IP协议&#xff1f; 什么是TCP/IP协议&#xff1f; 重谈协议 网络传输基本流程 局域网传输流程 跨网络传输流程 Socket编程预备 理解源IP地址与目的…

BlenderBot对话机器人大模型Facebook开发

文章目录 &#x1f680; BlenderBot 的关键特性&#x1f9ea; 版本与改进&#x1f4ca; 应用实例 Blender是搅拌机&#xff0c;果汁机&#xff0c;混合机的意思。 BlenderBot 是由 Facebook AI Research (FAIR) 开发的一种先进的对话生成模型。它旨在通过融合多种对话技能&…

60天python训练计划----day59

在之前的学习中&#xff0c;我们层层递进的介绍了时序模型的发展&#xff0c;从AR到MA到ARMA&#xff0c;再到ARIMA。本质就是把数据处理的操作和模型结合在一起了&#xff0c;实际上昨天提到的季节性差分也可以合并到模型中&#xff0c;让流程变得更加统一。 季节性差分用S来…

学习日志05 python

我相信事在人为&#xff0c;人定胜天&#xff0c;现在还是在基础语法上面打转&#xff0c;还是会提出一些很低级的很基础的问题&#xff0c;不要着急&#xff0c;波浪式前进、螺旋式上升的过程吧&#xff0c;虽然现在的确是很绝望吧...... 今天要做一个练习&#xff1a;编写猜…

LiteHub中间件之gzip算法

gzip算法理论部分LZ777算法霍夫曼编码算法改进型的LZ777算法代码实现压缩对象gzip实现运行分析日志查看wireshark抓包查看后台管理界面查看理论部分 gzip是一种无损压缩算法&#xff0c;其基础为Deflate&#xff0c;Deflate是LZ77与哈弗曼编码的一个组合体。它的基本原理是&…

java+vue+SpringBoo校园失物招领网站(程序+数据库+报告+部署教程+答辩指导)

源代码数据库LW文档&#xff08;1万字以上&#xff09;开题报告答辩稿ppt部署教程代码讲解代码时间修改工具 技术实现 开发语言&#xff1a;后端&#xff1a;Java 前端&#xff1a;vue框架&#xff1a;springboot数据库&#xff1a;mysql 开发工具 JDK版本&#xff1a;JDK1.…

Qt Quick 与 QML(五)qml中的布局

QML布局系统主要分为三大类&#xff1a;锚布局、定位器布局、布局管理器。一、锚布局&#xff08;Anchors&#xff09;通过定义元素与其他元素或父容器的锚点关系实现精确定位&#xff0c;支持动态调整。核心特性属性‌‌作用‌‌示例‌anchors.left左边缘对齐目标元素anchors.…

【Java|集合类】list遍历的6种方式

本文主要是总结一下Java集合类中List接口的遍历方式&#xff0c;以下面的list为例&#xff0c;为大家讲解遍历list的6种方式。 List<Integer> list new ArrayList<>();list.add(1);list.add(2);list.add(3);list.add(4);list.add(5);文章目录1.直接输出2.for循环遍…

博弈论基础-笔记

取石子1 性质一&#xff1a;12345可以确定先手赢&#xff0c;6不论取那个质数都输&#xff0c;789 10 11可以分别取12345变成6 性质二&#xff1a;6的倍数一定不能取出之后还是6的倍数&#xff08;不能转换输态&#xff09; #include <bits/stdc.h> using namespace st…

多任务学习-ESMM

简介 ESMM&#xff08;Entire Space Multi-task Model&#xff09;是2018年阿里巴巴提出的多任务学习模型。基于共享的特征表达和在用户整个行为序列空间上的特征提取实现对CTR、CVR的联合训练 解决的问题 SSB&#xff08;sample selection bias&#xff09; 如下图1所示&am…

K8S 集群配置踩坑记录

系统版本&#xff1a;Ubuntu 22.04.5-live-server-amd64 K8S 版本&#xff1a;v1.28.2 Containerd 版本&#xff1a; 1.7.27 kubelet logs kuberuntime_sandbox.go:72] "Failed to create sandbox for pod" err"rpc error: code Unknown desc failed to cre…

超滤管使用与操作流程-实验操作013

超滤管使用与操作流程 超滤管&#xff08;或蛋白浓缩管&#xff09;是一种重要的实验设备&#xff0c;广泛应用于分离与纯化大分子物质&#xff0c;尤其是蛋白质、多糖和核酸等。其工作原理依赖于超滤技术&#xff0c;通过半透膜对分子进行筛分&#xff0c;精准地将大分子物质…

GitHub已破4.5w star,从“零样本”到“少样本”TTS,5秒克隆声音,冲击传统录音棚!

嗨&#xff0c;我是小华同学&#xff0c;专注解锁高效工作与前沿AI工具&#xff01;每日精选开源技术、实战技巧&#xff0c;助你省时50%、领先他人一步。&#x1f449;免费订阅&#xff0c;与10万技术人共享升级秘籍&#xff01;你是否为录音成本高、声音不灵活、又想为多语言…

【中文核心期刊推荐】《遥感信息》

《遥感信息》&#xff08;CN&#xff1a;11-5443/P&#xff09;是一份具有较高学术价值的双月刊期刊&#xff0c;自创刊以来&#xff0c;凭借新颖的选题和广泛的报道范围&#xff0c;兼顾了大众服务和理论深度&#xff0c;深受学术界和广大读者的关注与好评。 该期刊创办于1986…

uniapp微信小程序css中background-image失效问题

项目场景&#xff1a;提示&#xff1a;这里简述项目相关背景&#xff1a;在用uniapp做微信小程序的时候&#xff0c;需要一张背景图&#xff0c;用的是当时做app的时候的框架&#xff0c;但是&#xff0c;在class的样式中background-image失效了&#xff0c;查了后才知道&#…

iOS App无源码安全加固实战:如何对成品IPA实现结构混淆与资源保护

在很多iOS项目交付中&#xff0c;开发者或甲方并不总能拿到应用源码。例如外包项目交付成品包、历史项目维护、或者仅负责分发渠道的中间商&#xff0c;都需要在拿到成品ipa文件后对其进行安全加固。然而传统的源码级混淆方法&#xff08;如LLVM Obfuscator、Swift Obfuscator&…

Java 中的 ArrayList 和 LinkedList 区别详解(源码级理解)

&#x1f680; Java 中的 ArrayList 和 LinkedList 区别详解&#xff08;源码级理解&#xff09; 在日常 Java 开发中&#xff0c;ArrayList 和 LinkedList 是我们经常用到的两种 List 实现。虽然它们都实现了 List 接口&#xff0c;但在底层结构、访问效率、插入/删除操作、扩…

使用OpenLayers调用geoserver发布的wms服务

1.前端vue3调用代码 <template><div><div ref"mapContainer" class"map"></div></div> </template><script setup lang"ts"> import { ref, onMounted } from "vue"; import Map from &quo…