Bert项目--新闻标题文本分类

目录

技术细节

1、下载模型

 2、config文件

3、BERT 文本分类数据预处理流程

4、对输入文本进行分类

 5、计算模型的分类性能指标

6、模型训练

 7、基于BERT的文本分类预测接口

问题总结


技术细节

1、下载模型

文件名称--a0_download_model.py

使用 ModelScope 库从模型仓库下载 BERT-base-Chinese 预训练模型,并将其保存到本地指定目录。

# 模型下载
from modelscope import snapshot_downloadmodel_dir = snapshot_download('google-bert/bert-base-chinese', local_dir=r"D:\src\bert-base-chinese")

 2、config文件

数据加载与保存的路径

# 根目录self.root_path = 'E:/PythonLearning/full_mask_project2/'# 原始数据路径self.train_datapath = self.root_path + '01-data/train.txt'self.test_datapath = self.root_path + '01-data/test.txt'self.dev_datapath = self.root_path + '01-data/dev.txt'# 类别文档self.class_path = self.root_path + "01-data/class.txt"# 类别名列表self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")]  # 类别名单# 模型训练保存路径self.model_save_path = self.root_path + "dm_03_bert/save_models/bert_classifer_model.pt"  # 模型训练结果保存路径

加载预训练Bert模型以及其分词器和配置文件

        self.bert_path = r"E:\PythonLearning\full_mark_Project\dm04_Bert\bert-base-chinese"  # 预训练BERT模型的路径self.bert_model = BertModel.from_pretrained(self.bert_path)  # 加载预训练BERT模型self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)  # BERT模型的分词器self.bert_config = BertConfig.from_pretrained(self.bert_path)  # BERT模型的配置

设置模型参数

        # todo 参数self.num_classes = len(self.class_list)  # 类别数self.num_epochs = 1  # epoch数self.batch_size = 64  # mini-batch大小self.pad_size = 32  # 每句话处理成的长度(短填长切)self.learning_rate = 5e-5  # 学习率

完整config代码

import torch
import datetime
from transformers import BertModel, BertTokenizer, BertConfig# 获取当前日期字符串
current_date = datetime.datetime.now().date().strftime("%Y%m%d")# 配置类
class Config(object):def __init__(self):"""配置类,包含模型和训练所需的各种参数。"""self.model_name = "bert"  # 模型名称# todo 路径# 根目录self.root_path = 'E:/PythonLearning/full_mask_project2/'# 原始数据路径self.train_datapath = self.root_path + '01-data/train.txt'self.test_datapath = self.root_path + '01-data/test.txt'self.dev_datapath = self.root_path + '01-data/dev.txt'# 类别文档self.class_path = self.root_path + "01-data/class.txt"# 类别名列表self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")]  # 类别名单# 模型训练保存路径self.model_save_path = self.root_path + "dm_03_bert/save_models/bert_classifer_model.pt"  # 模型训练结果保存路径# 模型训练+预测的时候  训练设备,如果GPU可用,则为cuda,否则为cpuself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.bert_path = r"E:\PythonLearning\full_mark_Project\dm04_Bert\bert-base-chinese"  # 预训练BERT模型的路径self.bert_model = BertModel.from_pretrained(self.bert_path)  # 加载预训练BERT模型self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)  # BERT模型的分词器self.bert_config = BertConfig.from_pretrained(self.bert_path)  # BERT模型的配置# todo 参数self.num_classes = len(self.class_list)  # 类别数self.num_epochs = 1  # epoch数self.batch_size = 64  # mini-batch大小self.pad_size = 32  # 每句话处理成的长度(短填长切)self.learning_rate = 5e-5  # 学习率# TODO 量化模型存放地址# 注意: 量化的时候模型需要的设备首选是cpuself.bert_model_quantization_model_path = self.root_path + "dm_03_bert/save_models/bert_classifer_quantization_model.pt"  # 模型训练结果保存路径if __name__ == '__main__':# 测试conf = Config()print(conf.device)print(conf.class_list)print(conf.tokenizer)input_size = conf.tokenizer.convert_tokens_to_ids(["你", "好", "中", "人"])print(input_size)print(conf.bert_model)print(conf.bert_config)

3、BERT 文本分类数据预处理流程

数据加载

