PyTorch中三角函数与特殊运算详解和实战场景示例

在 PyTorch 中,三角函数(如 sin, cos, tan 等)和一些特殊数学运算(如双曲函数、反三角函数、hypot, atan2, clamp, lerp, sigmoid, softplus, special 模块等)被广泛用于科学计算、机器学习、深度学习中的前向推理或梯度计算中。


PyTorch 中的三角函数与特殊运算详解


1. 常见三角函数

torch.sin(input)

  • 定义:对输入张量每个元素计算正弦值
  • 参数input (Tensor) – 输入张量(角度以弧度为单位)
  • 返回值:返回一个和 input 同形状的张量,每个元素是 sin(x)
  • 示例
import torch
x = torch.tensor([0, torch.pi/2, torch.pi])
y = torch.sin(x)
print(y)  # tensor([0.0000, 1.0000, 0.0000])

torch.cos(input)

  • 类似于 sin,返回余弦值

torch.tan(input)

  • 返回正切值(注意:tan(π/2) 会出现无穷大)

torch.asin(input), torch.acos(input), torch.atan(input)

  • 反三角函数,输出的是角度(单位:弧度)
  • 输入值需在定义域内(asin/acos 的输入必须在 [-1, 1]

torch.atan2(input, other)

  • 定义:返回 atan(input / other),但能处理分母为 0 的情形,按象限返回正确角度

  • 参数

    • input: 分子
    • other: 分母
  • 示例

a = torch.tensor([1.0])
b = torch.tensor([1.0])
print(torch.atan2(a, b))  # 输出 π/4 ≈ 0.7854

torch.hypot(x, y)

  • 定义:计算 sqrt(x² + y²)
  • 常用于二维向量的模长(欧几里得范数)
x = torch.tensor([3.0])
y = torch.tensor([4.0])
print(torch.hypot(x, y))  # 输出 tensor([5.])

2. 双曲函数

torch.sinh, torch.cosh, torch.tanh

x = torch.tensor([0.0, 1.0])
print(torch.sinh(x))  # tensor([0.0000, 1.1752])
print(torch.tanh(x))  # tensor([0.0000, 0.7616])

3. 特殊数学函数(常用于神经网络中)

torch.sigmoid(input)

  • 返回:1 / (1 + exp(-x))
  • 常用于二分类模型输出层
x = torch.tensor([-1.0, 0.0, 1.0])
print(torch.sigmoid(x))  # tensor([0.2689, 0.5000, 0.7311])

torch.nn.functional.softplus(input)

  • 类似于平滑版的 ReLU:log(1 + exp(x))
  • 可用于避免 ReLU 的非可导性问题
import torch.nn.functional as F
x = torch.tensor([-1.0, 0.0, 1.0])
print(F.softplus(x))  # tensor([0.3133, 0.6931, 1.3133])

torch.lerp(start, end, weight)

  • 线性插值:(1 - weight) * start + weight * end
a = torch.tensor([0.0])
b = torch.tensor([10.0])
print(torch.lerp(a, b, 0.3))  # tensor([3.])

torch.clamp(input, min, max)

  • 限制张量的最小最大范围
x = torch.tensor([-2.0, 0.5, 3.0])
print(torch.clamp(x, min=0.0, max=1.0))  # tensor([0.0, 0.5, 1.0])

4. torch.special 模块(高级特殊函数)

import torch.special as S# gamma 函数:S.gamma(x) = (x-1)!
x = torch.tensor([1.0, 2.0, 3.0])
print(S.gamma(x))  # tensor([1., 1., 2.])# erf 函数(高斯误差函数)
print(S.erf(torch.tensor([0.0, 1.0])))  # tensor([0.0000, 0.8427])

常用 torch.special 函数包括:

函数名功能
special.erf高斯误差函数
special.gammalngamma 函数的对数
special.logitsigmoid 的反函数
special.expitsigmoid 函数(别名)
special.i0第一类贝塞尔函数
special.digammagamma 函数的导数

总结对比表

函数作用说明
torch.sin/cos/tan三角函数输入为弧度,输出 [-1, 1]
torch.asin/acos/atan反三角函数返回弧度
torch.atan2atan(y/x) 并考虑象限更安全的除法处理
torch.hypot计算 √(x² + y²)常用于距离计算
torch.sigmoidS型函数常用于分类神经网络
torch.softplus平滑 ReLU输出始终 > 0
torch.clamp限定区间防止梯度爆炸或数值异常
torch.lerp线性插值用于图像插值或平滑过渡
torch.special特殊函数模块包含 gamma、贝塞尔等高级函数

使用 PyTorch 计算三角函数的梯度(自动求导演示)

PyTorch 中只要一个张量启用了 .requires_grad=True,就可以使用 .backward() 自动求导。


示例 1:sin(x) 的梯度(导数为 cos(x)

import torchx = torch.tensor([torch.pi / 4], requires_grad=True)  # 45度,sin(π/4) ≈ 0.7071
y = torch.sin(x)       # y = sin(x)
y.backward()           # 计算 dy/dx
print("y =", y.item())
print("dy/dx =", x.grad.item())  # 应该为 cos(π/4) ≈ 0.7071

输出(大约):

y = 0.7071
dy/dx = 0.7071

示例 2:tanh(x) 的梯度(导数为 1 - tanh²(x)

x = torch.tensor([1.0], requires_grad=True)
y = torch.tanh(x)
y.backward()
print("y =", y.item())
print("dy/dx =", x.grad.item())  # 约为 1 - tanh(x)^2 ≈ 0.4199

输出:

y = 0.7616
dy/dx = 0.4199

示例 3:sigmoid(x) 的梯度(导数为 s(x)*(1-s(x))

x = torch.tensor([0.0], requires_grad=True)
y = torch.sigmoid(x)
y.backward()
print("y =", y.item())            # 0.5
print("dy/dx =", x.grad.item())  # 0.5 * (1 - 0.5) = 0.25

示例 4:组合函数 + 多元素求导

x = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
y = torch.sin(x) + torch.cos(x)
z = y.sum()     # 标量才能 .backward()
z.backward()
print("x.grad =", x.grad)

输出大约为:

x.grad ≈ tensor([cos(0.1) - sin(0.1),cos(0.2) - sin(0.2),cos(0.3) - sin(0.3)])

注意事项

条件说明
requires_grad=True启用自动求导
输出必须是标量才能 .backward()否则需手动传入梯度向量
.grad 只能用于叶子节点变量中间变量(如 y)无 .grad
多次 .backward() 需使用 retain_graph=True否则计算图会被释放

示例 5:对 special.erf() 求梯度

import torch.special as Sx = torch.tensor([0.5], requires_grad=True)
y = S.erf(x)
y.backward()
print("dy/dx =", x.grad.item())  # d/dx erf(x) = 2 / sqrt(pi) * exp(-x^2)

梯度验证工具(额外推荐)

可以使用 torch.autograd.gradcheck 做数值梯度验证(需要使用 double 类型):

from torch.autograd import gradcheckx = torch.tensor([0.5], dtype=torch.double, requires_grad=True)
test = gradcheck(torch.sin, (x,), eps=1e-6, atol=1e-4)
print("Sin gradcheck:", test)

实际神经网络中使用示例

下面我们深入介绍 在实际神经网络中如何使用 PyTorch 的三角函数和特殊函数(如 sin、sigmoid、softplus 等),特别是它们在 loss 函数、自定义激活、周期性建模等场景中的实际用法


场景 1:周期性数据建模(比如时间、角度)用 sin/cos

应用背景:

对于 角度、时间(24小时) 等周期性输入,用 sin/cos 编码能避免 0° 和 360° 被认为相差很远的问题。


示例:用 sin/cos 编码输入 + 简单回归网络

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SinCosNet(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(2, 1)  # 输入是 sin 和 cos,输出是预测值def forward(self, x_deg):x_rad = x_deg * torch.pi / 180x = torch.stack([torch.sin(x_rad), torch.cos(x_rad)], dim=1)  # [N, 2]out = self.fc(x)return out# 模拟输入:角度(以度为单位)
x_deg = torch.tensor([0.0, 90.0, 180.0, 270.0]).reshape(-1, 1)
y_true = torch.tensor([0.0, 1.0, 0.0, -1.0]).reshape(-1, 1)  # 例如正弦值作为目标model = SinCosNet()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 前向 + 反向 + 更新
for epoch in range(200):pred = model(x_deg)loss = loss_fn(pred, y_true)optimizer.zero_grad()loss.backward()optimizer.step()print("预测值:", model(x_deg).detach().squeeze())

场景 2:使用 softplus 替代 ReLU 防止死神经

class SoftplusNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(10, 64)self.fc2 = nn.Linear(64, 1)def forward(self, x):x = F.softplus(self.fc1(x))  # 使用 softplus 而不是 ReLUreturn self.fc2(x)

适用场景:

  • 高稳定性训练
  • 避免 ReLU 导致的神经元“死亡”
  • sigmoid 搭配时,数值更平滑

场景 3:自定义 Loss 函数中使用 torch.atan2, sin, cos

示例:角度差异损失(用于姿态估计、旋转预测)

def angular_loss(pred_angle, target_angle):"""计算两个角度之间的最小差(结果范围在 [-pi, pi])"""diff = pred_angle - target_anglediff = torch.atan2(torch.sin(diff), torch.cos(diff))  # wrap 到 [-pi, pi]return torch.mean(diff ** 2)  # MSE 损失

应用背景:

  • 预测角度(如相机旋转、姿态)
  • 避免直接使用 pred - target 导致 359° 和 0° 差异非常大

场景 4:使用 sigmoid + torch.special.logit 实现稳定反向函数 Loss

import torch.special as Sdef binary_target_loss(pred, target):pred = torch.clamp(pred, 1e-6, 1 - 1e-6)  # 避免 logit 无穷大loss = torch.mean((S.logit(pred) - S.logit(target)) ** 2)return loss

应用背景:

  • 对于二分类输出,可以用 logit 空间计算差异,更稳定、更敏感

场景 5:周期函数的拟合(神经网络拟合 sin 函数)

import matplotlib.pyplot as pltclass SineFitter(nn.Module):def __init__(self):super().__init__()self.net = nn.Sequential(nn.Linear(1, 64),nn.Tanh(),nn.Linear(64, 64),nn.Tanh(),nn.Linear(64, 1))def forward(self, x):return self.net(x)x = torch.linspace(-2*torch.pi, 2*torch.pi, 200).unsqueeze(1)
y = torch.sin(x)model = SineFitter()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()for epoch in range(1000):pred = model(x)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()# 可视化
plt.plot(x.detach(), y, label="Ground Truth")
plt.plot(x.detach(), model(x).detach(), label="Predicted")
plt.legend()
plt.title("Sine Function Fitting")
plt.show()

总结:三角函数/特殊函数在神经网络中的常见用途

应用领域使用的函数说明
角度建模sin, cos, atan2编码/解码周期性角度
激活函数softplus, tanh, sigmoid平滑激活、避免死神经
自定义损失函数atan2, logit, erf更稳定地处理周期性/概率性误差
函数拟合sin, special.*网络学习任意复杂函数,特别是周期/光波类
数据归一化或变换clamp, lerp, special.logit控制数据范围,避免梯度爆炸或损失异常

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/92197.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/92197.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文阅读: Mobile Edge Intelligence for Large LanguageModels: A Contemporary Survey

地址:Mobile Edge Intelligence for Large Language Models: A Contemporary Survey 摘要 设备端大型语言模型(LLMs)指在边缘设备上运行 LLMs,与云端模式相比,其成本效益更高、延迟更低且更能保护隐私,因…

JavaWeb(苍穹外卖)--学习笔记17(Websocket)

前言 本篇文章是学习B站黑马程序员苍穹外卖的学习笔记📑。我的学习路线是Java基础语法-JavaWeb-做项目,管理端的功能学习完之后,就进入到了用户端微信小程序的开发,🙌用户下单并且支付成功后,需要第一时间通…

WebForms 简介

WebForms 简介 概述 WebForms 是微软公司推出的一种用于构建动态网页和应用程序的技术。自 2002 年推出以来,WebForms 成为 ASP.NET 技术栈中重要的组成部分。它允许开发者以类似于桌面应用程序的方式创建交互式网页,极大地提高了 Web 开发的效率和体验。 WebForms 的工作…

vsCode软件中JS文件中启用Emmet语法支持(React),外加安装两个常用插件

1.点击vsCode软件中的设置(就是那个齿轮图标),如下图2.在搜索框中输入emmet,然后点击添加项,填写以下值:项:javascript 值:javascriptreact。如下图3.可以安装两个常用插件&#xf…

【第2话:基础知识】 自动驾驶中的世界坐标系、车辆坐标系、相机坐标系、像素坐标系概念及相互间的转换公式推导

自动驾驶中的坐标系概念及相互间的转换公式推导 在自动驾驶系统中,多个坐标系用于描述车辆、传感器和环境的相对位置。这些坐标系之间的转换是实现定位、感知和控制的关键。下面我将逐步解释常见坐标系的概念,并推导相互转换的公式。推导基于标准几何变换…

深度拆解Dify:开源LLM开发平台的架构密码与技术突围

注:此文章内容均节选自充电了么创始人,CEO兼CTO陈敬雷老师的新书《GPT多模态大模型与AI Agent智能体》(跟我一起学人工智能)【陈敬雷编著】【清华大学出版社】 清华《GPT多模态大模型与AI Agent智能体》书籍配套视频课程【陈敬雷…

tomcat处理请求流程

1.浏览器在请求一个servlet时,会按照HTTP协议构造一个HTTP请求,通过Socket连接发送给Tomcat. 2.Tomcat通过不同的IO模型接收到Socket的字节流数据。 3.接收到数据后,按照HTTP协议解析字节流,得到HttpServletRequest对象 4.通过HttpServletRequest对象,也就是请求信息,找到该请求…

【音视频】WebRTC 一对一通话-信令服

一、服务器配置 服务器在Ubuntu下搭建,使用C语言实现,由于需要使用WebSocket和前端通讯,同时需要解析JSON格式,因此引入了第三方库:WebSocketpp和nlonlohmann,这两个库的具体配置方式可以参考我之前的博客…

Spring(以 Spring Boot 为核心)与 JDK、Maven、MyBatis-Plus、Tomcat 的版本对应关系及关键注意事项

以下是 Spring(以 Spring Boot 为核心)与 JDK、Maven、MyBatis-Plus、Tomcat 的版本对应关系及关键注意事项,基于最新技术生态整理: 一、Spring Boot 与 JDK 版本对应 Spring Boot 2.x 系列 最低要求:JDK 1.8推荐版本…

03-基于深度学习的钢铁缺陷检测-yolo11-彩色版界面

目录 项目介绍🎯 功能展示🌟 一、环境安装🎆 环境配置说明📘 安装指南说明🎥 环境安装教学视频 🌟 二、系统环境(框架/依赖库)说明🧱 系统环境与依赖配置说明&#x1f4c…

24. 前端-js框架-Vue

文章目录前言一、Vue介绍1. 学习导图2. 特点3. 安装1. 方式一:独立版本2. 方式二:CDN方法3. 方式三:NPM方法(推荐使用)4. 搭建Vue的开发环境(大纲)5. 工程结构6. 安装依赖资源7. 运行项目8. Vue…

Spring 的依赖注入DI是什么?

口语化答案好的,面试官,依赖注入(Dependency Injection,简称DI)是Spring框架实现控制反转(IoC)的主要手段。DI的核心思想是将对象的依赖关系从对象内部抽离出来,通过外部注入的方式提…

汇川PLC通过ModbusTCP转Profinet网关连接西门子PLC配置案例

本案例是汇川的PLC通过开疆智能研发的ModbusTCP转Profient网关读写西门子1200PLC中的数据。汇川PLC作为ModbusTCP的客户端网关作为服务器,在Profinet一侧网关作为从站接收1200PLC的数据并转成ModbusTCP协议被汇川PLC读取。配置过程:汇川PLC配置Modbus TC…

【计组】数据的表示与运算

机器数与真值机器数真值编码原码特点表示范围求真值方法反码特点补码特点表示范围求真值方法移码特点表示范围求真值方法相互转换原码<->补码补码<->移码原码<->反码反码<->补码移位左移右移逻辑右移算术右移符号扩展零扩展整数小数符号扩展运算器部件…

视频水印技术中的变换域嵌入方法对比分析

1. 引言 随着数字视频技术的快速发展和网络传输的普及,视频内容的版权保护问题日益突出。视频水印技术作为一种有效的版权保护手段,通过在视频中嵌入不可见或半可见的标识信息,实现对视频内容的所有权认证、完整性验证和盗版追踪。在视频水印技术的发展历程中,变换域水印因…

电动汽车电池管理系统设计与实现

电动汽车电池管理系统设计与实现 1. 引言 电动汽车电池管理系统(BMS)是确保电池组安全、高效运行的关键组件。本文将详细介绍一个完整的BMS系统的MATLAB实现,包括状态估计(SOC/SOH)、参数监测、电池平衡和保护功能。系统设计为模块化结构,便于扩展和参数调整。 2. 系统架构…

JVM(Java Virtual Machine,Java 虚拟机)超详细总结

一、JVM的基础概念1、概述JVM是 Java 程序的运行基础环境&#xff0c;是 Java 语言实现 “一次编写&#xff0c;到处运行” &#xff08;"write once , run anywhere. "&#xff09;特性的关键组件&#xff0c;具体从以下几个方面来理解&#xff1a;概念层面JVM 是一…

Balabolka软件调用微软离线自然语音合成进行文字转语音下载安装教程

首先&#xff0c;需要准备安装包 Balabolka NaturalVoiceSAPIAdapterMicrosoftWindows.Voice.zh-CN.Xiaoxiao.1_1.0.9.0_x64__cw5n1h2txyewy.Msix MicrosoftWindows.Voice.zh-CN.Yunxi.1_1.0.4.0_x64__cw5n1h2txyewy.Msix借助上面这个工具&#xff1a;NaturalVoiceSAPIAdapter&…

Java修仙之路,十万字吐血整理全网最完整Java学习笔记(高级篇)

导航&#xff1a; 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/谷粒商城/学成在线设计模式面试题汇总性能调优/架构设计源码解析 推荐视频&#xff1a; 黑马程序员全套Java教程_哔哩哔哩 尚硅谷Java入门视频教程_哔哩哔哩 推荐书籍&#xff1a; 《Ja…

接口测试用例和接口测试模板

一、简介 3天精通Postman接口测试&#xff0c;全套项目实战教程&#xff01;&#xff01;接口测试区别于传统意义上的系统测试&#xff0c;下面介绍接口测试用例和接口测试报告。 二、接口测试用例模板 功能测试用例最重要的两个因素是测试步骤和预期结果&#xff0c;接口测试…