Pytorch中torch.where()函数详解和实战示例

torch.where() 是 PyTorch 中非常常用的一个函数,功能类似于 NumPy 的 where,用于条件筛选或三元选择操作。在深度学习训练、掩码操作、损失函数处理等场景中非常常见。


一、基本语法

torch.where(condition, x, y)
  • condition:一个布尔张量(torch.bool 类型),和 xy 的 shape 必须可广播。
  • x:满足条件时取的值。
  • y:不满足条件时取的值。

二、功能说明

  • 如果只传入一个参数 conditiontorch.where(condition) 将返回 非零元素的坐标(类似 nonzero())。

  • 如果传入三个参数 condition, x, y,则类似于三元表达式:

    result = x if condition else y
    

三、示例详解

示例 1:三元选择(条件替换)

import torcha = torch.tensor([1, 2, 3, 4])
b = torch.tensor([10, 20, 30, 40])
cond = torch.tensor([True, False, True, False])out = torch.where(cond, a, b)
print(out)  # tensor([ 1, 20,  3, 40])

解释:满足 cond=True 的地方取 a,否则取 b


示例 2:只有 condition 参数,返回索引

x = torch.tensor([[0, 1], [2, 0]])
pos = torch.where(x > 0)print(pos)
# 输出: (tensor([0, 1]), tensor([1, 0]))
# 表示非零位置是 [0,1] 和 [1,0]

如果你希望将这些坐标转换成可访问的形式:

coordinates = list(zip(pos[0].tolist(), pos[1].tolist()))
# [(0, 1), (1, 0)]

示例 3:广播行为支持

a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[100]])
mask = torch.tensor([[True, False], [False, True]])out = torch.where(mask, a, b)
print(out)
# tensor([[  1, 100],
#         [100,   4]])

示例 4:用于神经网络中的掩码操作(常见)

logits = torch.tensor([0.2, -0.5, 0.7, -1.0])
mask = logits > 0
output = torch.where(mask, logits, torch.zeros_like(logits))
print(output)  # tensor([0.2000, 0.0000, 0.7000, 0.0000])

示例 5:替代负无穷值(比如处理 log(0) 的场景)

eps = 1e-6
x = torch.tensor([0.5, 0.0, 1.0])
safe_x = torch.where(x > 0, x, torch.tensor(eps))
logx = torch.log(safe_x)
print(logx)

四、常见用途总结

场景用法示例
条件替换torch.where(mask, a, b)
去除负值/NaN/0torch.where(x > 0, x, eps)
多分类掩码处理torch.where(onehot_mask, pred, 0)
找到满足条件的索引idxs = torch.where(x > 0.5)
广播与标量搭配torch.where(mask, x, torch.tensor(0))

六、实战示例

下面是 torch.where()分类问题损失函数图像掩码 场景下的实战用法示例和解释,非常适合深度学习任务中使用。


1. 分类问题中使用 torch.where(二分类/多分类掩码)

场景:筛选预测为正类的样本做统计
import torch# 假设 logits 是模型输出,labels 是真实标签
logits = torch.tensor([0.9, 0.3, 0.8, 0.1])
labels = torch.tensor([1, 0, 1, 0])# 二分类掩码
positive_mask = labels == 1# 只保留正样本对应的预测概率
positive_preds = torch.where(positive_mask, logits, torch.tensor(0.0))
print(positive_preds)
# tensor([0.9000, 0.0000, 0.8000, 0.0000])

2. 自定义损失函数中的 torch.where

示例:加权 BCE Loss,给予正类更高的权重
def weighted_bce_loss(pred, target, pos_weight=2.0):bce_loss = -(target * torch.log(pred + 1e-6) + (1 - target) * torch.log(1 - pred + 1e-6))weights = torch.where(target == 1, torch.tensor(pos_weight), torch.tensor(1.0))weighted_loss = weights * bce_lossreturn weighted_loss.mean()# 示例输入
pred = torch.tensor([0.9, 0.2, 0.7, 0.1])  # 模型预测
target = torch.tensor([1.0, 0.0, 1.0, 0.0])  # 标签loss = weighted_bce_loss(pred, target)
print(loss)

