transformers 笔记:自定义模型(配置+模型+注册为AutoCLass+本地保存加载)

  • Transformers 模型设计上是可定制的。
  • 每个模型的代码都包含在 Transformers 仓库的 model 子文件夹中(transformers/src/transformers/models at main · huggingface/transformers),每个模型文件夹通常包含:
    • modeling.py:定义模型结构与前向传播
    • configuration.py:定义模型的超参数配置

1 配置(Configuration)

1.1 自定义配置

  • 自定义配置类的要点:
    • 必须继承自 PretrainedConfig,以继承 from_pretrained()save_pretrained() 等功能;
    • 构造函数 __init__() 必须接收任意 **kwargs 并传给父类;
    • 添加 model_type 属性,以支持 AutoClass;
    • 可以加入参数校验逻辑。

1.2 保存配置

resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
resnet50d_config.save_pretrained("custom-resnet")

2 模型结构

  • 模型类需要继承自 PreTrainedModel,并接受配置对象作为输入
  • Transformers 约定模型的所有超参数由配置对象提供
  • 可以构建两种模型:

2.1 裸模型(输出隐藏状态)

2.2 带分类头的模型(支持 Trainer,输出 logits 和 loss)

2.3 加载预训练权重

import timmresnet50d = ResnetModel(resnet50d_config)
#此时 resnet50d.model 就是一个结构为 ResNet-50d 的模型,但权重是 随机初始化的,没有训练。pretrained_model = timm.create_model("resnet50d", pretrained=True)
#从 timm 加载已经训练好的 resnet50d 模型resnet50d.model.load_state_dict(pretrained_model.state_dict())

3 启用 AutoClass 支持

AutoClass API 能自动根据配置加载模型,简化用户调用

需要:

  1. 在配置类中加入 model_type

  2. 在模型类中加入 config_class

  3. 使用 AutoConfig.register()AutoModel.register() 注册。

from transformers import AutoConfig, AutoModel, AutoModelForImageClassificationAutoConfig.register("resnet", ResnetConfig)
#注册自定义配置类 ResnetConfig。
#"resnet" 是 ResnetConfig.model_type,它必须和配置类中的 model_type = "resnet" 一致。
#注册后,用户可以通过 AutoConfig.from_pretrained() 自动加载这个配置类。AutoModel.register(ResnetConfig, ResnetModel)
#把裸模型类 ResnetModel 绑定到 AutoModel。
'''
这样用户就可以用如下方式加载模型:
model = AutoModel.from_pretrained("your-username/custom-resnet50d", trust_remote_code=True)
'''AutoModelForImageClassification.register(ResnetConfig, ResnetModelForImageClassification)
#注册了你带分类头的模型 ResnetModelForImageClassification 到 AutoModelForImageClassification。
'''
用户可以像这样加载:
model = AutoModelForImageClassification.from_pretrained("your-username/custom-resnet50d", trust_remote_code=True
)
'''

4 本地保存& 加载特定模型

 假设已经定义和注册配置和模型,并加载了预训练权重

resnet50d_config = ResnetConfig(block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True)
#加载自定义configresnet50d = ResnetModelForImageClassification(resnet50d_config)
#加载自定义model# 加载预训练权重
import timm
pretrained = timm.create_model("resnet50d", pretrained=True)
resnet50d.model.load_state_dict(pretrained.state_dict())

注册 AutoClass 支持,保存 AutoClass 映射信息

resnet50d_config.register_for_auto_class()
resnet50d.register_for_auto_class("AutoModelForImageClassification")


保存模型和配置到本地

resnet50d.save_pretrained("custom-resnet50d/")
resnet50d_config.save_pretrained("custom-resnet50d/")

4.1 本地重新加载

from transformers import AutoModelForImageClassification# 加载模型
model = AutoModelForImageClassification.from_pretrained("custom-resnet50d/", trust_remote_code=True
)

由于使用的是自定义模型类,加载时一定要加上trust_remote_code=True

4.2 保存后的本地目录

