机器学习监督学习实战七:文本卷积神经网络TextCNN对中文短文本分类(15类)

  本文介绍了一个基于TextCNN模型的文本分类项目,使用今日头条新闻数据集进行训练和评估。项目包括数据获取、预处理、模型训练、评估测试等环节。数据预处理涉及清洗文本、中文分词、去除停用词、构建词汇表和向量化等步骤。TextCNN模型通过卷积层和池化层提取文本特征,并在训练过程中记录准确率和损失。最终,模型在测试集上达到了较高的准确率(84.06%),并生成了混淆矩阵可视化。项目还详细介绍了TextCNN模型的结构和创新点,以及数据预处理和模型训练的具体实现代码。完整代码已开源在个人GitHub:https://github.com/KLWU07/Chinese-text-classification-TextCNN

一、项目描述

1.数据获取

(1)今日头条新闻分类数据爬取(Python脚本)

  • 使用今日头条文本分类数据集,包含民生故事、文化、娱乐、体育、财经、房
    产、汽车、教育、科技、军事、旅游、国际、证券股票、农业、游戏15个类别,共30多万条数据。

(2)每条数据形式

6554371968739574280_!_102_!_news_entertainment_!_今年你看《复联3》哭得有多惨,明年你看《复联4》就会叫得有多爽_!_卡魔拉,灭霸,小丑,复仇者联盟3,理想主义者,复联3,复联4,雷神3,漫威,黑暗骑士

2.数据预处理

(1)清洗文本(保留中文字符)
(2)中文分词(使用jieba)
(3)去除停用词
(4)构建词汇表并转换为数字向量
(5)数据划分为训练集/验证集/测试集

3.模型训练

(1)使用TextCNN模型进行训练
(2)记录训练过程中的准确率和损失
(3)保存最佳模型

4.评估测试

(1)输出测试集的准确率、精确率、召回率和F1值
(2)生成混淆矩阵可视化

二、训练过程和结果

1.训练过程中的准确率和损失、混淆矩阵可视化

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

textCNN((embed): Embedding(165444, 64, padding_idx=1)(conv11): Conv2d(1, 16, kernel_size=(3, 64), stride=(1, 1))(conv12): Conv2d(1, 16, kernel_size=(4, 64), stride=(1, 1))(conv13): Conv2d(1, 16, kernel_size=(5, 64), stride=(1, 1))(dropout): Dropout(p=0.5, inplace=False)(fc1): Linear(in_features=48, out_features=15, bias=True)
)
Epoch: 1   [===========]  cost: 775.52s;  loss: 10520.5979;  train acc: 0.1589;  val acc:0.2256;
Epoch: 2   [===========]  cost: 941.70s;  loss: 9934.9760;  train acc: 0.2239;  val acc:0.2975;
Epoch: 3   [===========]  cost: 1309.58s;  loss: 9329.3478;  train acc: 0.2844;  val acc:0.3842;
Epoch: 4   [===========]  cost: 1619.02s;  loss: 8444.4464;  train acc: 0.3663;  val acc:0.4751;
Epoch: 5   [===========]  cost: 1837.63s;  loss: 7384.8970;  train acc: 0.4545;  val acc:0.5569;
Epoch: 6   [===========]  cost: 1989.67s;  loss: 6397.7872;  train acc: 0.5371;  val acc:0.6317;
Epoch: 7   [===========]  cost: 2095.66s;  loss: 5669.9501;  train acc: 0.6029;  val acc:0.6932;
Epoch: 8   [===========]  cost: 2142.67s;  loss: 5095.1675;  train acc: 0.6578;  val acc:0.7403;
Epoch: 9   [===========]  cost: 2126.22s;  loss: 4633.6697;  train acc: 0.6989;  val acc:0.7702;
Epoch: 10   [===========]  cost: 2101.42s;  loss: 4296.7775;  train acc: 0.7272;  val acc:0.7899;
Epoch: 11   [===========]  cost: 2099.69s;  loss: 4047.9119;  train acc: 0.7460;  val acc:0.8012;
Epoch: 12   [===========]  cost: 2085.49s;  loss: 3871.4451;  train acc: 0.7594;  val acc:0.8096;
Epoch: 13   [===========]  cost: 2152.68s;  loss: 3725.9579;  train acc: 0.7689;  val acc:0.8157;
Epoch: 14   [===========]  cost: 2178.29s;  loss: 3613.9486;  train acc: 0.7761;  val acc:0.8209;
Epoch: 15   [===========]  cost: 2303.36s;  loss: 3531.0598;  train acc: 0.7828;  val acc:0.8230;
Epoch: 16   [===========]  cost: 2183.05s;  loss: 3447.1700;  train acc: 0.7878;  val acc:0.8280;
Epoch: 17   [===========]  cost: 2213.03s;  loss: 3388.7178;  train acc: 0.7918;  val acc:0.8286;
Epoch: 18   [===========]  cost: 2349.19s;  loss: 3332.7325;  train acc: 0.7953;  val acc:0.8313;
Epoch: 19   [===========]  cost: 2508.73s;  loss: 3278.0177;  train acc: 0.7982;  val acc:0.8326;
Epoch: 20   [===========]  cost: 2579.64s;  loss: 3241.5735;  train acc: 0.8021;  val acc:0.8341;
Epoch: 21   [===========]  cost: 2616.97s;  loss: 3205.6926;  train acc: 0.8039;  val acc:0.8350;
Epoch: 22   [===========]  cost: 2644.12s;  loss: 3158.8484;  train acc: 0.8072;  val acc:0.8348;
Epoch: 23   [===========]  cost: 2949.09s;  loss: 3135.9280;  train acc: 0.8093;  val acc:0.8376;
Epoch: 24   [===========]  cost: 2474.57s;  loss: 3118.4267;  train acc: 0.8100;  val acc:0.8385;
Epoch: 25   [===========]  cost: 2544.15s;  loss: 3092.5366;  train acc: 0.8117;  val acc:0.8395;
Epoch: 26   [===========]  cost: 2468.19s;  loss: 3072.7564;  train acc: 0.8132;  val acc:0.8389;
Epoch: 27   [===========]  cost: 2449.40s;  loss: 3041.6649;  train acc: 0.8155;  val acc:0.8403;
Epoch: 28   [===========]  cost: 2472.81s;  loss: 3025.2191;  train acc: 0.8161;  val acc:0.8410;
Epoch: 29   [===========]  cost: 2481.49s;  loss: 3007.8602;  train acc: 0.8182;  val acc:0.8408;
Epoch: 30   [===========]  cost: 2424.77s;  loss: 2986.9278;  train acc: 0.8195;  val acc:0.8421;
test ...
test acc: 0.8406   precision: 0.8406   recall: 0.8406    f1: 0.8406 

