针对Helsinki-NLP/opus-mt-zh-en模型进行双向互翻的微调

引言
 题目听起来有点怪怪的,但是实际上就是对Helsinki-NLP/opus-mt-en-es模型进行微调。但是这个模型是单向的,只支持中到英的翻译,反之则不行。这样的话,如果要做中英双向互翻就需要两个模型,那模型体积直接大了两倍。尤其是部署到手机上,模型的体积是一个非常重要的考虑因素。于是自己就对这个模型的微调过程进行了一些改动,实现了单个模型进行双向互翻的能力。

原生模型
 这里给出原生模型的使用方法:

from transformers import AutoModel , AutoTokenizer,MarianMTModeltext ="你好,你是谁?"
name ='Helsinki-NLP/opus-mt-zh-en'
tokenizer = AutoTokenizer.from_pretrained(name)
model = MarianMTModel.from_pretrained(name)
input_ids = tokenizer.encode(text, return_tensors="pt")
outputs = model.generate(input_ids)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(decoded)

需要改动的地方
 因为涉及到互翻,所以首先要告诉模型翻译的方向,具体就是在文本数据之前加一个目标语言的标识符,比如中翻英,原文“你好,你是谁?”,处理后就是“>>eng<< 你好,你是谁?”,英翻中则是“>>zho<< Hello,who are you?”

 因此就引出了一个问题,词表vocab.json中并没有“>>eng<<”和“>>zho<<”,那么分词就会出现问题。我尝试过两种方法来解决:

  • 首先是常规的解决办法,我最开始直接将这两个标识当做新的token加入词表中,最后也能跑通。这里只描述思想,具体的实现在下面的代码中会体现。
  • 然后就是我自己想的非常规方法,为啥自己又想了非常规的方法呢,那是因为我在训练好模型之后,要将模型转换为CT2的格式,但是这个转换过程中因为添加了2个新token导致了报错,搞了一圈也没有解决,于是直接把词表中两个极其罕见的token给删除了,用两个语言标识替代,这样既不会对翻译产生大的影响,又能完成模型格式的转换。当然,这是需要先改词表后进行微调,顺序不能反了。

解决办法一
 通过下面的代码微调之后,就能得到一个双向的翻译能力的模型了,使用的方法和原生模型一样,直接加载就能推理了。

