李沐--动手学深度学习 LSTM

1.从零开始实现LSTM

#从零开始实现长短期记忆网络
import torch
from torch import nn
from d2l import torch as d2l#加载时光机器数据集
batch_size,num_steps = 32,35
train_iter,vocab = d2l.load_data_time_machine(batch_size,num_steps)#1.定义和初始化模型参数:
#超参数num_hiddens定义隐藏单元的数量。按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。
def get_lstm_params(vocab_size,num_hiddens,device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape,device=device)*0.01def three():return (normal((num_inputs,num_hiddens)),normal((num_hiddens,num_hiddens)),torch.zeros(num_hiddens,device=device))W_xi,W_hi,b_i = three() #输入门参数W_xf,W_hf,b_f = three() #遗忘门参数W_xo,W_ho,b_o = three() #输出门参数W_xc,W_hc,b_c = three() #候选记忆元参数#输出层参数W_hq = normal((num_hiddens,num_outputs))b_q = torch.zeros(num_outputs,device=device)#附加梯度params = [W_xi,W_hi,b_i,W_xf,W_hf,b_f,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q]for param in params:param.requires_grad_(True)return params
#2.定义模型
#在初始化函数中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小,隐藏单元数)
def init_lstm_state(batch_size,num_hiddens,device):return (torch.zeros((batch_size,num_hiddens),device=device),torch.zeros((batch_size,num_hiddens),device=device))
#实际模型的定义与前面讨论的一样:提供三个门和一个额外的记忆元。
#只有隐状态才会传递到输出层,而记忆元mathbf{C}_t不直接参与输出计算。
def lstm(inputs,state,params):[W_xi,W_hi,b_i,W_xf,W_hf,b_f,W_xo,W_ho,b_o,W_xc,W_hc,b_c,W_hq,b_q] = params(H,C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs,dim=0),(H,C)
#3.训练和预测
vocab_size,num_hiddens,device = len(vocab),256,d2l.try_gpu()
num_epochs,lr = 500,1
model = d2l.RNNModelScratch(len(vocab),num_hiddens,device,get_lstm_params,init_lstm_state,lstm)
print(d2l.train_ch8(model,train_iter,vocab,lr,num_epochs,device))
d2l.plt.show()

2.简洁实现LSTM

#简洁实现长短期记忆网络
import torch
from torch import nn
from d2l import torch as d2l#加载时光机器数据集
batch_size,num_steps = 32,35
train_iter,vocab = d2l.load_data_time_machine(batch_size,num_steps)vocab_size,num_hiddens,device = len(vocab),256,d2l.try_gpu()
num_epochs,lr = 500,1num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs,num_hiddens)
model = d2l.RNNModel(lstm_layer,len(vocab))
model = model.to(device)
print(d2l.train_ch8(model,train_iter,vocab,lr,num_epochs,device))
d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/909515.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面经的疑难杂症