三、文本卷积神经网路TextCNN模型解释

1.《Convolutional Neural Networks for Sentence Classification》论文创新点

(1)模型应用创新:将卷积神经网络(CNN)应用于自然语言处理(NLP)中的句子分类任务,打破了传统 NLP 任务主要依赖循环神经网络(RNN)及其变体的局面。
(2)输入表示创新:采用预训练的词向量(如 word2vec)作为输入,替代传统的 one - hot 编码。预训练词向量能更好地捕捉词的语义信息,提高了模型的泛化能力和性能。
(3)特征提取创新:使用多个不同尺寸的卷积核来提取句子中的关键信息,类似于多窗口大小的 ngram。不同大小的卷积核可以捕捉不同长度的局部相关性,从而更全面地提取句子的特征,提高模型的特征提取能力。
(4)模型结构创新:TextCNN 模型结构相对简单,计算快速、实现方便,且在准确性方面表现较高。它包含嵌入层、卷积层、池化层和全连接层等,通过简单的结构实现了高效的句子分类。
(5)训练策略创新:提出了多种训练策略,如 CNN - rand(基础模型中所有单词都是随机初始化,然后在训练期间进行修改)、CNN - static(带有来自 word2vec 的预训练向量,所有单词保持静态,只学习模型的其他参数)、CNN - non - static(预训练的向量针对每个任务进行微调)和 CNN - multichannel(有两组词向量模型)。

2.数据预处理模块

# 数据处理相关函数
def is_chinese(uchar):# 判断字符是否为中文
def reserve_chinese(content):# 保留文本中的中文字符
def getStopWords():# 加载停用词表
def dataParse(text, stop_words):# 解析文本数据,映射标签,清洗文本并分词
def getFormatData():# 处理原始数据,构建词表,生成词向量表示并保存