import torch
import evaluate
import zhconv
from datasets import load_dataset, Dataset
import sacrebleu
import os
from transformers import (AutoTokenizer, MarianMTModel,Seq2SeqTrainer, Seq2SeqTrainingArguments,DataCollatorForSeq2Seq
)# 加载 tokenizer,并添加语言标签
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
special_tokens = [">>eng<<", ">>zho<<"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})# 加载模型,并扩展嵌入层大小
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
model.resize_token_embeddings(len(tokenizer))# 设置 token ID
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id'''
加载 Tatoeba 数据集(中英句对)
这里我使用的是公开的数据集,可以通过下面的代码来加载到本地。加载到本地之后就可以把data_files换成你自己的地址
import kagglehub
alvations_tatoeba_path = kagglehub.dataset_download('alvations/tatoeba')
'''
tatoeba_dataset = load_dataset("csv",data_files="./data/tatoeba-sentpairs.tsv",delimiter="\t",encoding="utf-8",split="train"
)# 过滤中英句对(zh→en 和 en→zh)
zh2en_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "cmn" and x['TRG LANG'] == 'eng')
en2zh_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "eng" and x['TRG LANG'] == 'cmn')# 预处理函数:添加语言标签 + 分词
def preprocess_zh2en(batch):# 添加目标语言 tokeninputs = [">>eng<< " + x for x in batch['SRC']]# 可选:进行繁转简inputs = [zhconv.convert(x, 'zh-cn') for x in inputs]targets = batch['TRG']# 编码inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(训练时用于 loss)}def preprocess_en2zh(batch):# 添加目标语言 tokeninputs = [">>zho<< " + x for x in batch['SRC']]# 可选:进行繁转简targets = batch['TRG']targets = [zhconv.convert(x, 'zh-cn') for x in targets]# 编码inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(训练时用于 loss)}# 数据清洗 + 映射分词
zh2en_dataset = zh2en_dataset.map(preprocess_zh2en, batched=True)
en2zh_dataset = en2zh_dataset.map(preprocess_en2zh, batched=True)# 合并中→英和英→中双向数据
combined_dataset = Dataset.from_dict({key: zh2en_dataset[key] + en2zh_dataset[key] for key in zh2en_dataset.features
})# 拆分训练集和测试集
split_dataset = combined_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]def compute_metrics(pred):pred_ids = pred.predictionslabel_ids = pred.label_idspred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)label_ids[label_ids == -100] = tokenizer.pad_token_idlabel_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)bleu = sacrebleu.corpus_bleu(pred_str, [label_str])# 保存验证的结果到本地文件,这样可以实时查看微调的效果save_dir = "./eval_logs"os.makedirs(save_dir, exist_ok=True)eval_id = f"step_{trainer.state.global_step}" if hasattr(trainer, "state") else "eval"output_file = os.path.join(save_dir, f"pred_vs_ref_{eval_id}.txt")with open(output_file, "w", encoding="utf-8") as f:for i, (pred, ref) in enumerate(zip(pred_str, label_str)):f.write(f"Sample {i + 1}:\n")f.write(f"Prediction: {pred}\n")f.write(f"Reference : {ref}\n")f.write("=" * 50 + "\n")return {"bleu": bleu.score}# 训练参数
training_args = Seq2SeqTrainingArguments(output_dir='./model/marian-zh-en-bidirectional',num_train_epochs=30,per_device_train_batch_size=16,per_device_eval_batch_size=16,logging_steps=50,save_steps=100,eval_steps=100,eval_strategy="steps",predict_with_generate=True,save_total_limit=10,report_to="tensorboard",  # 启用 TensorBoard 日志记录logging_dir='./logs',  # 指定 TensorBoard 日志的保存路径
)# 构建 Trainer
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=train_dataset.with_format("torch"),eval_dataset=eval_dataset.with_format("torch"),tokenizer=tokenizer,data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),compute_metrics=compute_metrics
)# 开始训练
trainer.train(resume_from_checkpoint=False)# 保存模型和 tokenizer
model.save_pretrained("./model/marian-zh-en-bidirectional")
tokenizer.save_pretrained("./model/marian-zh-en-bidirectional")

解决办法二
 上面是针对大众场景,具体的场景需要做具体的改动。本方法就是根据我的业务场景来修改的。

 方法一训练得到的模型是使用tokenizer来编解码,因为目标语言标识符已经加入到词表里了,所以编解码没问题。但是我转为CT2格式之后,分词使用的是sentencepiece模型,具体就是用source.spm、target.spm分别对中文和英文进行分词,然后根据共享词表转换为token的id。 共享词表中是有语言标识符的,但是sentencepiece模型里却没有添加两个新token,所以就无法识别,导致分词错误。我的做法就是推理的时候先不加目标语言的标识符,先分词,然后手动加上去。这样分词就不会出问题了,然后进行编码就能根据共享词表来编码了。

 还有一个问题就是,输入是中英混合的文本,这样sentencepiece分词器也无法正确识别,一个办法就是将中英文分开,分别进行分词,然后将分词的结果按顺序进行拼接。

 最后,以上都是基于不重新训练分词模型的做法,如果可以重新训练分词模型,那么就不需要搞上面哪些操作了。