使用 torch.where 为正样本动态加权。


3. 图像处理中的 torch.where(掩码操作)

示例:用掩码提取前景像素或设定背景为固定值
import torch# 假设灰度图像:H×W
image = torch.tensor([[100, 120, 130],[90, 0, 50],[255, 200, 180]
], dtype=torch.float32)# 二值掩码(比如图像分割输出)
mask = image > 100# 将背景像素设为0(即屏蔽背景)
masked_image = torch.where(mask, image, torch.tensor(0.0))
print(masked_image)

输出:

tensor([[  0., 120., 130.],[  0.,   0.,   0.],[255., 200., 180.]])

避免除以 0:
denominator = torch.where(denom != 0, denom, torch.tensor(1e-6))
替代 NaN 或 Inf:
x = torch.tensor([1.0, float('nan'), 2.0, float('inf')])
cleaned = torch.where(torch.isfinite(x), x, torch.tensor(0.0))

4.语义分割 中的应用示例:

  • 忽略某些像素(如 ignore index);
  • 可视化前景掩码;
  • 动态计算某些类的准确率、IoU;
  • 针对背景与前景的不同加权 loss;
  • 将预测 mask 显示成彩色图像。

1. 忽略标签为 255 的像素(ignore_index

import torchpred = torch.tensor([[1, 2, 0],[0, 1, 2]
])
target = torch.tensor([[1, 255, 0],[0, 1, 255]
])# 忽略标签为 255 的像素
mask = target != 255
correct = torch.where(mask, pred == target, torch.tensor(False))# 精度统计(不含 ignore)
acc = correct.sum().float() / mask.sum()
print("Accuracy (excluding ignore_index):", acc.item())

** 2. 前景/背景 mask 处理(比如用于 loss)**

# 假设标签中 0 为背景,1 为前景
label = torch.tensor([[0, 0, 1],[1, 1, 0]
])is_foreground = torch.where(label == 1, torch.tensor(1.0), torch.tensor(0.0))
print(is_foreground)
# tensor([[0., 0., 1.],
#         [1., 1., 0.]])

你可以用这个 mask 计算前景区域的 loss 或平均值。


** 3. 可视化分割结果(转彩色)**

import torch# 假设预测结果为标签图(整数类 id)
label_map = torch.tensor([[0, 1, 2],[1, 2, 0]
], dtype=torch.int64)# 假设有 3 个类,对应 RGB 颜色如下
colors = torch.tensor([[0, 0, 0],       # class 0 -> 黑色[255, 0, 0],     # class 1 -> 红色[0, 255, 0],     # class 2 -> 绿色
], dtype=torch.uint8)# 转换为彩色图像
color_image = colors[label_map]
print(color_image.shape)  # torch.Size([2, 3, 3]),对应 H x W x C

如果你要保存图像(使用 OpenCV):

import cv2
cv2.imwrite("seg_output.png", color_image.numpy())

** 4. 类别不平衡:前景加权 loss**

# logits 为 (N, C, H, W),labels 为 (N, H, W)
def weighted_cross_entropy(logits, labels, fg_weight=5.0, ignore_index=255):N, C, H, W = logits.shapelogits_flat = logits.permute(0, 2, 3, 1).reshape(-1, C)labels_flat = labels.reshape(-1)# 计算 lossloss = torch.nn.functional.cross_entropy(logits_flat, labels_flat,reduction='none', ignore_index=ignore_index)# 生成权重:前景类别加权weights = torch.ones_like(labels_flat, dtype=torch.float32)weights = torch.where(labels_flat != 0, torch.tensor(fg_weight), torch.tensor(1.0))weights = torch.where(labels_flat == ignore_index, torch.tensor(0.0), weights)loss = loss * weightsreturn loss.sum() / (weights.sum() + 1e-6)

小结:语义分割中常见的 torch.where() 用法

任务示例
忽略 ignore_index 像素mask = label != 255
筛选前景像素fg_mask = torch.where(label == 1, 1.0, 0.0)
自定义 loss 权重weights = torch.where(label == 1, 5.0, 1.0)
彩色可视化分割图colors[label_map]
分割输出中统计精度correct = torch.where(mask, pred == target, 0)

五、补充说明

  • condition 必须是 bool 类型张量。
  • xy 的形状需 可广播
  • torch.where 是支持 GPU 的(放在 cuda() 后依然生效)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/87916.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/87916.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Hadoop的公共自行车数据分布式存储和计算平台的设计与实现

文章目录 有需要本项目的代码或文档以及全部资源,或者部署调试可以私信博主项目介绍:基于Hadoop的公共自行车数据分布式存储与计算平台设计与实现数据介绍数据预处理 Hadoop 集群的几个主要节点介绍1. NameNode(主节点)2. DataNod…

Java项目:基于SSM框架实现的程序设计课程可视化教学系统【ssm+B/S架构+源码+数据库+毕业论文】

摘 要 使用旧方法对程序设计课程可视化教学信息进行系统化管理已经不再让人们信赖了,把现在的网络信息技术运用在程序设计课程可视化教学信息的管理上面可以解决许多信息管理上面的难题,比如处理数据时间很长,数据存在错误不能及时纠正等问题…

Unity 实现 NPC 随机漫游行为的完整指南

在游戏开发中,NPC(非玩家角色)的行为逻辑对于营造真实、沉浸式的游戏世界至关重要。一个常见但又极具表现力的需求是:让 NPC 在场景中自然地随机移动,仿佛它们有自己的意识和目的地。 本文将详细介绍如何使用 Unity 的…

重新学习Vue中的按键监听和鼠标监听

文章目录按键事件1. 使用 keyup.enter 修饰符2. 使用 v-on 监听键盘事件3. 在组件上监听原生事件Vue 2Vue 34. 全局监听键盘事件注意事项鼠标事件1. 基本鼠标事件监听常用鼠标事件2. 事件修饰符3. 鼠标按键检测4. 鼠标位置信息5. 自定义指令监听鼠标事件6. 组合鼠标事件7. 性能…

vue2启动问题以及解决方案

vue2启动时:ERROR Invalid options in vue.config.js: "typescript.validate.enable" is not allowed如果需要在 VSCode 中控制 TypeScript 验证:在项目根目录创建 .vscode/settings.json 文件(如不存在)添加以下配置&a…

Vue响应式系统:从原理到核心API全解析

响应式原理 响应式机制的主要功能就是,可以把普通的JavaScript对象封装成为响应式对象,拦截数据的读取和设置操作,实现依赖数据的自动化更新。 Q: 如何才能让JavaScript对象变成响应式对象? 首先需要认识响应式数据和副作用函数…

水下目标检测:突破与创新

水下目标检测技术背景 水下环境带来独特挑战:光线衰减导致对比度降低,散射引发图像模糊,色偏使颜色失真。动态水流造成目标形变,小目标(如1010像素海胆)检测困难。声呐与光学数据融合可提升精度&#xff0…

高通SG882G平台(移远):2、使用docker镜像编译

其实之前已经编译过了。今日搜索时发现,只有当时解决问题的汇总,没有操作步骤。于是记录下来。 建议使用Ubuntu20 LTS。 安装docker $ sudo apt update $ sudo apt install docker.io $ sudo docker -v Docker version 27.5.1, build 27.5.1-0ubuntu3…

轻松上手:使用Nginx实现高效负载均衡

接上一篇《轻松上手:Nginx服务器反向代理配置指南》后,我们来探讨一下如何使用Nginx实现高效负载均衡。 在当今高并发、大流量的互联网环境下,单台服务器早已无法满足业务需求。想象一下:一次电商平台的秒杀活动、一个热门应用的…

身份证号码+姓名认证接口-身份证二要素核验

身份证号实名认证服务接口采用身份证号码、姓名二要素核验的方式,能够快速确认用户身份。无论是新用户注册,还是老用户重要操作的身份复核,只需输入姓名及身份证号,瞬间即可得到 “一致” 或 “不一致” 的核验结果。这一过程高效…

自动驾驶基本概念

目录 自动驾驶汽车(Autonomous Vehicles ) 单车智能 车联网 智能网联(单车智能车联网) 自动驾驶关键技术 环境感知与定位 车辆运动感知 车辆运动感知 路径规划与决策 自动驾驶发展历程 自动驾驶应用场景 自动驾驶路测…

提示词框架(10)--COAST

目前,有很多提示词框架都叫COAST,但是每个的解释都不同,出现很了很多解释和演化版本,不要在意这些小事,我们都是殊途同归--让AI更好的完成任务COAST框架,比较适合需要详细背景和技术支持的任务,…

基于selenium实现大麦网自动抢票脚本教程

闲来无事,打开大麦网发现现在大多数演唱票都需要手机端才能抢票,仅有很少一部分支持pc端用网页去抢票,但正所谓:道高一尺,魔高一丈,解决这个反爬问题,我们可以采用Airtest连接仿真机来模拟手机端…

2048小游戏实现

2048小游戏实现 将创建一个完整的2048小游戏,包含游戏核心逻辑和美观的用户界面。设计思路 4x4网格布局响应式设计,适配不同设备分数显示和最高分记录键盘控制(方向键)和触摸滑动支持游戏状态提示(胜利/失败&#xff0…

Windows VMWare Centos Docker部署Springboot + mybatis + MySql应用

前置文章 Windows VMWare Centos环境下安装Docker并配置MySqlhttps://blog.csdn.net/u013224722/article/details/148928081 Windows VMWare Centos Docker部署Springboot应用https://blog.csdn.net/u013224722/article/details/148958480 Windows VMWare Centos Docker部署…

【科普】Cygwin与wsl与ssh连接ubuntu有什么区别?DIY机器人工房

Cygwin、WSL(Windows Subsystem for Linux)和通过 SSH 连接 Ubuntu 是三种在 Windows 环境下与类 Unix/Linux 系统交互的工具,但它们的本质、运行环境、功能范围有显著区别。以下从核心定义、关键差异和适用场景三个维度详细说明:…

Web前端数据可视化:ECharts高效数据展示完全指南

Web前端数据可视化:ECharts高效数据展示完全指南 当产品经理拿着一堆密密麻麻的Excel数据走向你时,你知道又到了"化腐朽为神奇"的时刻。数据可视化不仅仅是把数字变成图表那么简单,它是将复杂信息转化为直观洞察的艺术。 在过去两…

# IS-IS 协议 | LSP 传输与链路状态数据库同步机制

略作整理,待校。 SRM 和 SSN 标志的作用 SRM 标志 功能:SRM 标志用于跟踪路由器从一个接口向邻居发送链路状态协议数据单元(LSP)的状态。作用:确保 LSP 的正确传输和状态跟踪。 SSN 标志 广播网络 功能&#xff1…

Windows DOS CMD 100

1. systeminfo:显示系统详细信息(安装日期/补丁/内存等) 2. sfc /scannow:扫描并修复系统文件损坏 [管理员] 3. chkdsk /f:检查磁盘错误并修复(需重启) [管理员] 4. cleanmgr:启动…

HTML初学者第三天

<1>文档类型声明标签——<!DOCTYPE><!DOCTYPE>文档声明&#xff0c;作用是告诉浏览器使用哪种HTML版本来显示网页。<!DOCTYPE html>这句代码的意思是&#xff1a;当前页面采用的是HTML5版本来显示网页。注意&#xff1a;-<!DOCTYPE>声明位于文档…