(1)中文分词:jieba分词

  • jieba是一个专为中文文本设计的分词库,其主要功能是将连续的中文文本拆分成有意义的词语序列。与英文不同,中文文本中词与词之间没有空格作为自然分隔符。例如:中文句子:“我喜欢自然语言处理”,分词结果:“我 / 喜欢 / 自然语言处理”。分词的准确性直接影响后续文本分析的质量,如关键词提取、情感分析、机器翻译等。
  • jieba分词基于以下技术:基于前缀词典的匹配算法、隐马尔可夫模型 (HMM)、动态规划优化
  • jieba 分词的主要模式:精确模式(默认),将文本精确地切分成词语,适合文本分析。全模式,将文本中所有可能的词语都扫描出来,速度快但可能产生冗余。搜索引擎模式,在精确模式基础上,对长词再次切分,适合搜索引擎分词。

(2)去除停用词

  • 停用词(Stop Words)是自然语言处理(NLP)中一类被认为对文本分析价值较低的词汇。如中文中的 “的”“了”“在”“是”,英文中的 “the”“a”“is” 等。包括标点符号、连接词、代词等,本身不携带具体语义信息。去除停用词是文本预处理的核心步骤,通过 stopwords.txt 文件可高效管理停用词列表,配合 jieba 等分词工具,能显著提升文本分析的质量和效率。
  • 去除停用词的主要目的:减少数据量,提升处理效率;过滤无意义词汇,突出关键信息(如名词、动词等实词);避免停用词对文本分析(如关键词提取、情感分析)产生干扰。
import jiebadef filter_stopwords(text, stopwords):"""对文本分词并去除停用词"""# 分词words = jieba.lcut(text)# 过滤停用词filtered_words = [word for word in words if word not in stopwords and len(word) > 1]return filtered_words# 示例文本
text = "今天天气很好,我打算去公园散步。"
filtered_result = filter_stopwords(text, stopwords)
print(filtered_result)  # 输出:['今天', '天气', '很好', '打算', '公园', '散步']

(3)构建词表:按词频排序,高频词获得更小的索引

  • 在自然语言处理中,构建词表(Vocabulary)是将文本转换为计算机可处理格式的关键步骤。按词频排序并让高频词获得更小索引的策略,是一种常见且有效的词表构建方法。
  • 文本数字化的需求:神经网络等机器学习模型无法直接处理文本,需要将文本转换为数字表示。词表是文本与数字之间的映射桥梁,每个词被映射为唯一的整数索引。
  • 按词频排序:在大量文本中频繁出现的词,通常携带更关键的语义信息。更小的索引值占用更少的内存和计算资源,高频词优先使用小索引可提升效率。

(4)文本向量化:通过词汇表将词语映射为索引

  • 在自然语言处理中,将文本转换为词向量序列是连接原始文本与深度学习模型的关键步骤。这一过程将离散的文本符号转换为连续的数值表示,使模型能够捕捉文本中的语义信息。
1.分词:使用分词工具(如代码中的 jieba.lcut)将句子拆分为词语列表。
例如:"深度学习很强大"["深度", "学习", "很", "强大"]2.统计词频:遍历所有文本,统计每个词的出现频率(使用 Counter)。
例如:{"深度": 10, "学习": 8, "强大": 5, ...}3.排序和过滤:按词频从高到低排序(代码中 reverse=True),保留高频词以降低维度。低频词可视为噪声或替换为 <UNK>(未知词)。4.建立映射:为每个词分配唯一索引(从 1 开始,0 通常留给填充符 <PAD>)。
例如:"<PAD>": 0, "深度": 1, "学习": 2, "强大": 3, ..., "<UNK>": 100015.词到索引的转换:根据词汇表将每个词替换为对应的索引。处理未知词:若词不在词汇表中,使用 <UNK> 的索引(如 10001)。
例如:["深度", "学习", "强大"][1, 2, 3]6.序列对齐(Padding):使用 pad_sequence 将所有序列填充到相同长度(按最长序列或预设值)。
例如(填充到长度5):
[1, 2, 3][1, 2, 3, 0, 0]  # 0是<pad>的索引

(5) MAT 文件 (data.mat)

data = {'X': data,          # 文本的数字向量表示(填充后的索引序列)'label': labels,     # 文本对应的类别标签(数字编码)'num_words': len(my_vocab)  # 词汇表大小(总词数)
}
io.savemat('./dataset/data/data.mat', data)
  • data.mat 是通过 scipy.io.savemat() 生成的一个 MATLAB 格式的数据文件,它保存了预处理后的文本数据及其标签,用于后续的模型训练和评估。

  • 为什么使用 .mat 格式?

    • 兼容性: MATLAB和Python(通过 scipy.io)均可读写,便于跨平台使用。
    • 结构化存储: 适合存储多维数组和元数据(如词汇表大小)。
    • 效率: 二进制格式加载速度快,占用空间比纯文本小。
  • X(特征数据)
    形状: [num_samples, max_seq_len]
    num_samples: 样本数量(即多少条文本)。
    max_seq_len: 填充后的序列最大长度。每条文本被转换为固定长度的数字序列
    例如:

原始文本: ["深度", "学习", "强大"] → 分词后
词汇表: {"深度":1, "学习":2, "强大":3, "<PAD>":0}
填充后: [1, 2, 3, 0, 0]  # 填充到长度5
  • label(标签数据)
    形状: [num_samples,]
    每个文本的类别标签(通过 LabelEncoder 从字符串标签转换为数字)

  • num_words(词汇表大小)
    作用: 记录词汇表的总词数,用于初始化模型的嵌入层(nn.Embedding)。

数据字段维度说明
X382589×53共 382,589 条文本样本,每条文本被填充/截断为 53 个词的固定长度。
num_words165444词汇表总大小(包含所有唯一词 + 特殊符号如 和 )
label1×382589每条文本对应的类别标签(共 382,589 个标签,可能是 15 分类任务)
  • 53 是数据集中所有文本分词后的 最大词数(即至少有一条文本被分词为 53 个词)
text:"他是最帅反派专业户,演《古惑仔》大火,今病魔缠身可怜无人识!"
经过 jieba.lcut() 分词后可能变成:['他', '是', '最', '帅', '反派', '专业户', ',', '演', '《', '古惑仔', '》', '大火', ',', '今', '病魔', '缠身', '可怜', '无人', '识', '!']
分词数量:20 个(如果过滤掉标点符号和停用词,可能更少)。

3.TextCNN 的核心结构

(1)Embedding Layer(词嵌入层):将单词映射为稠密向量。
输入:词的整数索引(如通过词表映射得到的[1, 5, 10])。
输出:固定维度的连续向量(如 100 维的[0.2, -0.5, 0.1, …])。

(2)Convolutional Layer(卷积层):使用 不同尺寸的卷积核(如 [3,4,5])在词向量序列上滑动,提取局部特征。

  • 为什么使用卷积?

    • 捕获局部特征:文本中的语义单元(如短语、情感表达)通常是连续词的组合。
    • 参数共享:减少模型参数量,提高泛化能力。
    • 多尺度特征:不同大小的卷积核能捕获不同长度的语义模式。
  • 文本数据的二维表示
    虽然文本本质是一维序列,但在卷积操作中通常表示为二维张量,形状:[句子长度 × 词向量维度]。如一个包含 10 个词的句子,每个词用 300 维向量表示 → [10 × 300] 的二维矩阵

  • kernel_size=2:每次滑动查看 2 个连续词 的组合(如 [“深度”, “学习”])。
    kernel_size=3:每次滑动查看 3 个连续词 的组合(如 [“深度”, “学习”, “模型”])。
    kernel_size=4:每次滑动查看 4 个连续词 的组合(如 [“深度”, “学习”, “模型”, “强大”])。

  • 经过 ReLU 激活函数增强非线性。

Max-Pooling Layer(最大池化层)
对每个卷积核的输出进行 全局最大池化(Global Max Pooling),提取最重要的特征。
输出:[batch_size, num_filters](每个卷积核保留一个最大值)。

(3)Concatenation(特征拼接):将所有卷积核的池化结果拼接,形成固定长度的特征向量。

  • 池化层的核心功能:特征筛选与降维
    • 最大池化:在每个滑动窗口中选取最大值作为输出,本质是保留最突出的特征。
    • 降维效果:减少特征维度,降低计算量,同时避免过拟合。
  • 文本处理中的直观意义
    • 文本中的关键语义(如情感词、主题词)通常会产生较大的卷积响应值。
    • 最大池化相当于 “筛选” 出每个局部窗口中最关键的特征,忽略次要信息。
import torch
import torch.nn.functional as F# 假设卷积后的特征图(单通道)
feature_map = torch.tensor([[0.2, -0.5, 0.8, 0.1, -0.3],  # 长度为5的特征序列
], dtype=torch.float32).unsqueeze(0)  # [1, 1, 5]# 最大池化(窗口大小为5,覆盖整个序列)
pooled = F.max_pool1d(feature_map, kernel_size=5)
print(pooled)  # 输出: tensor([[[0.8]]])

