从认识AI开始-----解密LSTM:RNN的进化之路

前言

我在上一篇文章中介绍了 RNN,它是一个隐变量模型,主要通过隐藏状态连接时间序列,实现了序列信息的记忆与建模。然而,RNN在实践中面临严重的“梯度消失”与“长期依赖建模困难”问题:

  • 难以捕捉相隔很远的时间步之间的关系
  • 隐状态在不断更新中容易遗忘早期信息。

为了解决这些问题,LSTM(Long Short-Term Memory) 网络于 1997 年被 Hochreiter等人提出,该模型是对RNN的一次重大改进。


一、LSTM相比RNN的核心改进

接下来,我们通过对比RNN、LSTM,来看一下具体的改进:

模型特点优势缺点
RNN单一隐藏转态,时间步传递结构简答容易造成梯度消失/爆炸,对长期依赖差
LSTM多门控机制 + 单独的“记忆单元”解决长距离依赖问题,保留长期信息结构复杂,计算开销大

通过对比,我们可以发现,其实LSTM的核心思想是:引入了一个专门的“记忆单元”,在通过门控机制对信息进行有选择的保留、遗忘与更新


二、LSTM的核心结构

LSTM的核心结构如下图所示:

 如图可以轻松的看出,LSTM主要由门控机制和候选记忆单元组成,对于每个时间步,LSTM都会进行以下操作:

1. 忘记门

忘记门F_t主要的作用是:控制保留多少之前的记忆:

F_t=\sigma(X_t@W_{xf}+H_{t-1}@W_{hf}+b_f)

2. 输入门

输入门I_t主要的作用是:决定当前输入中哪些信息信息被写入记忆:

I_t=\sigma(X_t@W_{xi}+H_{t-1}@W_{hi}+b_i)

3. 候选记忆单元

\tilde C_t=tanh(X_t@W_{xc}+H_{t-1}@W_{hc}+b_c)

4. 输出门

输出门O_t的作用是:决定是是否使用隐状态:

O_t=\sigma(X_t@W_{xo}+H_{t-1}@W_{ho}+b_o)

5. 真正记忆单元

记忆单元( C_t )用于长期存储信息,解决RNN容易遗忘的问题:

C_t=F_t*C_{t-1}+I_t*\tilde C_{t}

7. 隐藏转态

H_t=O_t*tanh(C_t)

LSTM引入了专门的记忆单元 C_t  ,长期存储信息,解决了传统RNN容易遗忘的问题。


三、手写LSTM

通过上面的介绍,我们现在已经知道了LSTM的实现原理,现在,我们试着手写一个LSTM核心层:

首先,初始化需要训练的参数:

import torch
import torch.nn as nn
import torch.nn.functional as Fdef params(input_size, output_size, hidden_size):W_xi, W_hi, b_i = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xf, W_hf, b_f = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xo, W_ho, b_o = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_xc, W_hc, b_c = torch.randn(input_size, hidden_size) * 0.1, torch.randn(hidden_size, hidden_size) * 0.1, torch.zeros(hidden_size)W_hq = torch.randn(hidden_size, output_size) * 0.1b_q = torch.zeros(output_size)params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]for param in params:param.requires_grad = Truereturn params

接着,我们需要初始化0时刻的隐藏转态:

import torchdef init_state(batch_size, hidden_size):return (torch.zeros((batch_size, hidden_size)), torch.zeros((batch_size, hidden_size)))

然后, 就是LSTM的核心操作:

import torch
import torch.nn as nn
def lstm(X, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params(H, C) = stateoutputs = []for x in X:I = torch.sigmoid(torch.mm(x, W_xi) + torch.mm(H, W_hi) + b_i)F = torch.sigmoid(torch.mm(x, W_xf) + torch.mm(H, W_hf) + b_f)O = torch.sigmoid(torch.mm(x, W_xo) + torch.mm(H, W_ho) + b_o)C_tilde = torch.tanh(torch.mm(x, W_xc) + torch.mm(H, W_hc) + b_c)C = F * C + I * C_tildeH = O * torch.tanh(C)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=1), (H, C)

四、使用Pytroch实现简单的LSTM

在Pytroch中,已经内置了lstm函数,我们只需要调用就可以实现上述操作:

import torch
import torch.nn as nnclass mylstm(nn.Module):def __init__(self, input_size, output_size, hidden_size):super(mylstm, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, h0, c0):out, (hn, cn) = self.lstm(x, h0, c0)out = self.fc(out)return out, (hn, cn)# 示例
input_size = 10
hidden_size = 20
output_size = 10
batch_size = 1
seq_len = 5
num_layer = 1 # lstm堆叠层数h0 = torch.zeros(num_layer, batch_size, hidden_size)
c0 = torch.randn(num_layer, batch_size, hidden_size)
x = torch.randn(batch_size, seq_len, hidden_size)model = mylstm(input_size=input_size, hidden_size=hidden_size, output_size=output_size)out, _ = model(x, (h0, c0))
print(out.shape)

总结

在现实中,LSTM的实际应用场景很多,比如语言模型、文本生成、时间序列预测、情感分析等长序列任务重,这是因为相比于RNN而言,LSTM能够更高地捕捉长期依赖,而且也更好的缓解了梯度消失问题;但是由于LSTM引入了三个门控机制,导致参数量比RNN要多,训练慢。

总的来说,LSTM是对传统RNN的一次革命性升级,引入门控机制和记忆单元,使模型能够选择性地记忆与遗忘,从而有效地捕捉长距离依赖。尽管LSTM近年来Transformer所取代,但LSTM依然是理解深度学习序列模型不可绕开的一环,有时在其他任务上甚至优于Transformer。


如果小伙伴们觉得本文对各位有帮助,欢迎:👍点赞 | ⭐ 收藏 |  🔔 关注。我将持续在专栏《人工智能》中更新人工智能知识,帮助各位小伙伴们打好扎实的理论与操作基础,欢迎🔔订阅本专栏,向AI工程师进阶!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/907638.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

接地气的方式认识JVM(一)

最近在学jvm,浮于表面的学了之后,发现jvm并没有我想象中的那么神秘,这篇文章将会用接地气的方式来说一说这些jvm的相关概念以及名词解释。 带着下面两个问题来阅读 认识了解JVM大致有什么在代码运行时的都在背后做了什么 JVM是个啥&#xf…

Next.js 15 与 Apollo Client 的现代集成及性能优化

