详解序数回归损失函数ordinal_regression_loss:原理与实现

在医疗 AI 领域,很多分类任务具有有序类别的特性,如疾病严重程度(轻度→中度→重度)、肿瘤分级(G1→G2→G3)等。这类任务被称为序数回归(Ordinal Regression),需要特殊的损失函数设计。本文将深入解析序数回归损失函数的原理及其实现代码。

一、序数回归与传统分类的区别

传统分类任务(如疾病类型识别)假设类别之间是无序的,而序数回归的类别具有自然顺序。例如:

  • 疾病严重程度:0(正常)→1(轻度)→2(中度)→3(重度)
  • 影像评分:1 分→2 分→3 分→4 分→5 分

对于这类任务,传统的交叉熵损失存在局限性:它只关注类别预测的正确性,而忽略了类别之间的顺序关系。例如,将真实标签为 "中度"(2)的样本预测为 "重度"(3),与预测为 "轻度"(1),在交叉熵损失中被视为同等错误,但实际上前者的错误程度更小。

二、序数回归损失函数的核心思想

序数回归损失函数的设计目标是:不仅要正确分类,还要保持类别之间的顺序关系。常见的实现方法有以下几种:

  1. 累积概率模型:将序数分类转化为一系列二分类问题
  2. 相邻类别比较:比较相邻类别的预测概率
  3. 距离敏感损失:惩罚与真实类别距离更远的错误预测

代码中实现的是累积概率模型,这是最常用的序数回归方法之一。

三、累积概率模型的数学原理

累积概率模型的核心思想是:将序数类别转化为一系列累积概率。对于有K个类别的问题,定义K-1个阈值cutspoints,,则样本属于类别k的概率为:,其中:

四、代码实现解析

下面详细解析序数回归损失函数的实现代码:

def ordinal_regression_loss(self, pred, label, num_classes, train_cutpoints=False, scale=20.0):# 1. 计算阈值(cutpoints)num_cutpoints = num_classes - 1#计算阈值数量cutpoints = torch.arange(num_cutpoints, device=pred.device).float() * scale / (num_classes - 2) - scale / 2cutpoints = nn.Parameter(cutpoints, requires_grad=train_cutpoints)# 2. 计算累积概率sigmoids = torch.sigmoid(cutpoints - pred)# 3. 构建概率矩阵:将累积概率转换为每个类别的概率link_mat = sigmoids[:, 1:] - sigmoids[:, :-1]  # 中间类别的概率link_mat = torch.cat((sigmoids[:, [0]],         # 第一个类别的概率link_mat,                 # 中间类别的概率(1 - sigmoids[:, [-1]])   # 最后一个类别的概率), dim=1)# 4. 数值稳定性处理:防止对数计算时出现NaNeps = 1e-15likelihoods = torch.clamp(link_mat, eps, 1 - eps)# 5. 计算负对数似然损失neg_log_likelihood = torch.log(likelihoods)if label is None:loss = 0else:loss = -torch.gather(neg_log_likelihood, 1, label).mean()return loss, likelihoods

五、关键步骤详解

1. 阈值(Cutpoints)计算
cutpoints = torch.arange(num_cutpoints, device=pred.device).float() * scale / (num_classes - 2) - scale / 2
  • 作用:生成均匀分布的阈值点,将连续空间划分为多个区间

例如:

  • 参数
    • scale:控制阈值的范围,默认 20.0
    • train_cutpoints:是否将阈值作为可训练参数(默认为 False)
  • 基础序列torch.arange(num_cutpoints):对于K个类别,生成序列[0,1,2,...,K-2]
  • 缩放因子scale / (num_classes - 2)调整阈值之间的间隔
  • 线性变换* scale / (num_classes - 2) - scale / 2:将基础序列映射到 [-scale/2, scale/2] 区间。

这两行代码的核心是将连续的预测空间均匀划分为多个有序区间,每个区间对应一个类别。通过调整 scale 参数,可以控制区间的宽度,适应不同的任务需求。当 train_cutpoints=True 时,模型会在训练过程中自动学习最优的阈值位置,进一步提升序数回归的性能。

2. 累积概率计算
sigmoids = torch.sigmoid(cutpoints - pred)
  • 作用:将模型预测值与阈值的差值通过 sigmoid 函数转换为累积概率
  • 示例:对于 3 个类别(2 个阈值),累积概率为:

将模型输出的抽象分数 pred,通过与阈值 cutpoints 的比较,转换为 “属于某个类别或更低等级” 的概率。这个概率越接近 1,说明 pred 越可能落在该类别或更低等级的区间里。

3. 类别概率矩阵构建
link_mat = sigmoids[:, 1:] - sigmoids[:, :-1]
link_mat = torch.cat((sigmoids[:, [0]], link_mat, 1 - sigmoids[:, [-1]]), dim=1)

  • sigmoids[:, 1:] → 取所有样本的第二个及以后的累积概率
  • sigmoids[:, :-1] → 取所有样本的第一个及以前的累积概率