该操作从长度为 5 的序列中提取最大值 0.8,作为该通道的最终特征。
(4)Fully Connected Layer(全连接层):将拼接后的特征映射到类别空间(Softmax 输出概率)。在文本分类模型中,全连接层 (Linear) 和 Dropout 层是构建分类器的关键组件。它们共同作用,将提取的文本特征映射到分类空间并防止过拟合。

四、完整代码

# -*- coding: utf-8 -*-
from jieba import lcut
from torchtext.vocab import vocab
from collections import OrderedDict, Counter
from torchtext.transforms import VocabTransform
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from sklearn.preprocessing import LabelEncoder
import scipy.io as io
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from utils import metrics, safeCreateDir
import time
from sklearn.metrics import ConfusionMatrixDisplay
from matplotlib import pyplot as plt
import seaborn as sns
import torch
from torch.nn import functional as F
import math
from sklearn.metrics import confusion_matrix
import os# 数据处理
def is_chinese(uchar):if (uchar >= '\u4e00' and uchar <= '\u9fa5'):return Trueelse:return Falsedef reserve_chinese(content):content_str = ''for i in content:if is_chinese(i):content_str += ireturn content_strdef getStopWords():file = open('./dataset/stopwords.txt', 'r', encoding='utf8')words = [i.strip() for i in file.readlines()]file.close()return wordsdef dataParse(text, stop_words):label_map = {'news_story': 0, 'news_culture': 1, 'news_entertainment': 2,'news_sports': 3, 'news_finance': 4, 'news_house': 5, 'news_car': 6,'news_edu': 7, 'news_tech': 8, 'news_military': 9, 'news_travel': 10,'news_world': 11, 'stock': 12, 'news_agriculture': 13, 'news_game': 14}_, _, label, content, _ = text.split('_!_')label = label_map[label]content = reserve_chinese(content)words = lcut(content)words = [i for i in words if not i in stop_words]return words, int(label)def getFormatData():file = open('./dataset/data/toutiao_cat_data.txt', 'r', encoding='utf8')texts = file.readlines()file.close()stop_words = getStopWords()all_words = []all_labels = []for text in texts:content, label = dataParse(text, stop_words)if len(content) <= 0:continueall_words.append(content)all_labels.append(label)ws = sum(all_words, [])set_ws = Counter(ws)keys = sorted(set_ws, key=lambda x: set_ws[x], reverse=True)dict_words = dict(zip(keys, list(range(1, len(set_ws) + 1))))ordered_dict = OrderedDict(dict_words)my_vocab = vocab(ordered_dict, specials=['<UNK>', '<SEP>'])vocab_transform = VocabTransform(my_vocab)vector = vocab_transform(all_words)vector = [torch.tensor(i) for i in vector]pad_seq = pad_sequence(vector, batch_first=True)labelencoder = LabelEncoder()labels = labelencoder.fit_transform(all_labels)data = pad_seq.numpy()data = {'X': data,'label': labels,'num_words': len(my_vocab)}io.savemat('./dataset/data/data.mat', data)# 数据集加载
class Data(Dataset):def __init__(self, mode='train'):data = io.loadmat('./dataset/data/data.mat')self.X = data['X']self.y = data['label']self.num_words = data['num_words'].item()train_X, val_X, train_y, val_y = train_test_split(self.X, self.y.squeeze(), test_size=0.3, random_state=1)val_X, test_X, val_y, test_y = train_test_split(val_X, val_y, test_size=0.5, random_state=1)if mode == 'train':self.X = train_Xself.y = train_yelif mode == 'val':self.X = val_Xself.y = val_yelif mode == 'test':self.X = test_Xself.y = test_ydef __getitem__(self, item):return self.X[item], self.y[item]def __len__(self):return self.X.shape[0]class getDataLoader():def __init__(self, batch_size):train_data = Data('train')val_data = Data('val')test_data = Data('test')self.traindl = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)self.valdl = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)self.testdl = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)self.num_words = train_data.num_words# 定义网络结构
class textCNN(nn.Module):def __init__(self, param):super(textCNN, self).__init__()ci = 1kernel_num = param['kernel_num']kernel_size = param['kernel_size']vocab_size = param['vocab_size']embed_dim = param['embed_dim']dropout = param['dropout']class_num = param['class_num']self.param = paramself.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=1)self.conv11 = nn.Conv2d(ci, kernel_num, (kernel_size[0], embed_dim))self.conv12 = nn.Conv2d(ci, kernel_num, (kernel_size[1], embed_dim))self.conv13 = nn.Conv2d(ci, kernel_num, (kernel_size[2], embed_dim))self.dropout = nn.Dropout(dropout)self.fc1 = nn.Linear(len(kernel_size) * kernel_num, class_num)def init_embed(self, embed_matrix):self.embed.weight = nn.Parameter(torch.Tensor(embed_matrix))@staticmethoddef conv_and_pool(x, conv):x = conv(x)x = F.relu(x.squeeze(3))x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):x = self.embed(x)x = x.unsqueeze(1)x1 = self.conv_and_pool(x, self.conv11)x2 = self.conv_and_pool(x, self.conv12)x3 = self.conv_and_pool(x, self.conv13)x = torch.cat((x1, x2, x3), 1)x = self.dropout(x)logit = F.log_softmax(self.fc1(x), dim=1)return logitdef init_weight(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.01)m.bias.data.zero_()def plot_acc(train_acc):sns.set(style='darkgrid')plt.figure(figsize=(10, 7))x = list(range(len(train_acc)))plt.plot(x, train_acc, alpha=0.9, linewidth=2, label='train acc')plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend(loc='best')plt.savefig('results/acc.png', dpi=400)def plot_loss(train_loss):sns.set(style='darkgrid')plt.figure(figsize=(10, 7))x = list(range(len(train_loss)))plt.plot(x, train_loss, alpha=0.9, linewidth=2, label='train loss')plt.xlabel('Epoch')plt.ylabel('loss')plt.legend(loc='best')plt.savefig('results/loss.png', dpi=400)# 定义训练过程
class Trainer():def __init__(self):safeCreateDir('results/')self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')self._init_data()self._init_model()def _init_data(self):data = getDataLoader(batch_size=64)self.traindl = data.traindlself.valdl = data.valdlself.testdl = data.testdlself.num_words = data.num_wordsdef _init_model(self):self.textCNN_param = {'vocab_size': self.num_words,'embed_dim': 64,'class_num': 15,"kernel_num": 16,"kernel_size": [3, 4, 5],"dropout": 0.5,}self.net = textCNN(self.textCNN_param)self.opt = Adam(self.net.parameters(), lr=1e-4, weight_decay=5e-4)self.cri = nn.CrossEntropyLoss()def save_model(self):save_dir = 'saved_dict'if not os.path.exists(save_dir):os.makedirs(save_dir)torch.save(self.net.state_dict(), os.path.join(save_dir, 'cnn.pt'))def load_model(self):save_dir = 'saved_dict'model_path = os.path.join(save_dir, 'cnn.pt')if not os.path.exists(model_path):raise FileNotFoundError(f"Model file not found at {model_path}")self.net.load_state_dict(torch.load(model_path))def train(self, epochs):print('init net...')self.net.init_weight()print(self.net)patten = 'Epoch: %d   [===========]  cost: %.2fs;  loss: %.4f;  train acc: %.4f;  val acc:%.4f;'train_accs = []c_loss = []for epoch in range(epochs):cur_preds = np.empty(0)cur_labels = np.empty(0)cur_loss = 0start = time.time()for batch, (inputs, targets) in enumerate(self.traindl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)loss = self.cri(pred, targets)self.opt.zero_grad()loss.backward()self.opt.step()cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])cur_loss += loss.item()acc, precision, f1, recall = metrics(cur_preds, cur_labels)val_acc, val_precision, val_f1, val_recall = self.val()train_accs.append(acc)c_loss.append(cur_loss)end = time.time()print(patten % (epoch + 1, end - start, cur_loss, acc, val_acc))self.save_model()plot_acc(train_accs)plot_loss(c_loss)@torch.no_grad()def val(self):self.net.eval()cur_preds = np.empty(0)cur_labels = np.empty(0)for batch, (inputs, targets) in enumerate(self.valdl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])acc, precision, f1, recall = metrics(cur_preds, cur_labels)self.net.train()return acc, precision, f1, recall@torch.no_grad()def test(self):print("test ...")self.load_model()patten = 'test acc: %.4f   precision: %.4f   recall: %.4f    f1: %.4f    'self.net.eval()cur_preds = np.empty(0)cur_labels = np.empty(0)for batch, (inputs, targets) in enumerate(self.testdl):inputs = inputs.to(self.device)targets = targets.to(self.device)self.net.to(self.device)pred = self.net(inputs)cur_preds = np.concatenate([cur_preds, pred.cpu().detach().numpy().argmax(axis=1)])cur_labels = np.concatenate([cur_labels, targets.cpu().numpy()])acc, precision, f1, recall = metrics(cur_preds, cur_labels)cv_conf = confusion_matrix(cur_preds, cur_labels)labels11 = ['story', 'culture', 'entertainment', 'sports', 'finance','house', 'car', 'edu', 'tech', 'military','travel', 'world', 'stock', 'agriculture', 'game']fig, ax = plt.subplots(figsize=(15, 15))disp = ConfusionMatrixDisplay(confusion_matrix=cv_conf, display_labels=labels11)disp.plot(cmap="Blues", values_format='', ax=ax)plt.savefig("results/ConfusionMatrix.png", dpi=400)self.net.train()print(patten % (acc, precision, recall, f1))if __name__ == "__main__":getFormatData()  # 数据预处理:数据清洗和词向量trainer = Trainer()trainer.train(epochs=30)  # 数据训练trainer.test()  # 测试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/87600.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/87600.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

