PyTorch 损失函数详解:从理论到实践

目录

一、损失函数的基本概念

二、常用损失函数及实现

1. 均方误差损失(MSELoss)

2. 平均绝对误差损失(L1Loss/MAELoss)

3. 交叉熵损失(CrossEntropyLoss)

4. 二元交叉熵损失(BCELoss)

三、损失函数选择指南

四、损失函数在训练中的应用

五、总结


损失函数是深度学习模型训练的核心组件,它量化了模型预测值与真实值之间的差异,指导模型参数的更新方向。本文将结合 PyTorch 代码实例,详细讲解常用损失函数的原理、适用场景及实现方法。

一、损失函数的基本概念

损失函数(Loss Function)又称代价函数(Cost Function),是衡量模型预测结果与真实标签之间差异的指标。在模型训练过程中,通过优化算法(如梯度下降)最小化损失函数,使模型逐渐逼近最优解。

损失函数的选择取决于具体任务类型:

  • 回归任务:预测连续值(如房价、温度)
  • 分类任务:预测离散类别(如图片分类、垃圾邮件识别)
  • 其他任务:如生成任务、序列标注等

二、常用损失函数及实现

1. 均方误差损失(MSELoss)

均方误差损失是回归任务中最常用的损失函数,计算预测值与真实值之间平方差的平均值。

数学公式MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 其中,y_i为真实值,\hat{y}_i为预测值,n为样本数量。

代码实现

import torch
import torch.nn as nn# 初始化MSE损失函数
mse_loss = nn.MSELoss()# 示例数据
y_true = torch.tensor([3.0, 5.0, 2.5])  # 真实值
y_pred = torch.tensor([2.5, 5.0, 3.0])  # 预测值# 计算损失
loss = mse_loss(y_pred, y_true)
print(f'MSE Loss: {loss.item()}')  # 输出:MSE Loss: 0.0833333358168602

特点

  • 对异常值敏感,因为会对误差进行平方
  • 是凸函数,存在唯一全局最小值
  • 适用于大多数回归任务

2. 平均绝对误差损失(L1Loss/MAELoss)

平均绝对误差计算预测值与真实值之间绝对差的平均值,对异常值的敏感性低于 MSE。

数学公式MAE = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i|

代码实现

# 初始化L1损失函数
l1_loss = nn.L1Loss()# 计算损失
loss = l1_loss(y_pred, y_true)
print(f'L1 Loss: {loss.item()}')  # 输出:L1 Loss: 0.25

特点

  • 对异常值更稳健
  • 梯度在零点处不连续,可能影响收敛速度
  • 适用于存在异常值的回归场景

3. 交叉熵损失(CrossEntropyLoss)

交叉熵损失是多分类任务的标准损失函数,在 PyTorch 中内置了 Softmax 操作,直接作用于模型输出的 logits。

数学公式CrossEntropyLoss = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) 其中,C为类别数,y_i为真实标签的 one-hot 编码,\hat{y}_i为经过 Softmax 处理的预测概率。

代码实现

def test_cross_entropy():# 模型输出的logits(未经过softmax)logits = torch.tensor([[1.5, 2.0, 0.5], [0.5, 1.0, 1.5]])# 真实标签(类别索引)labels = torch.tensor([1, 2])  # 第一个样本属于类别1,第二个样本属于类别2# 初始化交叉熵损失函数criterion = nn.CrossEntropyLoss()loss = criterion(logits, labels)print(f'Cross Entropy Loss: {loss.item()}')  # 输出:Cross Entropy Loss: 0.6422222256660461test_cross_entropy()

计算过程解析

  1. 对 logits 应用 Softmax 得到概率分布
  2. 计算真实类别对应的负对数概率
  3. 取平均值作为最终损失

特点

  • 自动包含 Softmax 操作,无需手动添加
  • 适用于多分类任务(类别互斥)
  • 标签格式为类别索引(非 one-hot 编码)

4. 二元交叉熵损失(BCELoss)

二元交叉熵损失用于二分类任务,需要配合 Sigmoid 激活函数使用,确保输入值在 (0,1) 范围内。

数学公式BCELoss = -\frac{1}{n} \sum_{i=1}^{n} [y_i \log(\hat{y}_i) + (1-y_i) \log(1-\hat{y}_i)]

代码实现

def test_bce_loss():# 模型输出(已通过sigmoid处理)y_pred = torch.tensor([[0.7], [0.2], [0.9], [0.7]])# 真实标签(0或1)y_true = torch.tensor([[1], [0], [1], [0]], dtype=torch.float)# 方法1:使用BCELossbce_loss = nn.BCELoss()loss1 = bce_loss(y_pred, y_true)# 方法2:使用functional接口loss2 = nn.functional.binary_cross_entropy(y_pred, y_true)print(f'BCELoss: {loss1.item()}')  # 输出:BCELoss: 0.47234177589416504print(f'Functional BCELoss: {loss2.item()}')  # 输出:Functional BCELoss: 0.47234177589416504test_bce_loss()

