BERT 模型准备与转换详细操作流程

在尝试复现极客专栏《PyTorch 深度学习实战|24 | 文本分类:如何使用BERT构建文本分类模型?》时候,构建模型这一步骤专栏老师一笔带过,对于新手有些不友好,经过一阵摸索,终于调通了,现在总结一下整体流程。

1. 获取必要脚本文件

首先,我们需要从 Transformers 的 GitHub 仓库中找到相关文件:

# 克隆 Transformers 仓库
git clone https://github.com/huggingface/transformers.git
cd transformers

在仓库中,我们需要找到以下关键文件:

  • src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py(用于 TF1.x 模型)
  • src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py(用于 TF2.x 模型)
  • src/transformers/models/bert/modeling_bert.py(BERT 的 PyTorch 实现)

2. 下载预训练模型

接下来,我们需要下载 Google 提供的预训练 BERT 模型。根据你的需求,我们选择"BERT-Base, Multilingual Cased"版本,它支持104种语言。

访问 Google 的 BERT GitHub 页面:https://github.com/google-research/bert

在该页面中找到"BERT-Base, Multilingual Cased"的下载链接,或直接使用以下命令下载:

mkdir bert-base-multilingual-cased
cd bert-base-multilingual-cased# 下载模型文件
wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
unzip multi_cased_L-12_H-768_A-12.zip

解压后,你会得到以下文件:

  • bert_model.ckpt.data-00000-of-00001
  • bert_model.ckpt.index
  • bert_model.ckpt.meta
  • bert_config.json
  • vocab.txt

3. 模型转换

现在,我们使用之前找到的转换脚本将 TensorFlow 模型转换为 PyTorch 格式:

# 回到 transformers 目录
cd ../transformers# 执行转换脚本(针对 TF2.x 模型)
python src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/pytorch_model.bin

如果你下载的是 TF1.x 模型,则使用:

python src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/pytorch_model.bin

注意,此处需要安装tensorflow。

4. 准备完整的 PyTorch 模型目录

转换完成后,我们需要确保模型目录包含所有必要文件:

cd ../bert-base-multilingual-cased# 复制 bert_config.json 为 config.json(Transformers 库需要)
cp bert_config.json config.json

现在,你的模型目录应该包含以下三个关键文件:

  1. config.json:模型配置文件,包含了所有用于训练的参数设置
  2. pytorch_model.bin:转换后的 PyTorch 模型权重文件
  3. vocab.txt:词表文件,用于识别模型支持的各种语言的字符

5. 验证模型转换成功

为了验证模型转换是否成功,我们可以编写一个简单的脚本来加载模型并进行测试:

from transformers import BertTokenizer, BertModel# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)# 测试多语言能力
texts = ["Hello, how are you?",  # 英语"你好,最近怎么样?",    # 中文"Hola, ¿cómo estás?"   # 西班牙语
]for text in texts:inputs = tokenizer(text, return_tensors="pt")outputs = model(**inputs)print(f"Text: {text}")print(f"Shape of last hidden states: {outputs.last_hidden_state.shape}")print("---")

6. 使用模型进行下游任务

现在你可以使用这个转换好的模型进行各种下游任务,如文本分类、命名实体识别等:

from transformers import BertTokenizer, BertForSequenceClassification
import torch# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)# 初始化分类模型(假设有2个类别)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)# 准备输入
text = "这是一个测试文本"
inputs = tokenizer(text, return_tensors="pt")# 前向传播
outputs = model(**inputs)
logits = outputs.logits# 获取预测结果
predicted_class = torch.argmax(logits, dim=1).item()
print(f"预测类别: {predicted_class}")

注意事项

  1. 模型文件大小:BERT-Base 模型文件通常较大(约400MB+),请确保有足够的磁盘空间和内存。

  2. 路径问题:在执行转换脚本时,确保正确指定了所有文件的路径。

  3. 命名约定:Transformers 库期望配置文件名为 config.json,而不是 bert_config.json,所以需要进行复制或重命名。

  4. TensorFlow 版本:根据你下载的模型版本(TF1.x 或 TF2.x),选择正确的转换脚本。

  5. checkpoint 文件:转换脚本中的 --tf_checkpoint_path 参数应该指向不带后缀的 checkpoint 文件名(如 bert_model.ckpt),而不是具体的 .index.data 文件。

通过以上步骤,你就可以成功地将 Google 预训练的 BERT 模型转换为 PyTorch 格式,并在你的项目中使用它了。这个多语言版本的 BERT 模型支持 104 种语言,非常适合多语言自然语言处理任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/86195.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/86195.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

doris 和StarRocks 导入导出数据配置

一、StarRocks 导数据到hdfs EXPORT TABLE database.table TO “hdfs://namenode/tmp/demo/table” WITH BROKER ( “username”“username”, “password”“password” ); 二、StarRocks 导数据到oss EXPORT TABLE database.table TO “oss://broke/aa/” WITH BROKER ( “…

【HTTP】取消已发送的请求

场景 在页面中,可能会因为某些操作多次触发某个请求,如多次点击某按钮触发请求,实际上我们只需要最后一次请求的返回值,但是由于请求的耗时不一,请求未必会按发送的顺序返回,导致我们最终获取到的值 ≠ 最后…

JSON框架转化isSuccess()为sucess字段