import torch
import evaluate
import zhconv
from datasets import load_dataset, Dataset
import sacrebleu
import os
from transformers import (AutoTokenizer, MarianMTModel,Seq2SeqTrainer, Seq2SeqTrainingArguments,DataCollatorForSeq2Seq
)# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")# 加载模型
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-zh-en")# 设置 token ID
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id'''
加载 Tatoeba 数据集(中英句对)
这里我使用的是公开的数据集,可以通过下面的代码来加载到本地。加载到本地之后就可以把data_files换成你自己的地址
import kagglehub
alvations_tatoeba_path = kagglehub.dataset_download('alvations/tatoeba')
'''
tatoeba_dataset = load_dataset("csv",data_files="./data/tatoeba-sentpairs.tsv",delimiter="\t",encoding="utf-8",split="train"
)# 过滤中英句对(zh→en 和 en→zh)
zh2en_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "cmn" and x['TRG LANG'] == 'eng')
en2zh_dataset = tatoeba_dataset.filter(lambda x: x['SRC LANG'] == "eng" and x['TRG LANG'] == 'cmn')# 预处理函数:添加语言标签 + 分词
def preprocess_zh2en(batch):# 添加目标语言 tokeninputs = [">>eng<< " + x for x in batch['SRC']]# 可选:进行繁转简inputs = [zhconv.convert(x, 'zh-cn') for x in inputs]targets = batch['TRG']# 编码inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(训练时用于 loss)}def preprocess_en2zh(batch):# 添加目标语言 tokeninputs = [">>zho<< " + x for x in batch['SRC']]# 可选:进行繁转简targets = batch['TRG']targets = [zhconv.convert(x, 'zh-cn') for x in targets]# 编码inputs_encoded = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")outputs_encoded = tokenizer(targets, max_length=128, truncation=True, padding="max_length" )return {"input_ids": inputs_encoded["input_ids"],"attention_mask": inputs_encoded["attention_mask"],"decoder_input_ids": outputs_encoded["input_ids"],"decoder_attention_mask": outputs_encoded["attention_mask"],"labels": outputs_encoded["input_ids"].copy(),  # labels 通常跟 decoder_input_ids 相同(训练时用于 loss)}# 数据清洗 + 映射分词
zh2en_dataset = zh2en_dataset.map(preprocess_zh2en, batched=True)
en2zh_dataset = en2zh_dataset.map(preprocess_en2zh, batched=True)# 合并中→英和英→中双向数据
combined_dataset = Dataset.from_dict({key: zh2en_dataset[key] + en2zh_dataset[key] for key in zh2en_dataset.features
})# 拆分训练集和测试集
split_dataset = combined_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]def compute_metrics(pred):pred_ids = pred.predictionslabel_ids = pred.label_idspred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)label_ids[label_ids == -100] = tokenizer.pad_token_idlabel_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)bleu = sacrebleu.corpus_bleu(pred_str, [label_str])# 保存验证的结果到本地文件,这样可以实时查看微调的效果save_dir = "./eval_logs"os.makedirs(save_dir, exist_ok=True)eval_id = f"step_{trainer.state.global_step}" if hasattr(trainer, "state") else "eval"output_file = os.path.join(save_dir, f"pred_vs_ref_{eval_id}.txt")with open(output_file, "w", encoding="utf-8") as f:for i, (pred, ref) in enumerate(zip(pred_str, label_str)):f.write(f"Sample {i + 1}:\n")f.write(f"Prediction: {pred}\n")f.write(f"Reference : {ref}\n")f.write("=" * 50 + "\n")return {"bleu": bleu.score}# 训练参数
training_args = Seq2SeqTrainingArguments(output_dir='./model/marian-zh-en-bidirectional',num_train_epochs=30,per_device_train_batch_size=16,per_device_eval_batch_size=16,logging_steps=50,save_steps=100,eval_steps=100,eval_strategy="steps",predict_with_generate=True,save_total_limit=10,report_to="tensorboard",  # 启用 TensorBoard 日志记录logging_dir='./logs',  # 指定 TensorBoard 日志的保存路径
)# 构建 Trainer
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=train_dataset.with_format("torch"),eval_dataset=eval_dataset.with_format("torch"),tokenizer=tokenizer,data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),compute_metrics=compute_metrics
)# 开始训练
trainer.train(resume_from_checkpoint=False)# 保存模型和 tokenizer
model.save_pretrained("./model/marian-zh-en-bidirectional")
tokenizer.save_pretrained("./model/marian-zh-en-bidirectional")