def load_raw_data(file_path):"""从指定文件中加载原始数据。处理文本文件,返回(文本, 标签类别索引)元组列表参数:file_path: 原始文本文件路径返回:list: 包含(文本, 标签类别索引)的元组列表,类别为int类型[('体验2D巅峰 倚天屠龙记十大创新概览', 8), ('60年铁树开花形状似玉米芯(组图)', 5)]"""result = []# 打印指定文件with open(file_path, 'r', encoding='utf-8') as f:# 使用tqdm包装文件读取迭代器,以便显示加载数据的进度条for line in tqdm(f, desc=f"加载原始数据{file_path}"):# 移除行两端的空白字符line = line.strip()# 跳过空行if not line:continue# 将行分割成文本和标签两部分text, label = line.split("\t")# 将标签转为int作为类别label = int(label)# 将文本和转换为整数的标签作为元组添加到数据列表中result.append((text, label))# 返回处理后的列表return result

数据集构建

# todo 2.自定义数据集
class TextDataset(Dataset):# 初始化数据def __init__(self, data_list):self.data_list = data_list# 返回数据集长度def __len__(self):return len(self.data_list)# 根据样本索引,返回对应的特征和标签def __getitem__(self, idx):text, label = self.data_list[idx]return text, label

批处理(padding/collate)

每当 DataLoader 从 Dataset 中取出一批batch 的原始数据后,
就会调用 collate_fn 来对这个 batch 进行统一处理(如填充、转换为张量等)。
  • batch 是一个包含多个 (文本, 标签) 元组的列表

  • 使用 zip(*batch) 将元组列表"转置"为两个元组:一个包含所有文本,一个包含所有标签

  • 例如:[("text1", 1), ("text2", 2)] → ("text1", "text2") 和 (1, 2)

  • add_special_tokens=True: 自动添加 [CLS] 和 [SEP] 等特殊token

  • padding='max_length': 将所有文本填充到固定长度 conf.pad_size

  • max_length=conf.pad_size: 设置最大长度

  • truncation=True: 如果文本超过最大长度则截断

  • return_attention_mask=True: 返回注意力掩码

  • input_ids: 文本转换为的数字token ID序列

  • attention_mask: 标记哪些位置是实际文本(1),哪些是填充部分(0)

def collate_fn(batch):"""对batch数据进行padding处理参数: batch: 包含(文本, 标签)元组的batch数据返回: tuple: 包含处理后的input_ids, attention_mask和labels的元组"""# todo 使用zip()将一批batch数据中的(text, label)元组拆分成两个独立的元组# texts = [item[0] for item in batch]# labels = [item[1] for item in batch]texts, labels = zip(*batch)# 对文本进行paddingtext_tokens = conf.tokenizer.batch_encode_plus(texts,add_special_tokens=True,  # 默认True,自动添加 [CLS] 和 [SEP]# padding=True,自动填充到数据中的最大长度       padding='max_length':填充到指定的固定长度padding='max_length',max_length=conf.pad_size,  # 设定目标长度truncation=True,  # 开启截断,防止超出模型限制return_attention_mask=True  # 请求返回注意力掩码,以区分输入中的有效信息和填充信息)# 从文本令牌中提取输入IDinput_ids = text_tokens['input_ids']# 从文本令牌中提取注意力掩码attention_mask = text_tokens['attention_mask']# 将输入的token ID列表转换为张量input_ids = torch.tensor(input_ids)# 将注意力掩码列表转换为张量attention_mask = torch.tensor(attention_mask)# 将标签列表转换为张量labels = torch.tensor(labels)# 返回转换后的张量元组return input_ids, attention_mask, labels

DataLoader 封装

def build_dataloader():# 加载原始数据train_data_list = load_raw_data(conf.train_datapath)dev_data_list = load_raw_data(conf.dev_datapath)test_data_list = load_raw_data(conf.test_datapath)# 构建训练集train_dataset = TextDataset(train_data_list)dev_dataset = TextDataset(dev_data_list)test_dataset = TextDataset(test_data_list)# 构建DataLoadertrain_dataloader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)dev_dataloader = DataLoader(dev_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)test_dataloader = DataLoader(test_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)return train_dataloader, dev_dataloader, test_dataloader

完整代码