在您的描述中,BankInfoVO子类返回的JSON中出现了"success": true字段,但类本身没有定义这个字段。这通常是由以下原因之一造成的: 原因分析及解决方案 序列化框架的Getter自动推导 Java序列化框架(如Jackson/Gson&…

Ragflow 源码:task_executor.py

目录 介绍主要功能核心组件 流程图核心代码解释1. 系统架构与核心组件2. 核心处理流程3. 高级处理能力4. 关键创新点5. 容错与监控机制6. 性能优化技巧 介绍 task_executor.py 是RAGFlow系统中的任务执行器(Task Executor)核心部分,主要负责文档的解析、分块(chunk…

创客匠人联盟生态:重构家庭教育知识变现的底层逻辑

在《家庭教育促进法》推动行业刚需化的背景下,单一个体 IP 的增长天花板日益明显。创客匠人提出的 “联盟生态思维”,正推动家庭教育行业从 “单打独斗” 转向 “矩阵作战”,其核心在于通过工具整合资源,将 “同行竞争” 转化为 “…

【Docker基础】Docker容器管理:docker stop详解

目录 1 Docker容器生命周期概述 2 docker stop命令深度解析 2.1 命令基本语法 2.2 命令执行流程 2.3 stop与kill的区别 3 docker stop的工作原理 3.1 工作流程 3.2 详细工作流程 3.3 信号处理机制 4 docker stop的使用场景与最佳实践 4.1 典型使用场景 场景1&#…

rules写成动态

拖拽排序和必填校验联动(rules写到computed里) computed: {rules() {const rules {};this.form.feedList.forEach((item, idx) > {rules[feedList.${idx}] [{ required: true, message: 路线评价动态${idx 1}待填写,请填写完毕提交, trigger: change }];});re…

The Open Group开放流程自动化™ 论坛(OPAF)发布组织最新进展报告

除埃克森美孚(ExxonMobil)的成就外,开放流程自动化™ 论坛(OPAF)的最新论坛报告显示,该组织其他成员也在多个领域取得进展。 “我们祝贺埃克森美孚,因为他们证明了在前线、创收的工艺操作中部署…

线程的基本控制

线程终止 exit是危险的 如果进程中的任意一个线程调用了exit,那么整个进程终止。 不终止进程的退出方式 普通单个线程的退出方法,以下方法退出不会导致进程终止: (1)从启动例程中返回,返回值是线程的退出…

DeepSeek+WinForm串口通讯实战

前言 在现代软件开发中,串口通讯仍然是工业自动化、物联网设备和嵌入式系统的重要通信方式。随着.NET技术的发展,特别是.NET 5/.NET 6的跨平台能力,传统的WinForm应用现在可以通过现代UI框架实现真正的跨平台串口通讯。本文将深入探讨三种主…

针对数据仓库方向的大数据算法工程师面试经验总结

⚙️ 一、技术核心考察点 数据建模能力 星型 vs 雪花模型:面试官常要求对比两种模型。星型模型(事实表冗余维度表)查询性能高但存储冗余;雪花模型(规范化维度表)减少冗余但增加JOIN复杂度。需结合场景选择&…

Nuxt3 Cannot read properties of undefined (reading ‘createElement‘)

你遇到的 TypeError: Cannot read properties of undefined (reading createElement) 这个报错,通常是由于在 Nuxt3 或 Vue3 项目中,某些地方尝试访问 document.createElement 或类似 DOM API,但此时 document 还未定义(比如在服务…

正则表达式匹配实现

直接上代码 using Microsoft.AspNetCore.Mvc; using System.Text.RegularExpressions;namespace SaaS.OfficialWebSite.Web.Controllers {public class RegController : Controller{public IActionResult Index(){return View();}[HttpPost]public IActionResult TestRegex([F…

API测试工具Parasoft SOAtest:应对API变化,优化测试执行

API频繁变更给测试工作带来诸多挑战,如手动排查变更影响耗时费力、测试用例维护繁琐易出错等。Parasoft SOAtest作为一款企业级API测试工具,通过自动扫描API接口、智能分析变更影响、优化测试,执行以及支持测试用例共享与版本控制等功能&…

mysql 数据库连接 -h localhost 和 -h 127.0.0.1 区别是什么

对于 mysql 数据库, 在 my.conf 中指定的client 端口是 3358,实际的mysql server 的端口监听在 3306, mysql -h localhost 可以居然可以连接成功; mysql -h 127.0.0.1 连接失败提示Can’t connect to MySQL server on 127.0.0.1&a…

Educational Codeforces Round 180 (Rated for Div. 2) A-D

A.Race 题目大意 给你两个x,y,终点会在二点之间随机出现,alice在点a,假设alice和bob有相同的速度(距离更短者用时更少),问对于bob是否存在一点,无论终点是x还是y,他都能比alice更快到达 思路 如果alice在…

python requests post请求

在Python中,使用requests库进行POST请求是一种常见的操作,用于向服务器发送数据。下面是如何使用requests库进行POST请求的步骤: 安装requests库 如果你还没有安装requests库,可以通过pip安装: pip install requests…

Postman中设置定时自动运行接口测试

‌创建测试集合‌ 将需每日运行的接口组织到Collection中,并配置好测试脚本和断言。 ‌配置定时运行‌ 打开目标Collection → 点击 ‌Run‌ 按钮在Collection Runner页面底部选择 ‌Schedule runs‌关键配置: Frequency: Daily // 选择每日执行 Time…

multiprocessing.pool和multiprocessing.Process

在CPU密集型任务中,Python的multiprocessing模块是突破GIL限制的关键工具。multiprocessing.Pool(进程池)和multiprocessing.Process(独立进程)是最常用的两种并行化方案,但其设计思想和适用场景截然不同。…

容器技术技术入门与 Docker 环境部署

目录 一:Docker概述 1、 Docker的优势: (1)环境一致性 (2)隔离性 (3)资源高效 (4)便捷性和可扩展性 2、Docker容器与传统虚拟机的区别 3、Docker的应用…