4.数值稳定性处理:防止对数计算时出现NaN

在深度学习中,当计算概率的对数时(如交叉熵损失中的 log(p)),如果概率 p 非常接近 0(如 1e-20),会导致以下问题:

  1. 数值下溢:计算机无法精确表示极小数,可能返回 0
  2. 对数计算错误log(0) 会返回负无穷(-inf
  3. 梯度爆炸:反向传播时,-inf 的梯度会导致参数更新异常

同样,当概率 p 接近 1 时,1-p 接近 0,也会引发类似问题。

  • torch.clamp(input, min, max) 将输入张量的每个元素限制在 [min, max] 范围内
  • 确保所有概率值在 [1e-15, 1-1e-15] 之间,避免过于接近 0 或 1

5. 负对数似然损失计算
neg_log_likelihood = torch.log(likelihoods)
loss = -torch.gather(neg_log_likelihood, 1, label).mean()
  • 作用:计算每个样本的真实类别对应的负对数概率,并取平均

通过最大似然估计,让模型预测的真实类别概率最大化。具体步骤为:

  1. 计算对数似然:将概率转换为对数空间
  2. 按标签选择:提取真实类别对应的对数似然
  3. 取负平均:转换为损失(越小越好)

六、为什么选择序数回归损失?

在医疗分类任务中,序数回归损失有以下优势:

  1. 利用顺序信息:充分利用类别之间的顺序关系,提高模型对程度差异的敏感性
  2. 减少信息损失:相比将序数问题简单视为分类问题,保留了更多结构信息
  3. 更好的校准:输出的概率具有更明确的临床意义(如疾病严重程度的概率)
  4. 提升性能:在序数分类任务中,通常比传统分类损失取得更好的性能

七、实践建议

  1. 阈值初始化

    • 代码中的线性初始化是常用方法,但对于特定任务,可根据先验知识自定义阈值
    • train_cutpoints=True时,模型会学习最优阈值位置
  2. 模型输出设计

    • 模型最后一层应输出单个连续值(而非类别概率),作为序数回归的预测值
    • 可通过全连接层实现:nn.Linear(input_dim, 1)
  3. 超参数调整

    • scale参数影响阈值的分布范围,需根据具体任务调整
    • 对于严重不平衡的序数类别,可考虑加权损失
  4. 评估指标

    • 除准确率外,建议使用 Kendall's τ 或 Spearman 相关性等评估顺序一致性
    • 医学场景中,还需关注不同严重程度类别的敏感性和特异性

八、总结

序数回归损失函数为具有顺序关系的医疗分类任务提供了更合适的优化目标。通过将类别转化为累积概率,它不仅能正确分类,还能保持类别之间的顺序关系,特别适合疾病严重程度分级、影像评分等医疗场景。

在实际应用中,可根据任务特点调整阈值初始化方式和损失函数参数,结合适当的评估指标,构建更符合临床需求的医疗 AI 模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914933.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914933.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL增查

建完库与建完表后后:1.分别查询student表和score表的所有记录student表:score表:2.查询student表的第2条到5条记录SELECT * FROM student LIMIT 1,4;3.从student表中查询计算机系和英语系的学生的信息SELECT * FROM student-> WHERE department IN (计算机系, 英…

二分答案之最大化最小值

参考资料来源灵神在力扣所发的题单,仅供分享学习笔记和记录,无商业用途。 核心思路:本质上是求最大 应用场景:在满足条件的最小值区间内使最大化 检查函数:保证数据都要大于等于答案 补充:为什么需要满…

OCR 赋能档案数字化:让沉睡的档案 “活” 起来

添加图片注释,不超过 140 字(可选)企业产品档案包含设计图纸、检测报告、生产记录等,传统数字化仅靠扫描存档,后续检索需人工逐份翻阅,效率极低。​OCR 产品档案解决方案直击痛点:通过智能识别技…

力扣118.杨辉三角

思路1.新建一个vector的vector2.先把空间开出来,然后再把里面的值给一个个修改开空间的手段:new、构造函数、reserve、resize因为我们之后要修改里面的数据,这就意味着我们需要去读取这个数据并修改,如果用reserve的话&#xff0c…

Python 网络爬虫 —— 提交信息到网页

一、模块核心逻辑“提交信息到网页” 是网络交互关键环节,借助 requests 库的 post() 函数,能模拟浏览器向网页发数据(如表单、文件 ),实现信息上传,让我们能与网页背后的服务器 “沟通”,像改密…

SpringMVC4

一、SpringMVC 注解与项目开发流程1.1注解的生命周期- Target、Retention 等元注解:- Target(ElementType.TYPE) :说明这个注解只能用在类、接口上。- Retention(RetentionPolicy.RUNTIME) :说明注解在运行时保留,能通过反射获取…

数据结构排序算法总结(C语言实现)

以下是常见排序算法的总结及C语言实现,包含时间复杂度、空间复杂度和稳定性分析:1. 冒泡排序 (Bubble Sort)思想:重复比较相邻元素,将较大元素向后移动。 时间复杂度:O(n)(最好O(n),最坏O(n)) 空…

嵌入式学习-PyTorch(2)-day19

很久没有学了,期间打点滴打了一个多星期,太累了,再加上学了一下Python语法基础,再终于开始重新学习pytorchtensorboard 的使用import torch from torch.utils.tensorboard import SummaryWriter writer SummaryWriter("logs…

Prompt Engineering 快速入门+实战案例

资料来源:火山引擎-开发者社区 引言 什么是 prompt A prompt is an input to a Generative AI model, that is used to guide its output. Prompt engineering is the process of writing effective instructions for a model, such that it consistently generat…

「源力觉醒 创作者计划」_文心开源模型(ERNIE-4.5-VL-28B-A3B-PT)使用心得

文章目录背景操作流程开源模型选择算力服务器平台开通部署一个算力服务器登录GPU算力服务器进行模型的部署FastDeploy 快速部署服务安装paddlepaddle-gpu1. 降级冲突的库版本安装fastdeploy直接部署模型(此处大约花费15分钟时间)放行服务端口供公网访问最…

P10719 [GESP202406 五级] 黑白格

题目传送门 前言:不是这样例有点过分了哈: 这是我没考虑到无解的情况的得分: 这是我考虑了的得分: 总而言之,就是一个Subtask 你没考虑无解的情况(除了Subtask #0),就会WA一大片,然后这个Subt…

AWS RDS PostgreSQL可观测性最佳实践

AWS RDS PostgreSQL 介绍AWS RDS PostgreSQL 是亚马逊云服务(AWS)提供的托管型 PostgreSQL 数据库服务。托管服务:AWS 管理数据库的底层基础设施,包括硬件、操作系统、数据库引擎等,用户无需自行维护。高性能&#xff…

C++——set,map的模拟实现

文章目录前言红黑树的改变set的模拟实现基本框架迭代器插入源码map模拟实现基础框架迭代器插入赋值重载源码测试代码前言 set,map底层使用红黑树这种平衡二叉搜索树来组织元素 ,这使得set, map能够提供对数时间复杂度的查找、插入和删除操作。 下面都是基…

LabVIEW液压机智能监控

​基于LabVIEW平台,结合西门子、研华等硬件,构建液压机实时监控系统。通过 OPC 通信技术实现上位机与 PLC 的数据交互,解决传统监控系统数据采集滞后、存储有限、参数调控不便等问题,可精准采集冲压过程中的位置、速度、压力等参数…

15. 什么是 xss 攻击?怎么防护

总结 跨站脚本攻击&#xff0c;注入恶意脚本敏感字符转义&#xff1a;“<”,“/”前端可以抓包篡改主要后台处理&#xff0c;转义什么是 XSS 攻击&#xff1f;怎么防护 概述 XSS&#xff08;Cross-Site Scripting&#xff0c;跨站脚本攻击&#xff09;是一种常见的 Web 安全…

更换docker工作目录

使用环境 由于默认系统盘比较小docker镜像很容易就占满&#xff0c;需要挂载新的磁盘修改docker的默认工作目录 环境&#xff1a;centos7 docker默认工作目录: /var/lib/docker/ 新的工作目录&#xff1a;/home/docker-data【自己手动创建&#xff0c;一般挂在新加的磁盘下面】…

算法学习笔记:26.二叉搜索树(生日限定版)——从原理到实战,涵盖 LeetCode 与考研 408 例题

二叉搜索树&#xff08;Binary Search Tree&#xff0c;简称 BST&#xff09;是一种特殊的二叉树&#xff0c;因其高效的查找、插入和删除操作&#xff0c;成为计算机科学中最重要的数据结构之一。BST 的核心特性是 “左小右大”&#xff0c;这一特性使其在数据检索、排序和索引…

共生型企业:驾驭AI自动化(事+AI)与人类增强(人+AI)的双重前沿

目录 引言&#xff1a;人工智能的双重前沿 第一部分&#xff1a;自动化范式&#xff08;事AI&#xff09;——重新定义卓越运营 第一章&#xff1a;智能自动化的机制 第二章&#xff1a;自动化驱动的行业转型 第三章&#xff1a;自动化的经济演算 第二部分&#xff1a;协…

TypeScript的export用法

在 TypeScript 中&#xff0c;export 用于将模块中的变量、函数、类、类型等暴露给外部使用。export 语法允许将模块化的代码分割并在其他文件中导入。 1. 命名导出&#xff08;Named Export&#xff09; 命名导出是 TypeScript 中最常见的一种导出方式&#xff0c;它允许你导出…

数据结构-2(链表)

一、思维导图二、链表的反转def reverse(self):"""思路&#xff1a;1、设置previous_node、current、next_node三个变量,目标是将current和previous_node逐步向后循环并逐步进行反转,知道所有元素都被反转2、但唯一的问题是&#xff1a;一旦current.next反转为向…