# 加载数据工具类
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, TextDataset
from transformers.utils import PaddingStrategy
from config import Config
import time# 加载配置
conf = Config()# todo 加载并处理原始数据
def load_raw_data(file_path):"""从指定文件中加载原始数据。处理文本文件,返回(文本, 标签, 类别)元组列表参数:file_path: 文本文件路径返回:list: 包含(文本, 标签, 类别)的元组列表,类别为int类型"""result = []# 打印指定文件with open(file_path, 'r', encoding='utf-8') as f:# 使用tqdm包装文件读取迭代器,以便显示加载数据的进度条for line in tqdm(f, desc="加载数据..."):# 移除行两端的空白字符line = line.strip()# 跳过空行if not line:continue# 将行分割成文本和标签两部分text, label = line.split("\t")# 将标签转为int作为类别label = int(label)# 将文本和转换为整数的标签作为元组添加到数据列表中result.append((text, label))# 返回处理后的列表return result# todo 自定义数据集
class TextDataset(Dataset):# 初始化数据def __init__(self, data_list):self.data_list = data_list# 返回数据集长度def __len__(self):return len(self.data_list)# 根据样本索引,返回对应的特征和标签def __getitem__(self, idx):text, label = self.data_list[idx]return text, label# todo 批量处理数据
# 每当 DataLoader 从 Dataset 中取出一个 batch 的原始数据后,
# 就会调用 collate_fn 来对这个 batch 进行统一处理(如填充、转换为张量等)。
def collate_fn(batch):"""对batch数据进行padding处理参数: batch: 包含(文本, 标签)元组的batch数据返回: tuple: 包含处理后的input_ids, attention_mask和labels的元组"""# todo 使用zip()将一批batch数据中的(text, label)元组拆分成两个独立的元组# texts = [item[0] for item in batch]# labels = [item[1] for item in batch]texts, labels = zip(*batch)# 对文本进行paddingtext_tokens = conf.tokenizer.batch_encode_plus(texts,add_special_tokens=True,  # 默认True,自动添加 [CLS] 和 [SEP]padding='max_length',  # 固定长度max_length=conf.pad_size,  # 设定目标长度truncation=True,  # 开启截断,防止超出模型限制return_attention_mask=True  # 请求返回注意力掩码,以区分输入中的有效信息和填充信息)# 从文本令牌中提取输入IDinput_ids = text_tokens['input_ids']# 从文本令牌中提取注意力掩码attention_mask = text_tokens['attention_mask']# 将输入的token ID列表转换为张量input_ids = torch.tensor(input_ids)# 将注意力掩码列表转换为张量attention_mask = torch.tensor(attention_mask)# 将标签列表转换为张量labels = torch.tensor(labels)# 返回转换后的张量元组return input_ids, attention_mask, labels# todo 构建dataloader
def build_dataloader():# 加载原始数据train_data_list = load_raw_data(conf.train_datapath)dev_data_list = load_raw_data(conf.dev_datapath)test_data_list = load_raw_data(conf.test_datapath)# 构建训练集train_dataset = TextDataset(train_data_list)dev_dataset = TextDataset(dev_data_list)test_dataset = TextDataset(test_data_list)# 构建DataLoadertrain_dataloader = DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)dev_dataloader = DataLoader(dev_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)test_dataloader = DataLoader(test_dataset, batch_size=conf.batch_size, shuffle=False, collate_fn=collate_fn)return train_dataloader, dev_dataloader, test_dataloaderif __name__ == '__main__':# 测试load_raw_data方法data_list = load_raw_data(conf.dev_datapath)print(data_list[:10])# 测试TextDataset类dataset = TextDataset(data_list)print(dataset[0])print(dataset[1])# 测试build_dataloader方法train_dataloader, dev_dataloader, test_dataloader = build_dataloader()print(len(train_dataloader))print(len(dev_dataloader))print(len(test_dataloader))# 测试collate_fn方法"""for i, batch in enumerate(train_dataloader)流程如下:1.DataLoader 从你的 Dataset 中取出一组索引;2.使用这些索引调用 Dataset.__getitem__ 获取原始样本;3.将这一组样本组成一个 batch(通常是 (text, label) 元组的列表);4.自动调用你传入的 collate_fn 函数来处理这个 batch 数据;5.返回处理后的 batch(如 input_ids, attention_mask, labels)供模型使用。"""for i, batch in enumerate(train_dataloader):print(len(batch))print(i)input_ids, attention_mask, labels = batch# print("input_ids: ", input_ids.tolist())print("input_ids.shape: ", input_ids.shape)# print("attention_mask: ", attention_mask.tolist())print("attention_mask.shape: ", attention_mask.shape)# print("labels: ", labels.tolist())print("labels.shape: ", labels.shape)break