Next.js 15 与 Apollo Client 的现代集成及性能优化 目录 技术演进集成实践性能优化应用案例未来趋势 技术演进 Next.js 15 核心特性对开发模式的革新 Next.js 15 通过引入 App Router、服务器组件(Server Components)和客户端组件(Clie…

无人机桥梁3D建模、巡检、检测的航线规划

无人机桥梁3D建模、巡检、检测的航线规划 无人机在3D建模、巡检和检测任务中的航线规划存在显著差异,主要体现在飞行高度、航线模式、精度要求和传感器配置等方面。以下是三者的详细对比分析: 1. 核心目标差异 任务类型主要目标典型应用场景3D建模 生成…

Hive数据倾斜问题深度解析与实战优化指南

一、数据倾斜现象的本质与危害 数据倾斜是Hive在MapReduce计算过程中,​部分Key对应的数据量远超其他Key,导致少数Reducer任务处理时间远高于其他任务的性能瓶颈问题。典型表现为: ​作业进度卡在99%​​:99%的Reducer已完成,剩余1%持续数小时​资源利用率失衡​:部分节…

VRRP 原理与配置:让你的网络永不掉线!

VRRP 原理与配置:让你的网络永不掉线! 一. VRRP 是什么,为什么需要它?二. VRRP 的核心概念三. VRRP 的工作原理四. 华为设备 VRRP 配置步骤 (主备模式)4.1 拓扑示例4.2 🛠 配置步骤 五. VRRP 配…

解决开发者技能差距:AI 在提升效率与技能培养中的作用

企业在开发者人才方面正面临双重挑战。一方面,IDC 预测,到2025年,全球全职开发者将短缺400万人;另一方面,一些行业巨头已暂停开发者招聘,转而倚重人工智能(AI)来满足开发需求。这不禁…

痛点即爆点?如何挖掘客户的痛点和需求?

销售的核心在于精准洞察客户需求与痛点,并运用专业能力为其提供定制化解决方案,从而消除客户顾虑、解决问题,最终实现双赢。而快速识别客户痛点,不仅是成交的关键,更是建立专业形象、赢得客户信任的核心能力。那么&…

云服务器如何自动更新系统并保持安全?

云服务器自动更新系统是保障安全、修补漏洞的重要措施。下面是常见 Linux 系统(如 Ubuntu、Debian、CentOS)和 Windows 服务器自动更新的做法和建议: 1. Linux 云服务器自动更新及安全维护 Ubuntu / Debian 系统 手动更新命令 sudo apt up…

fvm install 下载超时 过慢 fvm常用命令、flutter常用命令

Git 配置问题 确保 Git 使用的是 HTTPS,而不是 SSH。如果你有 .gitconfig,确保没有配置奇怪的代理: git config --global --get http.proxy git config --global --get https.proxy如果有代理设置且不需要,取消代理:…

多语种OCR识别系统,引领文字识别新时代

在全球化与数字化深度融合的今天,语言障碍成为企业跨国协作、信息管理的一大挑战。无论是跨国合同签署、多语言档案管理,还是跨境商务沟通,高效精准的文字识别技术已成为刚需。中安智能OCR多语种识别系统应运而生,凭借其强大的光学…

Pyenv 使用指南:多版本 Python 环境管理

目录 Pyenv 是什么?安装 Pyenv管理 Python 版本虚拟环境管理项目级 Python 版本控制高级技巧常见问题解决最佳实践 Pyenv 是什么? Pyenv 是一个强大的 Python 版本管理工具,允许你: 在同一台机器上安装多个 Python 版本轻松切换…

Windows 11 家庭版 安装Docker教程

Windows 家庭版需要通过脚本手动安装 Hyper-V 一、前置检查 1、查看系统 快捷键【winR】,输入“control” 【控制面板】—>【系统和安全】—>【系统】 2、确认虚拟化 【任务管理器】—【性能】 二、安装Hyper-V 1、创建并运行安装脚本 在桌面新建一个 .…

leetcode:479. 最大回文数乘积(python3解法,数学相关算法题)

难度:简单 给定一个整数 n ,返回 可表示为两个 n 位整数乘积的 最大回文整数 。因为答案可能非常大,所以返回它对 1337 取余 。 示例 1: 输入:n 2 输出:987 解释:99 x 91 9009, 9009 % 1337 …

VR看房系统,新生代看房新体验

VR看房系统的概念 虚拟现实(VirtualReality,VR)看房系统,是近年来随着科技进步在房地产行业中兴起的一种创新看房方式。看房系统利用先进的计算机技术模拟出一个三维环境,使用户能够身临其境地浏览和体验房源,无需亲自…

栈与队列:数据结构的有序律动

在数据结构的舞台上,栈与队列宛如两位优雅的舞者,以独特的节奏演绎着数据的进出规则。它们虽不像顺序表与链表那般复杂多变,却有着令人着迷的简洁与实用,在众多程序场景中发挥着不可或缺的作用。今天,就让我们一同去探…

Flutte ListView 列表组件

目录 1、垂直列表 1.1 实现用户中心的垂直列表 2、垂直图文列表 2.1 动态配置列表 2.2 for循环生成一个动态列表 2.3 ListView.builder配置列表 列表布局是我们项目开发中最常用的一种布局方式。Flutter中我们可以通过ListView来定义列表项,支持垂直和水平方向展示…

跟Gemini学做PPT-模板样式的下载

好的,这里有一些推荐的网站,您可以在上面找到PPT目录样式和模板的灵感: SlideModel (slidemodel.com) 提供各种预先设计的目录幻灯片模板。这些模板100%可编辑,可用于PowerPoint和Google Slides。您可以找到不同项目数量&#xff…

【Netty系列】Reactor 模式 1

目录 一、Reactor 模式的核心思想 二、Netty 中的 Reactor 模式实现 1. 服务端代码示例 2. 处理请求的 Handler 三、运行流程解析(结合 Reactor 模式) 四、关键点说明 五、与传统模型的对比 六、总结 Reactor 模式是 Netty 高性能的核心设计思想…

LDAP(Lightweight Directory Access Protocol,轻量级目录访问协议)认证

理解 LDAP(Lightweight Directory Access Protocol,轻量级目录访问协议)认证,核心在于将其看作一种用于查询和验证用户身份信息的标准协议,类似于一个专门为“查找”优化的电子电话簿系统。以下是分层解析:…

LeetCodeHot100_0x09

LeetCodeHot100_0x09 70. 最小栈数据结构实现 求解思路: 一开始想着只用一个最小栈结构不就实现了,结果测试的时候发现,在pop元素后,它的最小值有可能不受影响,但是只用一个最小栈的话,最小值一定是作为栈…