4.3 为什么要保存config?

  • config 是必须保存的,因为 AutoModel 是依赖 config.json 来决定加载哪个模型类。
  • AutoModel.from_pretrained("path_or_repo")背后的机制是
    • 先加载配置文件 config.json
      • config = AutoConfig.from_pretrained("path_or_repo")
    • 根据 config.model_type 决定使用哪个模型类
      • "model_type": "resnet" → 查找注册的 ResnetModel
    • 再加载权重文件(.bin 或 .safetensors)到模型中

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/88109.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/88109.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java工具类,对象List提取某个属性为List,对象List转为对象Map其中某个属性作为Key值

Java工具类package org.common;import lombok.extern.slf4j.Slf4j;import java.util.*; import java.util.stream.Collectors;Slf4j public final class CollectorHelper {/*** param element* param propertyName* param <E>* return*/public static <E> List toL…

ATE FT ChangeKit学习总结-20250630

目录 一、基本概念 二、主要特点 三、BOM LIST Shuttle Hot Plate Dock Plate Contactor 四、设计要点 五、参考文献与链接 一、基本概念 Change Kit在半导体封装测试领域中是一个重要的组件,它作为Handler(自动化分类机)的配套治具,在芯片测试过程中发挥着关键作…

【网络协议安全】任务14:路由器DHCP_AAA_TELNET配置

本文档将详细介绍在华为 eNSP 仿真环境中&#xff0c;实现路由器 DHCP 服务器功能、AAA 认证以及 TELNET 远程登录配置的完整步骤&#xff0c;适用于华为 VRP 系统路由器。 一、配置目标 路由器作为 DHCP 服务器&#xff0c;为局域网内的设备自动分配 IP 地址、子网掩码、网关…

深度探索:现代翻译技术的核心算法与实践(第一篇)

引言:翻译技术的演进之路 从早期的基于规则的机器翻译(RBMT)到统计机器翻译(SMT),再到如今主导行业的神经机器翻译(NMT),翻译技术已经走过了漫长的发展道路。现代翻译系统不仅能够处理简单的句子,还能理解上下文、识别领域术语,甚至捕捉微妙的文化差异。 本系列文章将带…

玩转Docker | 使用Docker部署NotepadMX笔记应用程序

玩转Docker | 使用Docker部署NotepadMX笔记应用程序 前言一、NotepadMX介绍工具简介主要特点二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署NotepadMX服务下载NotepadMX镜像编辑部署文件创建容器检查容器状态检查服务端口安全设置四、访问NotepadMX服务访…

Web前端:not(否定伪类选择器)

not&#xff08;否定伪类选择器&#xff09;CSS中的 :not() 选择器是⼀个否定伪类选择器&#xff0c;它⽤于选择不符合给定选择器的元素。这是⼀种排除特定元素的⽅法&#xff0c;可以⽤来简 化复杂的选择器&#xff0c;提⾼ CSS 规则的灵活性和精确性。:not() 选择器的基本语法…

【BTC】比特币网络

目录 一、比特币网络架构 1.1 节点加入与离开 二、消息传播方式 三、交易处理机制 四、网络传播问题 五、实际应用问题及解决 本章节讲比特币网络的工作原理&#xff0c;讲解新创建的区块是如何在比特币网络中传播的。 一、比特币网络架构 比特币工作在应用层&#xff…

Clickhouse 的历史发展过程

20.5.3 开始支持多线程20.6.3 支持explainmysql 20.8 实时同步mysql&#x1f4cc; ‌一、早期版本阶段&#xff08;1.1.x系列&#xff09;‌‌版本范围‌&#xff1a;1.1.54245&#xff08;2017-07&#xff09;→ 1.1.54394&#xff08;2018-07&#xff09;‌核心特性‌&#x…

玩转n8n工作流教程(一):Windows系统本地部署n8n自动化工作流(n8n中文汉化)

在Windows系统下使用 Docker 本地部署N8N中文版的具体操作&#xff0c;进行了详尽阐述&#xff0c;玩转n8n工作流教程系列内容旨在手把手助力从0开始一步一步深入学习n8n工作流。想研究n8n工作流自动化的小伙伴们可以加个关注一起学起来。后续也会持续分享n8n自动化工作流各种玩…

mini-program01の系统认识微信小程序开发

一、官方下载并安装 1、下载&#xff08;I选了稳定版&#xff09; https://developers.weixin.qq.com/miniprogram/dev/devtools/download.htmlhttps://developers.weixin.qq.com/miniprogram/dev/devtools/download.html 2、安装&#xff08;A FEW MOMENT LATER&#xff09;…

如何将 Java 项目打包为可执行 JAR 文件

如何将 Java 项目打包为可执行 JAR 文件我将详细介绍将 Java 项目打包为可执行 JAR 文件的完整流程&#xff0c;包括使用 IDE 和命令行两种方法。方法一&#xff1a;使用 IntelliJ IDEA 打包步骤 1&#xff1a;配置项目结构打开项目点击 File > Project Structure在 Project…

【Starrocks 异常解决】-- mysql flink sync to starrocks row error

1、异常信息 flink 1.20 starrocks 3.3.0 mysql 8.0 errorLog: Error: Target column count: 35 doesnt match source value column count: 28. Column separator: \t, Row delimiter: \n. Row: 2025-05-22 6 23400055 214 dssd 1 1 1928 mm2er 360 20000.00000000 1…

Jenkins 使用宿主机的Docker

背景&#xff1a;不想在Jenkins 内部安装Docker,想直接使用Jenkins服务所在的系统安装的docker当你在 Jenkins 中执行 docker 命令时&#xff0c;实际上是通过 Docker 客户端与 Docker 守护进程进行通信。Docker 客户端和守护进程之间的通信是通过一个名为 /var/run/docker.soc…

工具+服务双驱动:创客匠人打造中医IP差异化竞争力

一、技术工具场景化定制&#xff1a;中医专业的可视化破圈在中医IP同质化严重的行业现状下&#xff0c;创客匠人以场景化技术工具破解专业传播难题。系统内置的“体质测试”模块可生成个性化调理报告&#xff0c;“案例库”支持前后对比图上传&#xff0c;“直播问诊”自动添加…

JVM对象分配内存如何保证线程安全?

大家好&#xff0c;我是锋哥。今天分享关于【JVM对象分配内存如何保证线程安全&#xff1f;】面试题。希望对大家有帮助&#xff1b; JVM对象分配内存如何保证线程安全&#xff1f; 超硬核AI学习资料&#xff0c;现在永久免费了&#xff01; 在Java中&#xff0c;JVM&#xf…

机器学习中的数据对齐

文章目录前言数据集怎么理解数据数据对齐为什么偏偏是这样对齐&#xff1f;前言 在神经网络中&#xff0c;我们往往会根据数据集构建训练集、测试集&#xff0c;有时会有验证集。但是&#xff0c;在构建完成后&#xff0c;如果直接将这些数据直接扔进模型训练&#xff0c;输入…

机器学习:更多分类回归算法之决策树、SVM、KNN

下面介绍的这几种算法&#xff0c;既能用于回归问题又能用于分类问题&#xff0c;接下来了解下吧。 决策树 可参考&#xff1a; 决策树&#xff08;Decision Tree&#xff09; | 菜鸟教程 决策树&#xff08;Decision Tree&#xff09;是一种常用的监督学习算法&#xff0c;可用…

Vue 整合 Vue Flow:从零构建交互式流程图

目录引言目的适用场景环境准备基础组件 (index.vue)自定义组件 (矩形、菱形等)RectangleNode.vue (矩形节点)&#xff1a;DiamondNode.vue (菱形节点)&#xff1a;ImageNode(自定义图片节点):操作实现 (#操作实现) 拖拽节点 (#拖拽节点) 连线 (多连接点) 删除节点 …

C# WPF - Prism 学习篇:搭建项目(一)

一、前期准备开发工具&#xff1a;Visual Studio 2022二、创建项目1、创建WPF 应用“WpfApp.StudyDemo”&#xff1a;2、项目结构如下&#xff1a; 三、安装 Prism1、选中项目“WpfApp.PrismDemo”&#xff0c;在右键菜单中选择“管理 NuGet 程序包(N)...”。2、在搜索框中输入…

单片机 基于rt-thread 系统 使用 CCM内存

一、开发环境 开发板&#xff1a;野火stm32f407 系统&#xff1a;rt-thread V4.1.1 二、链接脚本配置 ; ************************************************************* ; *** Scatter-Loading Description File generated by uVision *** ; ****************************…