基于训练好的模型我还搞了一套使用C++来推理的代码,方面在更多的平台使用,具体可以在github上搜"xinliu9451/Opus-Mt_Bidirectional_Translation"。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/82901.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Object转Map集合

对象与 Map 转换详解&#xff1a; Object.entries() 和 Object.fromEntries() 1&#xff0c;Object.fromEntries() 的主要用途就是将键值对集合&#xff08;如 Map&#xff09;转换为普通对象。 2&#xff0c;Object.entries() 返回一个二维数组&#xff0c;其中每个子数组包…

优先队列用法

第 5 行定义了一个队首是最大值的优先队列,第 10 行的输出如下: 27 - wuhan 21 - shanghai 11 - beijing 第 13 行定义了一个队首是最小值的优先队列,第 19 行的输出如下: 11 - beijing 21 - shanghai 27 - wuhan #include <bits/stdc.h> using namespace std; int…

Spring Boot3.4.1 集成redis

Spring Boot3.4.1 集成redis 第一步 引入依赖 <!-- redis 缓存操作 --> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency> <!-- pool 对象池 …

Replacing iptables with eBPF in Kubernetes with Cilium

source: https://archive.fosdem.org/2020/schedule/event/replacing_iptables_with_ebpf/attachments/slides/3622/export/events/attachments/replacing_iptables_with_ebpf/slides/3622/Cilium_FOSDEM_2020.pdf 使用Cilium&#xff0c;结合eBPF、Envoy、Istio和Hubble等技术…

英一真题阅读单词笔记 05年

2005 年 Text 1 第一段 序号 单词 音标 词义 1 fat [ft] a. 丰厚的&#xff0c;巨额的&#xff1b;肥胖的 2 pay [peɪ] n. 薪水 3 rise [raɪz] n. 上涨&#xff0c;增加&#xff1b;斜坡 4 pleasure [pleʒə(r)] n. 快乐&#xff1b;乐事 5 pleasure a…

FastAPI集成APsecheduler的BackgroundScheduler+mongodb(精简)

项目架构&#xff1a; FastAPI(folder) >app(folder) >core(folder) >models(folder) >routers(folder) >utils(folder) main.py(file) 1 utils文件夹下新建schedulers.py from apscheduler.schedulers.background import BackgroundScheduler from apschedu…

聊一聊接口测试中耗时请求如何合理安排?

目录 一、异步处理与轮询机制 轮询检查机制 二、 并行化测试执行 三、模拟与桩技术&#xff08;Mock/Stub&#xff09; 四、动态超时与重试策略 五、测试架构设计优化 分层测试策略 并行化执行 网络优化 六、测试用例分层管理 金字塔策略 七、 缓存与数据复用 响应…

深入详解DICOMweb:WADO与STOW-RS的技术解析与实现

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家、CSDN平台优质创作者&#xff0c;高级开发工程师&#xff0c;数学专业&#xff0c;10年以上C/C, C#, Java等多种编程语言开发经验&#xff0c;拥有高级工程师证书&#xff1b;擅长C/C、C#等开发语言&#xff0c;熟悉Java常用开…

Splunk Validated Architecture (SVA):构建企业级可观测性与安全的基石

Splunk Validated Architecture (SVA) 是 Splunk 官方提供的一套经过严格测试、性能验证和最佳实践指导的参考架构蓝图。它并非单一固定方案&#xff0c;而是根据企业数据规模、性能需求、高可用性目标和合规要求&#xff0c;提供一系列可落地的部署模型。SVA 的核心价值在于为…

Armv7l或树莓派32位RPI 4B编译faiss

pip3 install faiss-cpu当然找不到预编译的包 手动下载 git clone https://github.com/facebookresearch/faiss.git cd faiss #能需要切换到特定版本标签&#xff0c;例如 v1.7.1&#xff0c;这个版本Cmake 3.18可以过&#xff0c;因为apt install安装的cmake只更新到这里&am…