iot-dc3 项目Bug修复保姆喂奶级教程

一.Uncaught (in promise) ReferenceError: TinyArea is not defined 1.触发场景 前端设备模块,点击关联模板、关联位号、设备数据,无反应,一直切不过去,没有报错通知,F12查看控制台报错如下: 2.引起原因 前端导入的库为"@antv/g2": "^5.3.0",在 P…

Spring Boot + MyBatis Plus + SpringAI + Vue 毕设项目开发全解析(源码)

前言 前些天发现了一个巨牛的人工智能免费学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站 Spring Boot MyBatis Plus SpringAI Vue 毕设项目开发全解析 目录 一、项目概述与技术选型 项目背景与需求分析技术栈选择…

Vitess数据库部署与运维深度指南:构建可伸缩、高可用与安全的云原生数据库

摘要 Vitess是一个为MySQL和MariaDB设计的云原生、水平可伸缩的分布式数据库系统&#xff0c;它通过分片&#xff08;sharding&#xff09;实现无限扩展&#xff0c;同时保持对应用程序的透明性&#xff0c;使其无需感知底层数据分布。该项目于2019年从云原生计算基金会&#…

SpringAI+DeepSeek大模型应用开发——6基于MongDB持久化对话

持久化对话 默认情况下&#xff0c;聊天记忆存储在内存中ChatMemory chatMemory new InMemoryChatMemory()。 如果需要持久化存储&#xff0c;可以实现一个自定义的聊天记忆存储类&#xff0c;以便将聊天消息存储在你选择的任何持久化存储介质中。 MongoDB 文档型数据库&…