变种:BCEWithLogitsLoss
对于未经过 Sigmoid 处理的 logits,推荐使用BCEWithLogitsLoss,它内部会自动应用 Sigmoid,数值稳定性更好:

# 对于logits输入(未经过sigmoid)
logits = torch.tensor([[0.8], [-0.5], [1.2], [0.6]])
bce_with_logits_loss = nn.BCEWithLogitsLoss()
loss = bce_with_logits_loss(logits, y_true)

三、损失函数选择指南

任务类型推荐损失函数特点
回归任务MSELoss对异常值敏感,适用于大多数回归场景
回归任务(含异常值)L1Loss对异常值稳健,梯度不连续
多分类任务CrossEntropyLoss内置 Softmax,处理互斥类别
二分类任务BCELoss/BCEWithLogitsLoss配合 Sigmoid 使用,输出概率值
多标签分类BCEWithLogitsLoss每个类别独立判断,可同时属于多个类别

四、损失函数在训练中的应用

以图像分类任务为例,展示损失函数在完整训练流程中的使用:

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 定义简单的全连接网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)  # 输出logits,不使用softmaxreturn x# 初始化模型、损失函数和优化器
model = SimpleNet()
criterion = nn.CrossEntropyLoss()  # 多分类任务
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
def train(epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()# 打印每轮的平均损失avg_loss = running_loss / len(train_loader)print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')train()

五、总结

损失函数的选择直接影响模型的训练效果和收敛速度,关键要点:

  1. 回归任务优先选择 MSELoss,存在异常值时考虑 L1Loss
  2. 多分类任务使用 CrossEntropyLoss,无需手动添加 Softmax
  3. 二分类任务推荐使用 BCEWithLogitsLoss,数值稳定性更好
  4. 训练过程中需监控损失变化,判断模型是否收敛或过拟合

合理选择损失函数并配合适当的优化器,才能充分发挥模型的学习能力。在实际应用中,可根据具体任务特点和数据分布尝试不同的损失函数,选择表现最佳的方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89559.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89559.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MinIO深度解析:从核心特性到Spring Boot实战集成

在当今数据爆炸的时代,海量非结构化数据的存储与管理成为企业级应用的关键挑战。传统文件系统在TB级数据面前捉襟见肘,而昂贵的云存储服务又让中小企业望而却步。MinIO作为一款开源高性能对象存储解决方案,正以其独特的技术优势成为开发者的首…

腾讯云服务上下载docker以及使用Rabbitmq的流程

执行以下命令,添加 Docker 软件源并配置为腾讯云源。sudo yum-config-manager --add-repohttps://mirrors.cloud.tencent.com/docker-ce/linux/centos/docker-ce.repo sudo sed -i "s/download.docker.com/mirrors.tencentyun.com\/docker-ce/g" /etc/yu…

UE5 一些关于过场动画sequencer,轨道track的一些Python操作

删除多余的轨道 import unreal def execute():movie_scene_actors []sequence_assets []data 0.0# 获取编辑器实用工具库lib unreal.EditorUtilityLibrary()selected_assets lib.get_selected_assets()for asset in selected_assets:if asset.get_class() unreal.LevelS…

前端性能优化“核武器”:新一代图片格式(AVIF/WebP)与自动化优化流程实战

前端性能优化“核武器”:新一代图片格式(AVIF/WebP)与自动化优化流程实战 当你的页面加载时间超过3秒时,用户的跳出率会飙升到40%以上。而在所有的前端性能优化手段中,图片优化无疑是投入产出比最高的一环。一张未经优化的巨大图片&#xff0…

单元测试学习+AI辅助单测

标题单元测试衡量指标具体测试1、Resource2、MockBean3、Test4、Test模板5、单测示例H2数据库JSON1、使用方式AI辅助单测使用方法单元测试 单元测试一般指程序员在写好代码后,提交测试前,需要验证自己的代码是否可以正常工作,同时将自己的代…

Spring Cloud Gateway与Envoy Sidecar在微服务请求路由中的架构设计分享

Spring Cloud Gateway与Envoy Sidecar在微服务请求路由中的架构设计分享 在现代微服务架构中,请求路由层承担着流量分发、安全鉴权、流量控制等多重职责。传统的单一网关方案往往面临可扩展性和可维护性挑战。本文将从真实生产环境出发,分享如何结合Spri…

GitHub Pages+Jekyll 静态网站搭建(二)

GitHub PagesJekyll 静态网站搭建(二)GitHub PagesJekyll 静态网站搭建(二内容简介搭建模板网站部署工作流程GitHub PagesJekyll 静态网站搭建(二 内容简介 🚩 Tech Contents 该文主要涉及Jekyll主题的下载与使用。Gi…

Django 实战:I18N 国际化与本地化配置、翻译与切换一步到位

文章目录一、国际化与本地化介绍定义相关概念二、安装配置安装 gettext配置 settings.py三、使用国际化视图中使用序列化器和模型中使用四、本地化操作创建或更新消息文件消息文件说明编译消息文件五、项目实战一、国际化与本地化介绍 定义 国际化和本地化的目标,…

通过国内扣子(Coze)搭建智能体并接入discord机器人

国内的扣子是无法直接授权给discord的,但是用国外的coze的话,大模型调用太贵,如果想要接入国外的平台,那就需要通过调用API来实现。 1.搭建智能体(以工作流模式为例) 首先,我们需要在扣子平台…

【办公类-107-02】20250719视频MP4转gif(削减MB)

背景需求 最近在写第五届智慧项目结题(一共3篇)写的昏天黑地,日以继夜。 我自己《基于“AI技术”的幼儿园教学资源开发和运用》提到了AI绘画、AI视频和AI编程。 为了更好的展示AI编程的状态,我在WORD里面插入了MP4转gif的动图。 【教学类-75-04】20241023世界名画-《蒙…

一文讲清楚React的render优化,包括shouldComponentUpdate、PureComponent和memo

文章目录一文讲清楚React的render优化,包括shouldComponentUpdate、PureComponent和memo1. React的渲染render机制2. shouldComponentUpdate2.1 先上单组件渲染,验证state变化2.2 上父子组件,验证props2. PureComponent2.1 单组件验证state2.…

物联网iot、mqtt协议与华为云平台的综合实践(万字0基础保姆级教程)

本学期的物联网技术与应用课程,其结课设计内容包含:mqtt、华为云、PyQT5和MySQL等结合使用,完成了从华为云配置产品信息以及转发规则,到mqtt命令转发,再到python编写逻辑代码实现相关功能,最后用PyQT5实现面…

使用IntelliJ IDEA和Maven搭建SpringBoot集成Fastjson项目

使用IntelliJ IDEA和Maven搭建SpringBoot集成Fastjson项目 下面我将详细介绍如何在IntelliJ IDEA中使用Maven搭建一个集成Fastjson的SpringBoot项目,包含完整的环境配置和代码实现。 一、环境准备 软件要求 IntelliJ IDEA 2021.x或更高版本JDK 1.8或更高版本&#x…

Java从入门到精通!第九天, 重点!(集合(一))

十一、集合1. 为什么要使用集合(1) 数组存在的弊端1) 数组在初始化之后,长度就不能改变,不方便扩展。2) 数组中提供的属性和方法比较少,不便于进行添加、删除、修改等操作,并且效率不高,同时无法直接存储元素的个数。3…

为什么使用时序数据库

为什么使用时序数据库? 时序数据库(Time-Series Database, TSDB)是专为时间序列数据优化的数据库,相比传统关系型数据库(如MySQL)或NoSQL数据库(如MongoDB),它在以下方面…

计算机网络:(十一)多协议标记交换 MPLS

计算机网络:(十一)多协议标记交换 MPLS前言一、传统网络的问题二、MPLS:给数据包贴个“标签”三、MPLS的工作流程1. 入站2. 中间3. 出站四、MPLS的能力前言 前面我们讲解了计算机网络中网络层的相关知识,包括网络层转发…

docker run elasticsearch 报错

谷粒商城 p103 前提条件: 下载镜像文件 #存储和检索数据 docker pull elasticsearch:7.4.2 #可视化检索数据 docker pull kibana:7.4.2 创建挂载的文件和配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "http.h…

巧用Callbre RVE生成DRC HTML report及CTO的使用方法

对于后端版图人员,在芯片TO前的LV signoff阶段,犹如一段漫长而有期待的朝圣之旅,需要耐心,毅力和信心,在庞杂的DRC中找到一条收敛之路。为了让此路更为清晰收敛,Calibre提供了一套可追溯对比的富文本方式-H…

产品需求文档(PRD)格式全解析:从 RP 到 Word 的选择与实践

产品需求文档(PRD)的形式多种多样,但核心目标始终一致:清晰传递产品需求,让团队高效协作。不同公司对 PRD 的格式要求可能不同,有的偏爱直接在原型工具中撰写,有的则习惯用 Word 整理归档。本文…

【C++】入门阶段

一、初始化C中的初始化指为变量赋予初始值的过程。初始化方式多样,适用于不同场景。char cha0; char chb{0}; char chc(\0); char chdcha; char che{};注意事项优先使用列表初始化({}),避免窄化转换风险。在c11中{ }在变量&#x…