Pytorch中文文本分类

本文为🔗365天深度学习训练营内部文章

原作者:K同学啊

 将对中文文本进行分类,示例如下:

 

文本分类流程图

 

 

1.加载数据 

import time
import pandas as pd
import torch
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torchvision
from torchtext.data import to_map_style_dataset
from torchvision import transforms,datasets
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba
import warningswarnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")'''
加载本地数据
'''
train_data = pd.read_csv('train.csv',sep='\t',header=None)
print(train_data.head())
# 构建数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,y
# train_data[0]是第一列(通常是文本),train_data[1]是第二列(通常是标签)
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

 

定义一个名为 coustom_data_iter 的函数,接收两个参数:

  • texts:文本数据(通常是句子或单词序列)

  • labels:对应的标签(分类任务中的目标值)

for x, y in zip(texts, labels):

  • zip(texts, labels):将 textslabels 按元素配对,返回一个迭代器,每次迭代返回 (text, label) 的组合。

  • 例如,如果 texts = ["hello", "world"]labels = [0, 1],那么 zip(texts, labels) 会生成 ("hello", 0)("world", 1)

yield x, y

  • yield 使这个函数变成一个 生成器(generator),每次迭代返回 (x, y) 对,而不是一次性返回所有数据。

  • 这种方式适合大数据集,因为它不会一次性加载所有数据到内存,而是按需生成。

2.数据预处理

1)构建词典  

# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   # 设置默认索引,如果找不到单词,则会选择默认索引label_name = list(set(train_data[1].values[:]))   # 将标签去重,添加到label_name列表中
print(label_name)text_pipeline = lambda x:vocab(tokenizer(x))
label_pipeline = lambda x:label_name.index(x)

 

  • def yield_tokens(data_iter): 定义一个生成器函数 yield_tokens,接收一个数据迭代器 data_iter(通常是 (text, label) 格式的迭代器)。

  • for text, _ in data_iter:

    • data_iter 每次返回 (text, label),这里用 _ 忽略标签(因为我们只需要文本)。

    • 例如,如果 data_iter[("hello world", 0), ("good morning", 1)],则 text 依次是 "hello world""good morning"

  • yield tokenizer(text)

    • tokenizer(text):对文本 text 进行分词(如拆分成单词列表)。

    • yield 返回分词后的结果(如 ["hello", "world"]["good", "morning"]),逐步生成数据流。

  • build_vocab_from_iterator

    • 是 PyTorch 的 torchtext.vocab 提供的函数,用于从迭代器构建词汇表。

    • 输入yield_tokens(data_iter) 生成的分词结果(如 ["hello", "world"], ["good", "morning"])。

    • 输出:一个 Vocab 对象,包含所有单词到索引的映射。

  • specials=["<unk>"]

    • 指定特殊符号 <unk>(unknown token),用于处理词汇表中不存在的单词。

    • 其他常见的特殊符号:

      • "<pad>":填充符号(用于统一序列长度)。

      • "<sos>":句子开始符。

      • "<eos>":句子结束符。

lambda 表达式的语法为:lambda arguments:expression 其中 arguments 是函数的参数,可以有多个参数,用逗号分隔。expression 是一个表达式,它定义了函数的返回值。 text_pipeline函数:将原始文本数据转换为整数列表,使用了之前构建的vocab词表和tokenizer分词器函数。具体来说,它接受一个字符串x作为输入,首先使用tokenizer将其分词,然后将每个词在vocab词表中的索引放入一个列表中返回。 label pipeline函数:将原始标签数据转换为整数,它接受一个字符串x作为输入,并使用 label_name.index(x)方法获取x在label name 列表中的索引作为输出。

2)生成数据批次和迭代器  

# 2.生成数据批次和迭代器
def collate_batch(batch):label_list,text_list,offsets = [],[],[0]for (_text,_label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)text_list.append(processed_text)# 偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回维度dim中输入元素的累计和return text_list.to(device),label_list.to(device),offsets.to(device)# 数据加载器
dataloader = DataLoader(train_iter,batch_size=8,shuffle=False,collate_fn=collate_batch)
  • 输入batch 是一个列表,其中每个元素是 (_text, _label) 对(来自 train_iter)。

  • 初始化

    • label_list:存储批次的标签。

    • text_list:存储分词后的文本(转换为整数索引)。

    • offsets:存储每个文本的长度(用于后续拼接),初始值为 [0]

  • offsets 的用途:

    • 记录每个文本的累计长度,用于后续将多个文本拼接成一个一维张量时定位每个样本的起始位置。

  1. label_list

    • 将标签列表转换为 PyTorch 张量(形状为 [batch_size])。

  2. text_list

    • torch.cat(text_list):将所有文本的索引拼接成一个一维张量

      • 例如,如果有两个文本 [1, 2][3, 4, 5],结果为 [1, 2, 3, 4, 5]

  3. offsets

    • offsets[:-1]:去掉初始的 [0],保留每个文本的长度(如 [2, 3])。

    • .cumsum(dim=0):计算累计和,得到每个文本在 text_list 中的起始位置。

      • 例如,[2, 3][2, 5],表示:

        • 第一个文本在 text_list 中的位置是 0:2

        • 第二个文本的位置是 2:5

3.构建模型

首先对文本进行嵌入,然后对句子进行嵌入之后的结果进行均值整合

模型图如下:

 

 

# 1.定义模型
class TextClassificationModel(nn.Module):def __init__(self,vocab_size,embed_dim,num_class):super(TextClassificationModel,self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,  # 词典大小embed_dim,    # 嵌入维度sparse=False)self.fc = nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange,initrange)self.fc.weight.data.uniform_(-initrange,initrange)self.fc.bias.data.zero_()def forward(self,text,offsets):embedded = self.embedding(text,offsets)return self.fc(embedded)

 

self.embedding.weight.data.uniform_(-initrange,initrange)这段代码是在 PyTorch 框架下用于初始化神经网络的词嵌入层(embedding layer)权重的一种方法。这里使用了均匀分布的随机值来初始化权重,具体来说,其作用如下: self.embedding:这是神经网络中的词嵌入层(embeddinglayer)。词嵌入层的作用是将离散的单词表示(通常为整数索引)映射为固定大小的连续向量。这些向量捕捉了单词之间的语义关系,并作为网络的输入。 self.embedding.weight:这是词嵌入层的权重矩阵,它的形状为(vocab size,embedding _dim),其中 vocab size 是词汇表的大小,embedding dim 是嵌入向量的维度。

self.embedding.weight.data:这是权重矩阵的数据部分,我们可以在这里直接操作其底层的张量。 .uniform(-initrange,initrange):这是一个原地操作(in-place operation),用于将权重矩阵的值用一个均匀分布进行初始化。均匀分布的范围为[-initrange,initrange],其中 initrange 是一个正数。 通过这种方式初始化词嵌入层的权重,可以使得模型在训练开始时具有一定的随机性,有助于避免梯度消失或梯度爆炸等问题。在训练过程中,这些权重将通过优化算法不断更新,以捕捉到更好的单词表示。

# 2.定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size,em_size,num_class).to(device)# 3.定义训练函数和评估函数
def train(dataloader):model.train()total_acc,train_loss,total_count = 0,0,0log_interval = 50start_time = time.time()for idx,(text,label,offsets) in enumerate(dataloader):predicted_label = model(text,offsets)optimzer.zero_grad()   # grad属性归零loss = criterion(predicted_label,label)   # 计算网络输出和真实值之间的差距loss.backward()   # 反向传播nn.utils.clip_grad_norm(model.parameters(),0.1)   # 梯度裁剪optimzer.step()  # 每一步自动更新# 记录acc与Losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count = 0,0,0start_time = time.time()def evaluate(dataloader):model.eval()total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text,label, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算网络输出和真实值之间的差距# 记录acc与Losstotal_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count,train_loss/total_count

 torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)是一个PyTorch函数,用于在训练神经网络时限制梯度的大小。这种操作被称为梯度裁剪(gradient clipping),可以防止梯度爆炸问题,从而提高神经网络的稳定性和性能。 在这个函数中: model.parameters()表示模型的所有参数。对于一个神经网络,参数通常包括权重和偏置项。0.1是一个指定的阈值,表示梯度的最大范数(L2范数)。如果计算出的梯度范数超过这个阈值,梯度会被缩放,使其范数等于阈值。 梯度裁剪的主要日的是防止梯度爆炸。梯度爆炸通常发生在训练深度神经网络时,尤其是在处理长序列数据的循环神经网络(RNN)中。当梯度爆炸时,参数更新可能会变得非常大,导致模型无法收敛或出现数值不稳定。通过限制梯度的大小,梯度裁剪有助于解决这些问题,使模型训练变得更加稳定。

4.训练模型

1)拆分数据集运行模型

 

EPOCHS = 10
LR = 5
BATCH_SIZE = 64criterion = torch.nn.CrossEntropyLoss()
optimzer = torch.optim.SGD(model.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimzer,1.0,gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)
split_train_,split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader = DataLoader(split_train_,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimzer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-'*69)print('| epoch {:1d} | time:{:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss))print('-'*69)

 torchtext.data.functional.to_map_style_dataset 函数的作用是将一个迭代式的数据集(lterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引(例如:整数)更方便地访问数据集中的元素。 在 PyTorch 中,数据集可以分为两种类型:lterable-style和 Map-style。lterable-style 数据集实现了iter_()方法,可以迭代访问数据集中的元素,但不支持通过索引访问。而 Map-style 数据集实现了__getitem()和1en()方法,可以直接通过索引访问特定元素,并能获取数据集的大小。 TorchText 是 PyTorch 的一个扩展库,专注于处理文本数据。torchtext.data.functional 中的to map style dataset 函数可以帮助我们将一个 lterable-style 数据集转换为一个易于操作的 Map-style数据集。这样,我们可以通过索引直接访问数据集中的特定样本,从而简化了训练、验证和测试过程中的数据处理。

# 2.使用测试数据集评估模型
print('Checking the results of test dataset.')
test_acc,test_loss = evaluate(valid_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))# 3.测试指定数据
def predict(text,text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text,torch.tensor([0]))return output.argmax(1).item()ex_text = "随便播放一首陈奕迅的歌"
model = model.to("cpu")
print('该文本的类别是:%s'%label_name[predict(ex_text,text_pipeline)])

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/82578.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

13.「极简」扣子(coze)教程 | 小程序UI设计进阶(三)让界面动起来,实操讲透“聚焦”事件

前一期大师兄介绍了扣子平台组件的两种状态“禁用”和“加载”。这两种方法使控件可以通过简单设置表示出更多的运行状态。今天大师兄将详细介绍控件的一种事件“聚焦”。 扣子&#xff08;coze&#xff09;编程 「极简」扣子(coze)教程 | 小程序UI设计进阶 II&#xff01;让…

剑指offer11_矩阵中的路径

矩阵中的路径 请设计一个函数&#xff0c;用来判断在一个矩阵中是否存在一条路径包含的字符按访问顺序连在一起恰好为给定字符串。 路径可以从矩阵中的任意一个格子开始&#xff0c;每一步可以在矩阵中向左&#xff0c;向右&#xff0c;向上&#xff0c;向下移动一个格子。 如…

腾讯2025年校招笔试真题手撕(三)

一、题目 今天正在进行赛车车队选拔&#xff0c;每一辆赛车都有一个不可以改变的速度。现在需要选取速度差距在10以内的车队&#xff08;车队中速度的最大值减去最小值不大于10&#xff09;&#xff0c;用于迎宾。车队的选拔按照的是人越多越好的原则&#xff0c;给出n辆车的速…

《三维点如何映射到图像像素?——相机投影模型详解》

引言 以三维投影介绍大多比较分散&#xff0c;不少小伙伴再面对诸多的坐标系转换中容易弄混&#xff0c;特别是再写代码的时候可能搞错&#xff0c;所有这篇文章帮大家完整的梳理3D视觉中的投影变换的全流程&#xff0c;一文弄清楚这个过程&#xff0c;帮助大家搞清坐标系转换…

Ini配置文件读写,增加备注功能

1.增加备注项写入 例: #节点备注 [A] #项备注 bbb1 ccc2 [B] bbb1 IniConfig2 ic new IniConfig2(); //首次写入 if (!ic.CanRead()) { ic.AddSectionReMarke("A", "节点备注"); ic.SetValue("A&qu…

OpenHarmony 5.0中状态栏添加以太网状态栏图标以及功能实现

目录 1.前置条件 2.方案 1.前置条件 首先以太网接口是有问题的,如下按照如下流程将以太网接口进行修复 OpenHarmony 以太网卡热插拔事件接口无效-CSDN博客 然后上述的接口可以了就可以通过这个接口获取以太网是否连接状态 要注意wifi连接的干扰和预置虚拟网口干扰 2.方案…

RNN GRU LSTM 模型理解

一、RNN 1. 在RNN中&#xff0c; 2. RNN是一个序列模型&#xff0c;与非序列模型不同&#xff0c;序列中的元素互相影响&#xff1a; 是由 计算得来的。 在前向传播中&#xff1a; 用于计算 和 用于计算 和 因此&#xff0c;当进行反向链式法则求导时候&#xf…

多路径传输(比如 MPTCP)控制实时突发

实时突发很难控制&#xff0c;因为 “实时” 和 “突发” 相互斥。实时要求避免排队&#xff0c;而突发必然要排队&#xff0c;最终的解决方案都指向找一个公说公有理&#xff0c;婆说婆有理的中间点&#xff0c;这并没解决问题&#xff0c;只是权衡了问题。 这种局部解决问题的…

函数式编程思想详解

函数式编程思想详解 1. 核心概念 不可变数据 (Immutable Data) 数据一旦创建&#xff0c;不可修改。任何操作均生成新数据&#xff0c;而非修改原数据。 优点&#xff1a;避免副作用&#xff0c;提升并发安全&#xff0c;简化调试。 Java实现&#xff1a;使用final字段、不可变…

iOS 主要版本发布历史

截至 2025 年 5 月&#xff0c;iOS 的最新正式版本是 iOS 18&#xff0c;于 2024 年 9 月 16 日 正式发布。此前的 iOS 17 于 2023 年 9 月 18 日 发布&#xff0c;并在 2024 年被 iOS 18 取代。(维基百科) &#x1f4f1; iOS 主要版本发布历史 以下是 iOS 各主要版本的发布日…

矩阵详解:线性代数在AI大模型中的核心支柱

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家、CSDN平台优质创作者&#xff0c;高级开发工程师&#xff0c;数学专业&#xff0c;10年以上C/C, C#, Java等多种编程语言开发经验&#xff0c;拥有高级工程师证书&#xff1b;擅长C/C、C#等开发语言&#xff0c;熟悉Java常用开…

基于51单片机和8X8点阵屏、独立按键的飞行躲闪类小游戏

目录 系列文章目录前言一、效果展示二、原理分析三、各模块代码1、8X8点阵屏2、独立按键3、定时器04、定时器1 四、主函数总结 系列文章目录 前言 用的是普中A2开发板。 【单片机】STC89C52RC 【频率】12T11.0592MHz 【外设】8X8点阵屏、独立按键 效果查看/操作演示&#xff…

区块链可投会议CCF C--APSEC 2025 截止7.13 附录用率

Conference&#xff1a;32nd Asia-Pacific Software Engineering Conference (APSEC 2025) CCF level&#xff1a;CCF C Categories&#xff1a;软件工程/系统软件/程序设计语言 Year&#xff1a;2025 Conference time&#xff1a;December 2-5, 2025 in Macao SAR, China …

pdf图片导出(Visio\Origin\PPT)

一、Visio 导入pdf格式图片 1. 设计->大小&#xff0c;适应绘图。 2. 文件->导出&#xff0c;导出为pdf格式。 上面两部即可得到只包含图的部分的pdf格式。 如果出现的有默认白边&#xff0c;可以通过以下方式设置&#xff1a; 1. 文件->选项->自定义功能区->…

vector的实现

介绍 1. 本质与存储结构 动态数组实现&#xff1a;vector 本质是动态分配的数组&#xff0c;采用连续内存空间存储元素&#xff0c;支持下标访问&#xff08;如 vec[i]&#xff09;&#xff0c;访问效率与普通数组一致&#xff08;时间复杂度 O (1)&#xff09;。动态扩容机制&…

【Linux笔记】防火墙firewall与相关实验(iptables、firewall-cmd、firewalld)

一、概念 1、防火墙firewall Linux 防火墙用于控制进出系统的网络流量&#xff0c;保护系统免受未授权访问。常见的防火墙工具包括 iptables、nftables、UFW 和 firewalld。 防火墙类型 包过滤防火墙&#xff1a;基于网络层&#xff08;IP、端口、协议&#xff09;过滤流量&a…

el-date-picker 前端时间范围选择器

控制台参数&#xff1a; 前端代码&#xff1a;用数组去接受&#xff0c;同时用 value-format"YYYY-MM-DD" 格式化值为&#xff1a;年月日格式 <!-- 查询区域 --><transition name"fade"><div class"search" v-show"showSe…

在 macOS 上安装 jenv 管理 JDK 版本

在 macOS 上安装 jenv 并管理 JDK 版本 在开发 Java 应用程序时&#xff0c;你可能需要在不同的项目中使用不同版本的 JDK。手动切换 JDK 版本可能会很繁琐&#xff0c;但幸运的是&#xff0c;有一个工具可以简化这个过程&#xff1a;jenv。jenv 是一个流行的 Java 版本管理工…

2025年全国青少年信息素养大赛复赛C++集训(16):吃糖果2(题目及解析)

2025年全国青少年信息素养大赛复赛C集训&#xff08;16&#xff09;&#xff1a;吃糖果2&#xff08;题目及解析&#xff09; 题目描述 现有n(50 > n > 0)个糖果,每天只能吃2个或者3个&#xff0c;请计算共有多少种不同的吃法吃完糖果。 时间限制&#xff1a;1000 内存…

ARM笔记-嵌入式系统基础

第一章 嵌入式系统基础 1.1嵌入式系统简介 1.1.1嵌入式系统定义 嵌入式系统定义&#xff1a; 嵌入式系统是以应用为中心&#xff0c;以计算机技术为基础&#xff0c;软硬件可剪裁&#xff0c;对功能、可靠性、成本、体积、功耗等有严格要求的专用计算机系统 ------Any devic…