Mac电脑-音视频剪辑编辑-Final Cut Pro X(fcpx)

Final Cut Pro Mac是一款专业的视频剪辑工具&#xff0c;专为苹果用户设计。 它具备强大的视频剪辑、音轨、图形特效和调色功能&#xff0c;支持整片输出&#xff0c;提升创作效率。 经过Apple芯片优化&#xff0c;利用Metal引擎动力&#xff0c;可处理更复杂的项目&#xff…

不同程度多径效应影响下的无线通信网络电磁信号仿真数据生成程序

生成.mat数据&#xff1a; %创建时间&#xff1a;2025年6月19日 %zhouzhichao %遍历生成不同程度多径效应影响的无线通信网络拓扑推理数据用于测试close all clearsnr 40; n 30;dataset_n 100;for bias 0.1:0.1:0.9nodes_P ones(n,1);Sampling_M 3000;%获取一帧信号及对…

Eureka 和 Feign(二)

Eureka 和 Feign 是 Spring Cloud 微服务架构中协同工作的两个核心组件&#xff0c;它们的关系可以通过以下比喻和详解来说明&#xff1a; 关系核心&#xff1a;服务发现 → 动态调用 组件角色核心功能Eureka服务注册中心服务实例的"电话簿"Feign声明式HTTP客户端根…

Springboot仿抖音app开发之RabbitMQ 异步解耦(进阶)

Springboot仿抖音app开发之评论业务模块后端复盘及相关业务知识总结 Springboot仿抖音app开发之粉丝业务模块后端复盘及相关业务知识总结 Springboot仿抖音app开发之用短视频务模块后端复盘及相关业务知识总结 Springboot仿抖音app开发之用户业务模块后端复盘及相关业务知识…

1.部署KVM虚拟化平台