4、对输入文本进行分类

  • 流程

    1. 加载预训练 BERT 模型(BertModel)作为特征提取器。

    2. 添加全连接层(nn.Linear)进行类别预测。

    3. 提供测试代码,演示从文本输入到预测结果的全过程。

  • outputs 是一个包含多个输出的元组,对于分类任务我们主要关注:

    • last_hidden_state: 序列中每个 token 的隐藏状态 (shape: [batch_size, sequence_length, hidden_size])

    • pooler_output: 经过池化的整个序列的表示 (shape: [batch_size, hidden_size])

假设:

  • batch_size = 32

  • sequence_length = 128 (经过 padding/truncation 后的长度)

  • hidden_size = 768 (BERT-base 的隐藏层大小)

  • num_classes = 10

数据流过程:

  1. 输入 input_ids[32, 128]

  2. 输入 attention_mask[32, 128]

  3. BERT 输出 pooler_output[32, 768]

  4. 全连接层输出 logits[32, 10]

import torch
import torch.nn as nn
from transformers import BertModel
from config import Config# 加载配置
config = Config()# 定义bert模型
class BertClassifier(nn.Module):def __init__(self):# 初始化父类类的构造函数super().__init__()# 下面的BertModel是从transformers库中加载的预训练模型# config.bert_path是预训练模型的路径self.bert = BertModel.from_pretrained(config.bert_path)# 定义全连接层(fc),用于分类任务# 输入尺寸是Bert模型隐藏层的大小,即768(对于Base模型)# 输出尺寸是类别数量,由config.num_classes指定self.fc = nn.Linear(config.bert_config.hidden_size, config.num_classes)def forward(self, input_ids, attention_mask):# 使用BERT模型处理输入的token ID和注意力掩码,获取BERT模型的输出# outputs是: _,pooledoutputs = self.bert(input_ids=input_ids,  # 输入的token IDattention_mask=attention_mask  # 注意力掩码用于区分有效token和填充token)# print(outputs) # 观察结果# 通过全连接层对BERT模型的输出进行分类logits = self.fc(outputs.pooler_output)# 返回分类的logits(未归一化的预测分数)return logits# 测试以上模型
if __name__ == '__main__':# 测试model = BertClassifier()# 加载from transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained(config.bert_path)# 示例文本texts = ["我喜欢你", "今天天气真好"]# 编码文本encoded_inputs = tokenizer(texts,# padding=True,  #  所有的填充到文本最大长度padding="max_length", # 所有的填充到指定的max_length长度truncation=True, # 如果超出指定的max_length长度,则截断max_length=10,return_tensors="pt"  # 返回 pytorch 张量,"pt" 时,分词器会将输入文本转换为模型可接受的格式)# 获取 input_ids 和 attention_maskinput_ids = encoded_inputs["input_ids"]attention_mask = encoded_inputs["attention_mask"]print('input_ids:', input_ids)print('attention_mask:', attention_mask)print('======================================')# 预测logits = model(input_ids, attention_mask)print(logits)  # 每一行对应一个样本,每个数字表示该样本属于某一类别的“得分”(logit),没有经过 softmax 归一化。print('-------------------------------')# 获取预测概率probs = torch.softmax(logits, dim=-1)print(probs)  # 归一化后该样本属于某类的概率(范围在 0~1 之间),概率最高的就是预测结果print('-------------------------------')preds = torch.argmax(probs, dim=-1)print(preds)  # 得到每个样本的预测类别。表示两个输入文本被模型预测为类别 6(从 0 开始计数)。

 5、计算模型的分类性能指标

功能:传入模型、数据、设备,返回分类报告

完整代码:

import torch
from sklearn.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score
from tqdm import tqdmdef model2dev(model, data_loader, device):"""在验证或测试集上评估 BERT 分类模型的性能。参数:model (nn.Module): BERT 分类模型。data_loader (DataLoader): 数据加载器(验证或测试集)。device (str): 设备("cuda" 或 "cpu")。返回:tuple: (分类报告, F1 分数, 准确度, 精确度,召回率)- report: 分类报告(包含每个类别的精确度、召回率、F1 分数等)。- f1score: 微平均 F1 分数。- accuracy: 准确度。- precision: 微平均精确度- recall: 微平均召回率"""# todo 1. 设置模型为评估模式(禁用 dropout,并改变batch_norm行为)model.eval()# 2. 初始化列表,存储预测结果和真实标签all_preds, all_labels = [], []# 3. todo torch.no_grad()禁用梯度计算以提高效率并减少内存占用with torch.no_grad():# 4. 遍历数据加载器,逐批次进行预测for i, batch in enumerate(tqdm(data_loader, desc="验证集评估中...")):# 4.1 提取批次数据并移动到设备input_ids, attention_mask, labels = batchinput_ids = input_ids.to(device)attention_mask = attention_mask.to(device)labels = labels.to(device)# 4.2 前向传播:模型预测outputs = model(input_ids, attention_mask=attention_mask)# 4.3 获取预测结果(最大 logits分数 对应的类别)y_pred_list = torch.argmax(outputs, dim=1)# 4.4 存储预测和真实标签all_preds.extend(y_pred_list.cpu().tolist())all_labels.extend(labels.cpu().tolist())# 5. 计算分类报告、F1 分数、准确率,精确率,召回率report = classification_report(all_labels, all_preds)f1score = f1_score(all_labels, all_preds, average='macro')accuracy = accuracy_score(all_labels, all_preds)precision = precision_score(all_labels, all_preds, average='macro')recall = recall_score(all_labels, all_preds, average='macro')# 6. 返回评估结果return report, f1score, accuracy, precision, recall

6、模型训练

口诀:15241

加载配置文件和参数(1)

# todo 加载配置对象,包含模型参数、路径等
conf = Config()
# todo 导入数据处理工具类
from a1_dataloader_utils import build_dataloader
# todo 导入bert模型
from a2_bert_classifer_model import BertClassifier

准备数据(4)

    # todo 1、准备数据train_dataloader, dev_dataloader, test_dataloader = build_dataloader()# todo 2、准备模型# 2.1初始化bert分类模型model = BertClassifier()# 2.2将模型移动到指定的设备model.to(conf.device)# todo 3.准备损失函数loss_fn = nn.CrossEntropyLoss()# todo 4.准备优化器optimizer = AdamW(model.parameters(), lr=conf.learning_rate)

模型训练

外层遍历轮次,内层遍历批次

前向传播:模型预测、计算损失

反向传播:梯度清零、计算梯度、参数更新

# todo 5.开始训练模型# 初始化F1分数,用于保存最好的模型best_f1 = 0.0# todo 5.1 外层循环遍历每个训练轮次#  (每次需要设置训练模式,累计损失,预存训练集测试和真实标签)for epoch in range(conf.num_epochs):# 设置模型为训练模式model.train()# 初始化累计损失,初始化训练集预测和真实标签total_loss = 0.0train_preds, train_labels = [], []# todo 5.2 内层循环遍历训练DataLoader每个批次for i, batch in enumerate(tqdm(train_dataloader, desc="训练集训练中...")):# 提取批次数据并移动到设备input_ids, attention_mask, labels = batchinput_ids = input_ids.to(conf.device)attention_mask = attention_mask.to(conf.device)labels = labels.to(conf.device)# todo 前向传播:模型预测logits = model(input_ids, attention_mask)# todo 计算损失loss = loss_fn(logits, labels)# 累计损失total_loss += loss.item()# todo 获取预测结果(最大logits对应的类别)y_pred_list = torch.argmax(logits, dim=1)# todo 存储预测和真实标签,用于计算训练集指标train_preds.extend(y_pred_list.cpu().tolist())train_labels.extend(labels.cpu().tolist())# todo 梯度清零optimizer.zero_grad()# todo 反向传播:计算梯度loss.backward()# todo 参数更新:根据梯度更新模型参数optimizer.step()

验证评估

            # todo 每10个批次或一个轮次结束,计算训练集指标if (i + 1) % 10 == 0 or i == len(train_dataloader) - 1:# 计算准确率和f1值acc = accuracy_score(train_labels, train_preds)f1 = f1_score(train_labels, train_preds, average='macro')# 获取batch_count,并计算平均损失batch_count = i % 10 + 1avg_loss = total_loss / batch_count# todo 打印训练信息print(f"\n轮次: {epoch + 1}, 批次: {i + 1}, 损失: {avg_loss:.4f}, acc准确率:{acc:.4f}, f1分数:{f1:.4f}")# todo 清空累计损失和预测和真实标签total_loss = 0.0train_preds, train_labels = [], []# todo 每100个批次或一个轮次结束,计算验证集指标,打印,保存模型if (i + 1) % 100 == 0 or i == len(train_dataloader) - 1:# 计算在测试集的评估报告,准确率,精确率,召回率,f1值report, f1score, accuracy, precision, recall = model2dev(model, dev_dataloader, conf.device)print("验证集评估报告:\n", report)print(f"验证集的f1: {f1score:.4f}, accuracy:{accuracy:.4f}, precision:{precision:.4f}, recall:{recall:.4f}")# todo 将模型再设置为训练模式model.train()# todo 如果验证F1分数优于历史最佳,保存模型if f1score > best_f1:# 更新历史最佳F1分数best_f1 = f1score# 保存模型torch.save(model.state_dict(), conf.model_save_path)print("保存模型成功, 当前f1分数:", best_f1)