1.介绍一下虚拟地址,虚拟地址是怎么映射到物理地址的? 虚拟地址是指在采用虚拟存储管理的操作系统中,进程访问内存时所使用的地址。每个进程都有独立的虚拟地址空间,虚拟地址通过操作系统和硬件(如MMU,内存…

去噪扩散概率模型(DDPM)全解:从数学基础到实现细节

一、 概述 在这篇博客文章中,我们将深入探讨去噪扩散概率模型(也被称为 DDPMs,扩散模型,基于得分的生成模型,或简称为自动编码器),这可以说是AIGC最近几年飞速发展的基石,如果你想做…

【系统分析师】2011年真题:案例分析-答案及详解

文章目录 试题1【问题 1】【问题 2】【问题 3】 试题2【问题 1】【问题 2】【问题 3】 试题3【问题 1】【问题 2】【问题 3】 试题4【问题 1】【问题 2】【问题 3】 试题5【问题 1】【问题 2】【问题 3】 试题1 随着宽带应用快速发展,用户要求系统服务提供商提供基…

【unitrix】 1.7 规范化常量类型结构(standardization.rs)

一、源码 这段代码实现了一个二进制数字标准化系统,用于将二进制数字类型(B0/B1)转换为更简洁的表示形式。 //! 二进制数字标准化模块 / Binary Number Normalization Module //! //! 提供将二进制数字(B0/B1)标准化为更简洁表示形式的功能…

NJet Portal 应用门户管理介绍

nginx向云原生演进,All in OpenNJet! 1. 应用门户简介 NJet 应用引擎是基于 Nginx 的面向互联网和云原生应用提供的运行时组态服务程序,作为底层引擎,NJet 实现了NGINX 云原生功能增强、安全加固和代码重构,利用动态加…

uni-app学习笔记三十六--分段式选项卡组件的使用

先来看效果: 上图有3个选项卡(PS:uniapp官方称之为分段器,我还是习惯叫选项卡),需要实现点击不同的选项卡时下方切换显示对应的数据。 下面介绍下实现的过程。 1.在uniapp官方文档下载并安装该扩展组件:u…

Qt:Qt桌面程序正常退出注意事项

一般情况下,Qt窗体的创建和显示命令如下: Main_window main_window; main_window.show(); 主窗体中设置属性Qt::WA_DeleteOnClose setAttribute(Qt::WA_DeleteOnClose); 则在main.cpp中可以将窗体创建为指针,这样在退出时可以正确释放指针…

【arXiv2024】时间序列|TimesFM-ICF:即插即用!时间序列预测新王者!吊打微调!

论文地址:https://arxiv.org/pdf/2410.24087 代码地址:https://github.com/uctb/TSFM 为了更好地理解时间序列模型的理论与实现,推荐参考UP “ThePPP时间序列” 的教学视频。该系列内容系统介绍了时间序列相关知识,并提供配套的论…

从0开始学习语言模型--Day02-如何最大化利用硬件

如何利用硬件 这个单元分为内核、并行处理和推理。 内核(Kernels) 我们说的内核一般指的就是GPU,这是我们用于计算的地方,一般说的计算资源就指的是GPU的大小。我们模型所用的数据和参数一般存储在内存里,假设把内存…

ElasticSearch配置详解:设置内存锁定的好处

什么是内存锁定 "bootstrap": {"memory_lock": "true" }内存锁定是指将Elasticsearch的JVM堆内存锁定在物理内存中,防止操作系统将其交换(swap)到磁盘。 内存交换是操作系统的虚拟内存管理机制,当…

成功解决 ValueError: Unable to find resource t64.exe in package pip._vendor.distlib

解决问题 我们在本地的命令行中运行指令"python -m pip install --upgrade pip"的时候,报了如下的错误: 解决思路 我们需要重新安装一下pip。 解决方法 步骤1: 通过执行下面的指令删除本地的pip: python -m pip uninstall pip…

仓库物资出入库管理系统源码+uniapp小程序

一款基于ThinkPHPuniapp开发的仓库物资出入库管理系统,适用于单位内部物资采购、发放管理的库存管理系统。提供全部无加密源码,支持私有化部署。 更新日志: 新增 基于UNIAPP开发的手机端,适配微信小程序 新增 字典管理 新增页面…

基于机器学习的逐巷充填开采岩层运动地表沉降预测

基于机器学习的逐巷充填开采岩层运动地表沉降预测 1. 项目概述 本报告详细介绍了使用Python和机器学习技术预测逐巷充填开采过程中地表沉降的方法。通过分析地质参数、开采参数和充填参数,构建预测模型评估地表沉降风险。 # 导入必要的库 import numpy as np import pandas…

MotleyCrew ——抛弃dify、coze,手动搭建多agent工作流

1. MotleyCrew 核心组件 - 协调器: Crew MotleyCrew 的核心是一个 “Crew” 对象,即多代理系统的指挥者。Crew 持有一个全局的知识图谱(使用 Kuzu 图数据库),用于记录所有任务、任务单元和其执行状态。 Cr…

掌握这些 Python 函数,让你的代码更简洁优雅

在 Python 编程世界里,代码的简洁性与可读性至关重要。简洁优雅的代码不仅便于自己后期维护,也能让其他开发者快速理解逻辑。而 Python 丰富的内置函数和一些实用的第三方库函数,就是实现这一目标的有力武器。接下来,就为大家介绍…

简说ping、telnet、netcat

简说 ping 和 telnet 命令的作用、用法和区别,方便理解它们在网络诊断中的用途。 🌐 ping 命令 ✅ 作用: ping 用于检测网络连通性。它通过向目标主机发送 ICMP Echo 请求 并等待回应,从而判断目标主机是否可达,并测…

基于STM32的超声波模拟雷达设计

一、雷达概述 雷达(Radio Detection and Ranging,无线电探测与测距)是一种利用电磁波探测目标位置、速度等信息的主动式传感器系统。其基本原理是发射电磁波并接收目标反射的回波,通过分析回波的时间差、频率变化等参数&#xff0…

飞书多维表格利用 Amazon Bedrock AI 能力赋能业务

背景 飞书多维表格是一款功能强大的在线数据管理与协作工具。它打破传统表格局限,将电子表格与数据库特性融合,支持看板、甘特图、表单等多种视图自由切换,可根据项目进度、任务管理等不同场景灵活展示数据。其丰富的字段类型能精准适配各类…

表格对比工具推荐,快速比对Excel文件

软件介绍 今天为大家推荐一款专为Excel用户设计的表格比较工具,简单易用,零基础也能快速掌握。 轻量高效的办公助手 Excel比较工具体积仅为11MB,占用空间小,运行流畅,适合各类电脑配置使用。 简洁明了的操作界面 软…

深入探究其内存开销与JVM布局——Java Record

Java 14引入的Record类型如同一股清流,旨在简化不可变数据载体的定义。它的核心承诺是:​​透明的数据建模​​和​​简洁的语法​​。自动生成的equals(), hashCode(), toString()以及构造器极大地提升了开发效率。 当我们看到这样的代码: …