一.KVM原理简介 广义的KVM实际上包含两部分&#xff0c;一部分是基于Linux内核支持的KVM内核模块&#xff0c;另一部分就是经过简化和修改的Qemuo KVM内核模块是模拟处理器和内存以支持虚拟机的运行&#xff0c;Qemu主要处理丨℃以及为用户提供一个用户空间工具来进行虚拟机的…

优化与管理数据库连接池

优化与管理数据库连接池 在现代高并发系统中,数据库连接池是保障数据库访问性能的核心组件之一。合理配置、优化和管理连接池,可以有效缓解连接创建成本高、连接频繁断开重连等问题,从而提升系统整体的响应速度与稳定性。 数据库连接池的作用与价值 数据库连接池的核心思…

实现回显服务器(基于UDP)

目录 一.回显服务器的基本概念 二.回显服务器的简单示意图 三.实现回显服务器&#xff08;基于UDP&#xff09;必须要知道的API 1.DatagramSocket 2.DatagramPacket 3.InetSocketAddress 4.二者区别 1. 功能职责 2. 核心作用 3. 使用场景流程 四.实现服务器端的主…

LabVIEW电液伺服阀自动测试

针对航空航天及工业液压领域电液伺服阀测试需求&#xff0c;采用 LabVIEW 图形化编程平台&#xff0c;集成 NI、GE Druck 等品牌硬件&#xff0c;构建集静态特性&#xff08;流量/ 压力 / 泄漏&#xff09;与动态特性&#xff08;频率响应&#xff09;测试于一体的自动化系统&a…

性能优化 - 高级进阶: Spring Boot服务性能优化

文章目录 Pre引言&#xff1a;为何提前暴露指标与分析的重要性指标暴露与监控接入Prometheus 集成 性能剖析工具&#xff1a;火焰图与 async-profilerasync-profiler 下载与使用结合 Flame 图优化示例 HTTP 及 Web 层优化CDN 与静态资源加速Cache-Control/Expires 在 Nginx 中配…

力扣网C语言编程题:除自身以外数组的乘积

一. 简介 本文记录力扣网上涉及数组方面的编程题&#xff0c;主要以 C语言实现。 二. 力扣上C语言编程题&#xff1a;涉及数组 题目&#xff1a;除自身以外数组的乘积 给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i…

SpringBoot扩展——发送邮件!

发送邮件 在日常工作和生活中经常会用到电子邮件。例如&#xff0c;当注册一个新账户时&#xff0c;系统会自动给注册邮箱发送一封激活邮件&#xff0c;通过邮件找回密码&#xff0c;自动批量发送活动信息等。邮箱的使用基本包括这几步&#xff1a;先打开浏览器并登录邮箱&…

【html】iOS26 液态玻璃实现效果

<!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>液体玻璃效果演示</title><style>bo…

探索算法秘境:量子随机游走算法及其在图论问题中的创新应用

目录 ​编辑 一、量子随机游走算法的起源与原理 二、量子随机游走算法在图论问题中的创新应用 三、量子随机游走算法的优势与挑战 四、结语 在算法研究的浩瀚星空中&#xff0c;总有一些领域如同遥远星系&#xff0c;闪烁着神秘而诱人的光芒。今天&#xff0c;我们将一同深…

C# 一维数组和矩形数组全解析

在编程的世界里&#xff0c;数组是一种非常重要的数据结构。今天&#xff0c;我们就来详细了解一下一维数组和矩形数组。 数组基础认知 数组实例是从 System.Array 继承类型的对象。由于它从 BCL 基类派生而来&#xff0c;所以继承了许多有用的成员&#xff1a; Rank 属性&a…

WebStorm编辑器侧边栏

目录 编辑器侧边栏行号配置行号隐藏行号 代码折叠侧边栏图标书签添加匿名书签添加助记符书签 运行和调试管理断点配置断点图标 版本控制配置Git Blame注释 编辑器侧边栏 编辑器左侧的垂直区域。当编写代码时&#xff0c;提供重要信息和操作图标。外观和行为可以根据你的喜好进…

腾讯云TCCA认证考试报名 - TDSQL数据库交付运维工程师(PostgreSQL版)

数据库交付运维工程师-腾讯云TDSQL(PostgreSQL版)认证 适合人群&#xff1a; 适合从事TDSQL(PostgreSQL版)交付、运维、售前咨询以及TDSQL(PostgreSQL版)相关项目的管理人员。 认证考试 单选*40道多选*20道 成绩查询 70分及以上通过认证&#xff0c;官网个人中心->认证考…