C++之string的模拟实现

string 手写C字符串类类的基本结构与成员变量一、构造函数与析构函数二、赋值运算符重载三、迭代器支持四、内存管理与扩容机制五、字符串操作函数六、运算符重载总结 手写C字符串类 从零实现一个简易版std::string 类的基本结构与成员变量 namespace zzh { class string { …

修改Docker镜像源

配置文件位置&#xff1a; sudo vim /etc/docker/daemon.json Docker 或 containerd 的镜像加速器配置&#xff0c;旨在提高从 Docker Hub 拉取镜像的速度。 { "features": { "buildkit": true, "containerd-snapshotter": true }, …

服务器带宽线路的区别(GIA、CN2、BGP、CMI等)

服务器带宽线路的区别&#xff08;GIA、CN2、BGP、CMI等&#xff09; 一、BGP线路 1. 定义与技术特点 BGP&#xff08;Border Gateway Protocol&#xff0c;边界网关协议&#xff09;是一种用于不同自治系统&#xff08;AS&#xff09;之间交换路由信息的协议&#xff0c;属…

从0到1搭建AI绘画模型:Stable Diffusion微调全流程避坑指南

从0到1搭建AI绘画模型&#xff1a;Stable Diffusion微调全流程避坑指南 系统化学习人工智能网站&#xff08;收藏&#xff09;&#xff1a;https://www.captainbed.cn/flu 文章目录 从0到1搭建AI绘画模型&#xff1a;Stable Diffusion微调全流程避坑指南摘要引言一、数据集构…

VSCode + GD32F407 构建烧录

前言 最近调试一块 GD32F407VET6&#xff08;168Mhz&#xff0c;8Mhz晶振&#xff09; 板子时&#xff0c;踩了一些“启动失败”的坑。本以为是时钟配置有误&#xff0c;最后发现是链接脚本&#xff08;.ld 文件&#xff09;没有配置好&#xff0c;导致程序根本没能正常执行 ma…

AI绘画提示词:从零开始掌握Prompt Engineering的艺术

文章目录 什么是AI绘画提示词&#xff1f;提示词的基本结构主体描述场景/背景风格指定技术参数负面提示人物肖像模板风景模板 高级技巧权重调整混合风格颜色控制情绪氛围 常见问题与解决方法手部变形问题构图不理想风格不够突出 提示词示例库科幻场景奇幻人物静物画 结语 在当今…

在 Linux 上安装 Minikube:轻松搭建本地 Kubernetes 单节点集群

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、Minikube 是什么&#xff1f; Minikube 是 Kubernetes 官方推出的轻量级工具&#xff0c;专为开发者设计&#xff0c;用于在本地快速搭建单节点 Kube…

day41 python图像识别任务

目录 一、数据预处理&#xff1a;为模型打下坚实基础 二、模型构建&#xff1a;多层感知机的实现 三、训练过程&#xff1a;迭代优化与性能评估 四、测试结果&#xff1a;模型性能的最终检验 五、总结与展望 在深度学习的旅程中&#xff0c;多层感知机&#xff08;MLP&…

JS数组 concat() 与扩展运算符的深度解析与最佳实践

文章目录 前言一、语法对比1. Array.prototype.concat()2. 扩展运算符&#xff08;解构赋值&#xff09; 二、性能差异&#xff08;大规模数组&#xff09;关键差异原因 三、适用场景建议总结 前言 最近工作中遇到了一个大规模数组合并相关的问题&#xff0c;在数据合并时有些…

一套qt c++的串口通信

实现了创建线程使用串口的功能 具备功能: 1.线程使用串口 2.定时发送队列内容&#xff0c;防止粘包 3.没处理接收粘包&#xff0c;根据你的需求来&#xff0c;handleReadyRead函数中&#xff0c;可以通过m_receiveBuffer来缓存接收&#xff0c;然后拆分数据来处理 源码 seri…