完整代码

import torch
import torch.nn as nn
from torch.optim import AdamW
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
from model2dev_utils import model2dev
from config import Config
from a1_dataloader_utils import build_dataloader
# 忽略的警告信息
import warningswarnings.filterwarnings("ignore")
# todo 导入bert模型
from a2_bert_classifer_model import BertClassifier# 加载配置对象,包含模型参数、路径等
conf = Config()def model2train():"""训练 BERT 分类模型并在验证集上评估,保存最佳模型。参数:无显式参数,所有配置通过全局 conf 对象获取。返回:无返回值,训练过程中保存最佳模型到指定路径。"""# todo 1、准备数据train_dataloader, dev_dataloader, test_dataloader = build_dataloader()# todo 2、准备模型# 2.1初始化bert分类模型model = BertClassifier()# 2.2将模型移动到指定的设备model.to(conf.device)# todo 3.准备损失函数loss_fn = nn.CrossEntropyLoss()# todo 4.准备优化器optimizer = AdamW(model.parameters(), lr=conf.learning_rate)# todo 5.开始训练模型# 初始化F1分数,用于保存最好的模型best_f1 = 0.0# todo 5.1 外层循环遍历每个训练轮次#  (每次需要设置训练模式,累计损失,预存训练集测试和真实标签)for epoch in range(conf.num_epochs):# 设置模型为训练模式model.train()# 初始化累计损失,初始化训练集预测和真实标签total_loss = 0.0train_preds, train_labels = [], []# todo 5.2 内层循环遍历训练DataLoader每个批次for i, batch in enumerate(tqdm(train_dataloader, desc="训练集训练中...")):# 提取批次数据并移动到设备input_ids, attention_mask, labels = batchinput_ids = input_ids.to(conf.device)attention_mask = attention_mask.to(conf.device)labels = labels.to(conf.device)# todo 前向传播:模型预测outputs = model(input_ids, attention_mask)# todo 计算损失loss = loss_fn(outputs, labels)# todo 累计损失total_loss += loss.item()# todo 获取预测结果(最大logits对应的类别)y_pred_list = torch.argmax(outputs, dim=1)# todo 存储预测和真实标签,用于计算训练集指标train_preds.extend(y_pred_list.cpu().tolist())train_labels.extend(labels.cpu().tolist())# todo 梯度清零optimizer.zero_grad()# todo 反向传播:计算梯度loss.backward()# todo 参数更新:根据梯度更新模型参数optimizer.step()# todo 每10个批次或一个轮次结束,计算训练集指标if (i + 1) % 10 == 0 or i == len(train_dataloader) - 1:# 计算准确率和f1值acc = accuracy_score(train_labels, train_preds)f1 = f1_score(train_labels, train_preds, average='macro')# 获取batch_count,并计算平均损失batch_count = i % 10 + 1avg_loss = total_loss / batch_count# todo 打印训练信息print(f"\n轮次: {epoch + 1}, 批次: {i + 1}, 损失: {avg_loss:.4f}, acc准确率:{acc:.4f}, f1分数:{f1:.4f}")# todo 清空累计损失和预测和真实标签total_loss = 0.0train_preds, train_labels = [], []# todo 每100个批次或一个轮次结束,计算验证集指标,打印,保存模型if (i + 1) % 100 == 0 or i == len(train_dataloader) - 1:# 计算在测试集的评估报告,准确率,精确率,召回率,f1值report, f1score, accuracy, precision, recall = model2dev(model, dev_dataloader, conf.device)print("验证集评估报告:\n", report)print(f"验证集f1: {f1score:.4f}, accuracy:{accuracy:.4f}, precision:{precision:.4f}, recall:{recall:.4f}")# 将模型设置为训练模式model.train()# todo 如果验证F1分数优于历史最佳,保存模型if f1score > best_f1:# 更新历史最佳F1分数best_f1 = f1score# 保存模型torch.save(model.state_dict(), conf.model_save_path)print("保存模型成功, 当前f1分数:", best_f1)if __name__ == '__main__':model2train()

 7、基于BERT的文本分类预测接口

  1. 加载预训练好的BERT分类模型(BertClassifier)。

  2. 对输入文本进行分词、编码、转换为张量。

  3. 使用模型预测文本类别。

  4. 返回类别名称(如 "education"

import torch
from a2_bert_classifer_model import BertClassifier
from config import Config# 加载配置
conf = Config()# todo 准备模型
model = BertClassifier()
# todo 加载模型参数
model.load_state_dict(torch.load(conf.model_save_path))
# todo 添加模型到指定设备
model.to(conf.device)
# todo 设置模型为评估模式
model.eval()# TODO 定义predict_fun函数预测函数
def predict_fun(data_dict):"""根据用户录入数据,返回分类信息:param 参数 data_dict: {"text":"状元心经:考前一周重点是回顾和整理"}:return: 返回 data_dict: {"text":"状元心经:考前一周重点是回顾和整理", "pred_class":"education"}"""# todo 获取文本text = data_dict['text']# todo 将文本转为idtext_tokens = conf.tokenizer.batch_encode_plus([text],padding="max_length",max_length=conf.pad_size,pad_to_max_length=True)# todo 获取input_ids和attention_maskinput_ids = text_tokens['input_ids']attention_mask = text_tokens['attention_mask']# todo 将input_ids和attention_mask转为tensor, 并指定到设备input_ids = torch.tensor(input_ids).to(conf.device)attention_mask = torch.tensor(attention_mask).to(conf.device)# todo 设置不进行梯度计算(在该上下文中禁用梯度计算,提升推理速度并减少内存占用)with torch.no_grad():# 前向传播(模型预测)output = model(input_ids, attention_mask)# 获取预测类别索引张量output = torch.argmax(output, dim=1)# 获取预测类别索引标量pred_idx = output.item()# 获取类别名称pred_class = conf.class_list[pred_idx]print(pred_class)# 将预测结果添加到data_dict中data_dict['pred_class'] = pred_class# 返回data_dictreturn data_dictif __name__ == '__main__':data_dict = {'text': '状元心经:考前一周重点是回顾和整理'}print(predict_fun(data_dict))

问题总结

持续更新中......

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/90765.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/90765.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sendfile系统调用及示例

好的,我们继续学习 Linux 系统编程中的重要函数。这次我们介绍 sendfile 函数,它是一个高效的系统调用,用于在两个文件描述符之间直接传输数据,通常用于将文件内容发送到网络套接字,而无需将数据从内核空间复制到用户空…

数据结构习题--删除排序数组中的重复项

数据结构习题–删除排序数组中的重复项 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 方法&…

Docker的容器设置随Docker的启动而启动

原因也比较简单,在docker run 的时候没有设置–restartalways参数。 容器启动时,需要增加参数 –restartalways no - 容器退出时,不重启容器; on-failure - 只有在非0状态退出时才从新启动容器; always - 无论退出状态…

JWT安全机制与最佳实践详解

JWT(JSON Web Token) 是一种开放标准(RFC 7519),用于在各方之间安全地传输信息作为紧凑且自包含的 JSON 对象。它被广泛用于身份验证(Authentication)和授权(Authorization&#xff…

如何解决pip安装报错ModuleNotFoundError: No module named ‘ipython’问题

【Python系列Bug修复PyCharm控制台pip install报错】如何解决pip安装报错ModuleNotFoundError: No module named ‘ipython’问题 摘要 在开发过程中,我们常常会遇到pip install报错的问题,其中一个常见的报错是 ModuleNotFoundError: No module named…

从三维Coulomb势到二维对数势的下降法推导

题目 问题 7. 应用 9.1.4 小节描述的下降法,但针对二维的拉普拉斯方程,并从三维的 Coulomb 势出发 KaTeX parse error: Invalid delimiter: {"type":"ordgroup","mode":"math","loc":{"lexer&qu…

直播一体机技术方案解析:基于RK3588S的硬件架构特性​

硬件配置​​主控平台​​▸ 搭载瑞芯微RK3588S旗舰处理器(四核A762.4GHz 四核A55)▸ 集成ARM Mali-G610 MP4 GPU 6TOPS算力NPU▸ 双通道LPDDR5内存 UFS3.1存储组合​​专用加速单元​​→ 板载视频采集模块:支持4K60fps HDMI环出采集→ 集…

【氮化镓】GaN取代GaAs作为空间激光无线能量传输光伏转换器材料

2025年7月1日,西班牙圣地亚哥-德孔波斯特拉大学的Javier F. Lozano等人在《Optics and Laser Technology》期刊发表了题为《Gallium nitride: a strong candidate to replace GaAs as base material for optical photovoltaic converters in space exploration》的文章,基于T…

直播美颜SDK动态贴纸模块开发指南:从人脸关键点识别到3D贴合

很多美颜技术开发者好奇,如何在直播美颜SDK中实现一个高质量的动态贴纸模块?这不是简单地“贴图贴脸”,而是一个融合人脸关键点识别、实时渲染、贴纸驱动逻辑、3D骨骼动画与跨平台性能优化的系统工程。今天,就让我们从底层技术出发…

学习游戏制作记录(剑投掷技能)7.26

1.实现瞄准状态和接剑状态准备好瞄准动画,投掷动画和接剑动画,并设置参数AimSword和CatchSword投掷动画在瞄准动画后,瞄准结束后才能投掷创建PlayerAimSwordState脚本和PlayerCatchSwordState脚本并在Player中初始化:PlayerAimSwo…

【c++】问答系统代码改进解析:新增日志系统提升可维护性——关于我用AI编写了一个聊天机器人……(14)

在软件开发中,代码的迭代优化往往从提升可维护性、可追踪性入手。本文将详细解析新增的日志系统改进,以及这些改进如何提升系统的实用性和可调试性。一、代码整体背景代码实现了一个基于 TF-IDF 算法的问答系统,核心功能包括:加载…

visual studio2022编译unreal engine5.4.4源码

UE5系列文章目录 文章目录 UE5系列文章目录 前言 一、ue5官网 二.编译源码中遇到的问题 前言 一、ue5官网 UE5官网 UE5源码下载地址 这样虽然下载比较快,但是不能进行代码git管理,以后如何虚幻官方有大的版本变动需要重新下载源码,所以我们还是最好需要visual studio2022…

vulhub Earth靶场攻略

靶场下载 下载链接:https://download.vulnhub.com/theplanets/Earth.ova 靶场使用 将压缩包解压到一个文件夹中,右键,用虚拟机打开,就创建成功了,然后启动虚拟机: 这时候靶场已经启动了,咱们现…

Python训练Day24

浙大疏锦行 元组可迭代对象os模块

Spring核心:Bean生命周期、外部化配置与组件扫描深度解析

Bean生命周期 说明 程序中的每个对象都有生命周期,对象的创建、初始化、应用、销毁的整个过程称之为对象的生命周期; 在对象创建以后需要初始化,应用完成以后需要销毁时执行的一些方法,可以称之为是生命周期方法; 在sp…

日语学习-日语知识点小记-进阶-JLPT-真题训练-N1阶段(1):2017年12月-JLPT-N1

日语学习-日语知识点小记-进阶-JLPT-真题训练-N1阶段(1):2017年12月-JLPT-N1 1、前言(1)情况说明(2)工程师的信仰(3)真题训练2、真题-2017年12月-JLPT-N1(1&a…

(一)使用 LangChain 从零开始构建 RAG 系统|RAG From Scratch

RAG 的主要动机 大模型训练的时候虽然使用了庞大的世界数据,但是并没有涵盖用户关心的所有数据, 其预训练令牌(token)数量虽大但相对这些数据仍有限。另外大模型输入的上下文窗口越来越大,从几千个token到几万个token,…

OpenCV学习探秘之一 :了解opencv技术及架构解析、数据结构与内存管理​等基础

​一、OpenCV概述与技术演进​ 1.1技术历史​ OpenCV(Open Source Computer Vision Library)是由Intel于1999年发起创建的开源计算机视觉库,后来交由OpenCV开源社区维护,旨在为计算机视觉应用提供通用基础设施。经历20余年发展&…

什么是JUC

摘要 Java并发工具包JUC是JDK5.0引入的重要并发编程工具,提供了更高级、灵活的并发控制机制。JUC包含锁与同步器(如ReentrantLock、Semaphore等)、线程安全队列(BlockingQueue)、原子变量(AtomicInteger等…

零基础学后端-PHP语言(第二期-PHP基础语法)(通过php内置服务器运行php文件)

经过上期的配置,我们已经有了php的开发环境,编辑器我们继续使用VScode,如果是新来的朋友可以看这期文章来配置VScode 零基础学前端-传统前端开发(第一期-开发软件介绍与本系列目标)(VScode安装教程&#x…