一文读懂循环神经网络—从零实现长短期记忆网络(LSTM)

目录

一、遗忘门(Forget Gate):决定 “该忘记什么”

二、输入门(Input Gate):决定 “该记住什么新信息”

三、输出门(Output Gate):决定 “该输出什么”

四、候选记忆元(Candidate Cell State):“待存入的新信息”

五、记忆元(Cell State):长期记忆的 “仓库”

六、 隐状态(Hidden State):短期输出与信息传递

七、 门控记忆元(Gated Cell):整体协同机制

八、各组件协同流程

九、为什么能解决长期依赖?

十、LSTM的结构图

 十一、完整代码

十二、实验结果


一、遗忘门(Forget Gate):决定 “该忘记什么”

遗忘门的作用是筛选上一时刻记忆元中需要保留的信息。它根据 “上一时刻的隐状态” 和 “当前输入”,判断哪些历史信息可以被丢弃,哪些需要继续保留。

  • 输入:上一时刻的隐状态 h_{t-1} + 当前时间步的输入 x_t
  • 计算过程: 先将两者拼接,通过一个全连接层(权重为W_f,偏置为b_f),再经过 sigmoid 激活函数(输出范围 0~1):f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)其中,\sigma是 sigmoid 函数(\sigma(z) = 1/(1+e^{-z}),输出 f_t是一个与记忆元同维度的向量(每个元素对应记忆元中的一个 “信息位”)。
  • 含义
    • f_t中元素越接近 1:表示上一时刻记忆元中对应位置的信息 “完全保留”;
    • 越接近 0:表示对应位置的信息 “完全遗忘”。

:在句子 “我喜欢吃苹果,不喜欢吃香蕉,……” 中,当读到 “香蕉” 时,遗忘门会让 “苹果” 的相关信息适当保留(但权重可能降低),以便后续对比。

二、输入门(Input Gate):决定 “该记住什么新信息”

输入门的作用是筛选当前输入中需要存入记忆元的新信息。它和遗忘门配合,共同完成记忆元的更新(先忘旧的,再记新的)。

  • 输入:同样是上一时刻的隐状态h_{t-1} + 当前输入 x_t
  • 计算过程: 拼接后通过另一个全连接层(权重 W_i,偏置 b_i),再经 sigmoid 激活: i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)其中,i_t也是 0~1 的向量,每个元素对应 “当前输入中该信息位是否允许进入记忆元”。
  • 含义
    • i_t中元素越接近 1:表示当前输入中对应位置的新信息 “允许存入记忆元”;
    • 越接近 0:表示该新信息 “不存入记忆元”。

三、输出门(Output Gate):决定 “该输出什么”

输出门的作用是当前记忆元中筛选信息,生成当前时间步的隐状态(即模型的 “当前输出”)。隐状态会传递到下一时间步,同时作为当前步的输出(比如预测下一个词)。

  • 输入:依然是 h_{t-1}\) + \(x_t
  • 计算过程: 拼接后通过第四个全连接层(权重 W_o,偏置 b_o),经 sigmoid 激活: o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) 然后,用输出门 o_t筛选当前记忆元c_t中的信息,再经 tanh 处理(确保输出范围 - 1~1): h_t = o_t \odot \tanh(c_t)
  • 含义
    • o_t 中元素越接近 1:表示记忆元中对应位置的信息 “允许输出到隐状态”;
    • 越接近 0:表示该信息 “仅保存在记忆元中,不输出”。

:在 “我喜欢吃苹果,它很甜” 中,当处理 “它” 时,输出门会从记忆元中筛选 “苹果” 的信息,让隐状态包含 “苹果”,从而正确预测 “它” 指代 “苹果”。

四、候选记忆元(Candidate Cell State):“待存入的新信息”

候选记忆元是当前输入中可能被存入记忆元的 “原始新信息”(未经筛选),相当于 “草稿”,最终是否存入由输入门决定。

  • 输入:还是 h_{t-1}+ x_t
  • 计算过程: 拼接后通过第三个全连接层(权重 W_{\tilde{c}},偏置 b_{\tilde{c}}),经 tanh 激活(输出范围 - 1~1): \tilde{c}_t = \tanh(W_{\tilde{c}} \cdot [h_{t-1}, x_t] + b_{\tilde{c}})
  • 为什么用 tanh:tanh 将值限制在 - 1~1 之间,避免新信息数值过大导致记忆元 “溢出”,同时保留正负信息(比如 “喜欢” vs “不喜欢”)。

五、记忆元(Cell State):长期记忆的 “仓库”

记忆元是 LSTM 的 “核心仓库”,负责存储长期信息,其状态会在时间步之间传递并被不断更新

  • 更新规则:结合遗忘门(保留旧信息)和输入门 + 候选记忆元(添加新信息):c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t其中,\odot 是元素级乘法(Hadamard 积)。
  • 含义
    • f_t \odot c_{t-1}:对上一时刻的记忆元 c_{t-1} 按遗忘门的筛选保留部分信息;
    • i_t \odot \tilde{c}_t:对候选记忆元 \tilde{c}_t 按输入门的筛选保留部分新信息;
    • 两者相加:得到当前时刻的记忆元c_t(旧信息 + 新信息的融合)。

:在 “我出生在巴黎,……,现在住在伦敦” 中,记忆元会先保留 “巴黎”(遗忘门让其保留),当读到 “伦敦” 时,输入门允许 “伦敦” 进入,记忆元更新为 “巴黎 + 伦敦”(或根据重要性调整权重)。

六、 隐状态(Hidden State):短期输出与信息传递

隐状态 h_t是 LSTM 在当前时间步的 “对外输出”,有两个作用:

  • 作为当前时间步的模型输出(比如用于序列预测、分类等);
  • 传递到下一时间步,作为计算下一个时间步各大门和候选记忆元的输入。

与记忆元的区别:

  • 记忆元 c_t:长期存储,更新频率低(主要保留关键信息);
  • 隐状态 h_t:短期输出,随时间步快速变化(反映当前时刻的重点信息)。

七、 门控记忆元(Gated Cell):整体协同机制

“门控记忆元” 不是一个独立组件,而是对 LSTM 中 “记忆元 + 三大门控” 整体机制的统称。它强调记忆元的更新和输出是被输入门、遗忘门、输出门 “控制” 的,而非像传统 RNN 那样无差别传递。这种 “门控” 机制正是 LSTM 能处理长期依赖的核心。

八、各组件协同流程

为了更清晰理解,用一个时间步的流程总结:

  1. 遗忘旧信息:遗忘门f_t 决定从 c_{t-1} 中保留哪些信息;
  2. 筛选新信息:输入门 i_t决定从候选记忆元 \tilde{c}_t 中保留哪些新信息;
  3. 更新记忆元c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t(旧信息 + 新信息);
  4. 生成输出:输出门 o_t从 c_t中筛选信息,生成隐状态h_t = o_t \odot \tanh(c_t)

这个流程在每个时间步重复,使得记忆元能长期保留关键信息,隐状态能灵活输出当前重点,从而解决长期依赖问题。

  • 遗忘门、输入门、输出门是 “控制器”,决定信息的删、存、取;
  • 候选记忆元是 “新信息草稿”;
  • 记忆元是 “长期仓库”;
  • 隐状态是 “当前输出”。

九、为什么能解决长期依赖?

  • 遗忘门的灵活性:可以让不重要的信息快速遗忘(f_t \approx 0),而关键信息长期保留(f_t \approx 1);
  • 记忆元的稳定性:记忆元的更新是 “加性” 的(c_t = ... + ...),而非 RNN 中隐状态的 “替换性” 更新(h_t = \tanh(W \cdot [h_{t-1}, x_t])),梯度在反向传播时更稳定,不易消失;
  • 门控的选择性:输入门和输出门可以 “按需” 添加和提取信息,避免无关信息干扰。

十、LSTM的结构图

 十一、完整代码

"""
文件名: 9.2 从零实现长短期记忆网络(LSTM)
作者: 墨尘
日期: 2025/7/16
项目名: dl_env
备注: 
"""
# -------------------------- 基础工具库导入 --------------------------
import collections  # 用于统计词频(构建词表时需统计每个词元出现的次数)
import random  # 随机抽样生成训练数据(增加数据随机性,提升模型泛化能力)
import re  # 文本清洗(通过正则表达式过滤非目标字符)
import requests  # 下载数据集(从网络获取《时间机器》文本数据)
from pathlib import Path  # 文件路径处理(创建目录、检查文件是否存在等)
from d2l import torch as d2l  # 深度学习工具库(提供训练辅助、可视化等功能)
import math  # 数学运算(计算困惑度等指标)
import torch  # PyTorch框架(核心深度学习库,提供张量运算、自动求导等)
from torch import nn  # 神经网络模块(提供损失函数、层定义等)
from torch.nn import functional as F  # 函数式API(提供激活函数、one-hot编码等工具)# 图像显示相关库(解决中文和符号显示问题)
import matplotlib.pyplot as plt
import matplotlib.text as text# -------------------------- 核心解决方案:解决文本显示问题 --------------------------
def replace_minus(s):"""解决Matplotlib中Unicode减号(U+2212)显示为方块的问题原理:将特殊减号替换为普通ASCII减号('-'),确保所有环境都能正常显示"""if isinstance(s, str):  # 仅处理字符串类型return s.replace('\u2212', '-')  # 替换Unicode减号为ASCII减号return s  # 非字符串直接返回# 重写matplotlib的Text类的set_text方法,实现全局生效
original_set_text = text.Text.set_text  # 保存原始方法(避免覆盖后无法恢复)def new_set_text(self, s):s = replace_minus(s)  # 先处理减号return original_set_text(self, s)  # 调用原始方法设置文本text.Text.set_text = new_set_text  # 应用重写后的方法(所有文本显示都会经过此处理)# -------------------------- 字体配置(确保中文和数学符号正常显示)--------------------------
plt.rcParams["font.family"] = ["SimHei"]  # 设置中文字体(SimHei支持中文显示,避免中文乱码)
plt.rcParams["text.usetex"] = True  # 使用LaTeX渲染文本(提升数学符号显示美观度)
plt.rcParams["axes.unicode_minus"] = True  # 确保负号正确显示(避免负号显示为方块)
plt.rcParams["mathtext.fontset"] = "cm"  # 数学符号使用Computer Modern字体(LaTeX标准字体,更专业)
d2l.plt.rcParams.update(plt.rcParams)  # 让d2l库的绘图工具继承上述配置(保持显示一致性)# -------------------------- 1. 读取数据 --------------------------
def read_time_machine():"""下载并读取《时间机器》数据集,返回清洗后的文本行列表作用:获取原始文本数据并预处理,为后续词元化做准备"""data_dir = Path('./data')  # 数据存储目录(当前目录下的data文件夹)data_dir.mkdir(exist_ok=True)  # 目录不存在则创建(exist_ok=True避免重复创建报错)file_path = data_dir / 'timemachine.txt'  # 数据集文件路径# 检查文件是否存在,不存在则下载if not file_path.exists():print("开始下载时间机器数据集...")# 从d2l官方地址下载文本(《时间机器》是经典数据集,适合语言模型训练)response = requests.get('http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt')# 写入文件(utf-8编码确保兼容多种字符)with open(file_path, 'w', encoding='utf-8') as f:f.write(response.text)print(f"数据集下载完成,保存至: {file_path}")# 读取文件并清洗文本with open(file_path, 'r', encoding='utf-8') as f:lines = f.readlines()  # 按行读取(每行作为列表元素)print(f"文件读取成功,总行数: {len(lines)}")if len(lines) > 0:print(f"第一行内容: {lines[0].strip()}")  # 打印首行验证是否正确读取# 清洗规则:# 1. re.sub('[^A-Za-z]+', ' ', line):保留字母,其他字符(如数字、符号)替换为空格# 2. strip():去除首尾空格# 3. lower():转小写(统一大小写,减少词元数量)# 4. 过滤空行(if line.strip()确保仅保留非空行)cleaned_lines = [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines if line.strip()]print(f"清洗后有效行数: {len(cleaned_lines)}")  # 清洗后非空行数量(去除纯空格行)return cleaned_lines# -------------------------- 2. 词元化与词表构建 --------------------------
def tokenize(lines, token='char'):"""将文本行转换为词元列表(词元是文本的最小处理单位)参数:lines: 清洗后的文本行列表(如["abc def", "ghi jkl"])token: 词元类型('char'字符级/'word'单词级)返回:词元列表(如字符级:[['a','b','c',' ','d','e','f'], ...])作用:将文本拆分为模型可处理的最小单元(词元),字符级适合简单语言模型"""if token == 'char':# 字符级词元化:将每行拆分为单个字符列表(包括空格,如"abc"→['a','b','c'])return [list(line) for line in lines]elif token == 'word':# 单词级词元化:按空格拆分每行(需确保文本已用空格分隔单词,如"abc def"→['abc','def'])return [line.split() for line in lines]else:raise ValueError('未知词元类型:' + token)class Vocab:"""词表类:实现词元与索引的双向映射,用于将文本转换为模型可处理的数字序列核心功能:将字符串形式的词元转换为整数索引(模型只能处理数字),同时支持索引转词元(用于生成文本)"""def __init__(self, tokens, min_freq=0, reserved_tokens=None):"""构建词表参数:tokens: 词元列表(可嵌套,如[[token1, token2], [token3]])min_freq: 最低词频阈值(低于此值的词元不加入词表,减少词汇量)reserved_tokens: 预留特殊词元(如分隔符、填充符等,模型可能需要的特殊标记)"""if reserved_tokens is None:reserved_tokens = []  # 默认为空(无预留词元)# 统计词频:# 1. 展平嵌套列表([token for line in tokens for token in line])# 2. 用Counter计数(得到{词元: 出现次数}字典)counter = collections.Counter([token for line in tokens for token in line])# 按词频降序排序(便于后续按频率筛选,高频词优先保留)self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)# 初始化词表:# <unk>(未知词元)固定在索引0(所有未见过的词元都映射到<unk>)#  followed by预留词元(如用户指定的特殊标记)self.idx_to_token = ['<unk>'] + reserved_tokens# 构建词元到索引的映射(字典,便于快速查询)self.token_to_idx = {token: idx for idx, token in enumerate(self.idx_to_token)}# 按词频添加词元(过滤低频词)for token, freq in self.token_freqs:if freq < min_freq:break  # 低频词不加入词表(提前终止,提升效率)if token not in self.token_to_idx:  # 避免重复添加预留词元(如预留词元已在列表中)self.idx_to_token.append(token)self.token_to_idx[token] = len(self.idx_to_token) - 1  # 索引为当前长度-1(保持连续)def __len__(self):"""返回词表大小(词元总数,用于模型输入/输出维度设置)"""return len(self.idx_to_token)def __getitem__(self, tokens):"""词元→索引(支持单个词元或词元列表)未知词元返回<unk>的索引(0),确保模型输入始终有效"""if not isinstance(tokens, (list, tuple)):# 单个词元:查字典,默认返回<unk>的索引(0)return self.token_to_idx.get(tokens, self.unk)# 词元列表:递归转换每个词元(如['a','b']→[2,3])return [self.__getitem__(token) for token in tokens]def to_tokens(self, indices):"""索引→词元(支持单个索引或索引列表,用于将模型输出转换为文本)"""if not isinstance(indices, (list, tuple)):# 单个索引:直接查列表(如2→'a')return self.idx_to_token[indices]# 索引列表:递归转换每个索引(如[2,3]→['a','b'])return [self.idx_to_token[index] for index in indices]@propertydef unk(self):"""返回<unk>的索引(固定为0,便于统一处理未知词元)"""return 0# -------------------------- 3. 数据迭代器(随机抽样) --------------------------
def seq_data_iter_random(corpus, batch_size, num_steps):"""随机抽样生成批量子序列(生成器),用于模型训练的批量输入原理:从语料中随机截取多个长度为num_steps的子序列,组成批次(避免模型学习到固定的句子顺序)参数:corpus: 词元索引序列(1D列表,如[1,3,5,2,...],所有文本的词元索引拼接而成)batch_size: 批量大小(每个批次包含的子序列数,影响训练效率和内存占用)num_steps: 子序列长度(时间步,即模型一次处理的序列长度,如35表示一次输入35个词元)返回:生成器,每次返回(X, Y):X: 输入序列(batch_size, num_steps),模型的输入Y: 标签序列(batch_size, num_steps),是X右移一位的结果(模型需要预测的下一个词元)"""# 检查数据是否足够生成至少一个子序列(子序列长度+1,因Y是X右移1位,需多1个元素)if len(corpus) < num_steps + 1:raise ValueError(f"语料库长度({len(corpus)})不足,需至少{num_steps + 1}")# 随机偏移起始位置(0到num_steps-1),增加数据随机性(避免每次从固定位置开始)corpus = corpus[random.randint(0, num_steps - 1):]# 计算可生成的子序列总数:# (语料长度-1) // num_steps(-1是因Y需多1个元素,每个子序列需num_steps+1个元素)num_subseqs = (len(corpus) - 1) // num_stepsif num_subseqs < 1:raise ValueError(f"无法生成子序列(语料库长度不足)")# 生成所有子序列的起始索引(间隔为num_steps,如0, num_steps, 2*num_steps...)initial_indices = list(range(0, num_subseqs * num_steps, num_steps))random.shuffle(initial_indices)  # 打乱起始索引,实现随机抽样(核心:避免子序列顺序固定)# 计算可生成的批次数:子序列总数 // 批量大小(确保每个批次有batch_size个子序列)num_batches = num_subseqs // batch_sizeif num_batches < 1:raise ValueError(f"子序列数量({num_subseqs})不足,需至少{batch_size}个")# 生成批量数据for i in range(0, batch_size * num_batches, batch_size):# 当前批次的起始索引(从打乱的索引中取batch_size个,如i=0时取前batch_size个)indices = initial_indices[i: i + batch_size]# 输入序列X:每个子序列从indices[j]开始,取num_steps个元素(如indices[j]=0→[0:35])X = [corpus[j: j + num_steps] for j in indices]# 标签序列Y:每个子序列从indices[j]+1开始,取num_steps个元素(X右移1位,如[1:36])Y = [corpus[j + 1: j + num_steps + 1] for j in indices]# 转换为张量返回(便于模型处理,PyTorch模型输入需为张量)yield torch.tensor(X), torch.tensor(Y)# -------------------------- 4. 数据加载函数(关键修复:返回可重置的迭代器) --------------------------
def load_data_time_machine(batch_size, num_steps):"""加载《时间机器》数据,返回数据迭代器生成函数和词表修复点:返回迭代器生成函数(而非一次性迭代器),确保训练时可重复生成数据(每个epoch重新抽样)参数:batch_size: 批量大小num_steps: 子序列长度(时间步)返回:data_iter: 迭代器生成函数(调用时返回新的迭代器,每次调用重新抽样)vocab: 词表对象(用于词元与索引的转换)"""lines = read_time_machine()  # 读取清洗后的文本行tokens = tokenize(lines, token='char')  # 字符级词元化(每个字符为词元,适合简单语言模型)vocab = Vocab(tokens)  # 构建词表(根据词元生成索引映射)# 将所有词元转换为索引(展平为1D序列,如[[ 'a', 'b' ], [ 'c' ]]→[2,3,4])corpus = [vocab[token] for line in tokens for token in line]print(f"语料库长度: {len(corpus)}(词元索引总数)")# 定义迭代器生成函数:每次调用生成新的随机抽样迭代器(确保每个epoch数据不同)def data_iter():return seq_data_iter_random(corpus, batch_size, num_steps)return data_iter, vocab  # 返回生成函数和词表# -------------------------- 5. LSTM模型核心实现 --------------------------def get_lstm_params(vocab_size, num_hiddens, device):"""初始化LSTM的所有参数(权重和偏置)参数:vocab_size: 词表大小(输入/输出维度,因是语言模型,输入输出均为词表词元)num_hiddens: 隐藏层维度(记忆元/隐状态的维度,控制模型容量)device: 计算设备(CPU/GPU,参数需存储在对应设备上)返回:参数列表:包含所有门控、候选记忆元、输出层的权重和偏置"""num_inputs = num_outputs = vocab_size  # 输入维度=输出维度=词表大小def normal(shape):"""生成正态分布的随机参数(均值0,标准差0.01,避免初始值过大)"""return torch.randn(size=shape, device=device) * 0.01def three():"""生成一组参数(输入权重、隐藏层权重、偏置),用于门控或候选记忆元"""return (normal((num_inputs, num_hiddens)),  # 输入X的权重(vocab_size × num_hiddens)normal((num_hiddens, num_hiddens)),  # 上一时刻隐状态H的权重(num_hiddens × num_hiddens)torch.zeros(num_hiddens, device=device))  # 偏置(初始为0,num_hiddens维度)# 输入门参数(W_xi:输入X到输入门的权重;W_hi:上一H到输入门的权重;b_i:偏置)W_xi, W_hi, b_i = three()# 遗忘门参数(W_xf:输入X到遗忘门的权重;W_hf:上一H到遗忘门的权重;b_f:偏置)W_xf, W_hf, b_f = three()# 输出门参数(W_xo:输入X到输出门的权重;W_ho:上一H到输出门的权重;b_o:偏置)W_xo, W_ho, b_o = three()# 候选记忆元参数(W_xc:输入X到候选记忆元的权重;W_hc:上一H到候选记忆元的权重;b_c:偏置)W_xc, W_hc, b_c = three()# 输出层参数(将隐状态H映射到输出词表维度)W_hq = normal((num_hiddens, num_outputs))  # H到输出的权重(num_hiddens × vocab_size)b_q = torch.zeros(num_outputs, device=device)  # 输出层偏置# 附加梯度(所有参数需要计算梯度,后续训练时更新)params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)  # 启用梯度计算return paramsdef init_lstm_state(batch_size, num_hiddens, device):"""初始化LSTM的初始状态(记忆元和隐状态)LSTM有两个状态:记忆元(Cell State,长期记忆)和隐状态(Hidden State,短期输出)初始状态均为全0张量"""return (torch.zeros((batch_size, num_hiddens), device=device),  # 记忆元c的初始状态(batch_size × num_hiddens)torch.zeros((batch_size, num_hiddens), device=device))  # 隐状态h的初始状态(batch_size × num_hiddens)def lstm(inputs, state, params):"""LSTM前向传播(核心计算逻辑)参数:inputs: 输入序列(num_steps, batch_size, vocab_size),每个时间步的输入(one-hot编码)state: 初始状态((H_0, C_0),H_0是初始隐状态,C_0是初始记忆元)params: LSTM的所有参数(门控、候选记忆元、输出层的权重和偏置)返回:outputs: 所有时间步的输出拼接(num_steps*batch_size, vocab_size)(H, C): 最终的隐状态和记忆元(用于传递到下一批次或预测)"""# 解析参数(从params列表中提取各部分参数)[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = state  # 初始状态:H是上一时刻隐状态,C是上一时刻记忆元outputs = []  # 存储每个时间步的输出# 逐时间步计算(inputs的第0维是时间步)for X in inputs:# 1. 计算输入门(I_t):控制新信息进入记忆元的比例(0~1)# 输入门由当前输入X和上一隐状态H共同决定,sigmoid激活(输出0~1)I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)  # X×W_xi:输入X的贡献;H×W_hi:上一H的贡献;加偏置后激活# 2. 计算遗忘门(F_t):控制记忆元中旧信息保留的比例(0~1)# 同样由X和H决定,sigmoid激活F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)# 3. 计算输出门(O_t):控制记忆元中信息输出到隐状态的比例(0~1)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)# 4. 计算候选记忆元(C_tilda):当前时间步的新信息(-1~1)# tanh激活确保值在-1~1之间,避免数值过大C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)# 5. 更新记忆元(C_t):旧信息保留 + 新信息加入# F×C:遗忘门控制旧记忆元C保留的部分;I×C_tilda:输入门控制新信息加入的部分C = F * C + I * C_tilda# 6. 更新隐状态(H_t):输出门控制记忆元中信息的输出# tanh(C)将记忆元值缩放到-1~1,再由输出门O筛选H = O * torch.tanh(C)# 7. 计算当前时间步的输出(Y_t):隐状态H映射到词表维度Y = (H @ W_hq) + b_q  # H×W_hq:隐状态到输出的映射;加偏置outputs.append(Y)  # 保存当前时间步的输出# 拼接所有时间步的输出(按时间步维度拼接),返回输出和最终状态return torch.cat(outputs, dim=0), (H, C)# -------------------------- 6. RNN模型包装类 --------------------------
class RNNModelScratch:  # @save"""从零实现的RNN模型包装类,统一模型调用接口(适配训练和预测流程)"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):"""参数:vocab_size: 词表大小(输入/输出维度)num_hiddens: 隐藏层维度(记忆元/隐状态的维度)device: 计算设备get_params: 参数初始化函数(如get_lstm_params)init_state: 状态初始化函数(如init_lstm_state)forward_fn: 前向传播函数(如lstm)"""self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)  # 模型参数(通过get_params获取)self.init_state, self.forward_fn = init_state, forward_fn  # 状态初始化和前向传播函数def __call__(self, X, state):"""模型调用接口(前向传播入口,兼容PyTorch的调用方式)参数:X: 输入序列(batch_size, num_steps),元素为词元索引(未编码的原始输入)state: 初始隐藏状态((H_0, C_0))返回:y_hat: 输出(num_steps*batch_size, vocab_size),所有时间步的输出拼接state: 最终隐藏状态((H_t, C_t))"""# 处理输入:# 1. X.T:转置为(num_steps, batch_size)(便于逐时间步处理,时间步在前)# 2. F.one_hot:转换为one-hot编码(num_steps, batch_size, vocab_size),将索引转为向量# 3. type(torch.float32):转换为浮点型(适配后续矩阵运算,权重为浮点型)X = F.one_hot(X.T, self.vocab_size).type(torch.float32)# 调用前向传播函数(如lstm)计算输出和新状态return self.forward_fn(X, state, self.params)def begin_state(self, batch_size, device):"""获取初始隐藏状态(调用初始化函数,封装状态初始化逻辑)"""return self.init_state(batch_size, self.num_hiddens, device)# -------------------------- 7. 预测函数(文本生成) --------------------------
def predict_ch8(prefix, num_preds, net, vocab, device):  # @save"""根据前缀生成后续字符(文本生成,验证模型学习效果)参数:prefix: 前缀字符串(如"time traveller",模型基于此生成后续内容)num_preds: 要生成的字符数net: 训练好的LSTM模型vocab: 词表(用于词元与索引的转换)device: 计算设备返回:生成的字符串(前缀+预测字符,如前缀"ti"生成"time...")"""# 初始化状态(批量大小为1,因仅生成一条序列,无需并行)state = net.begin_state(batch_size=1, device=device)# 记录输出索引:初始为前缀首字符的索引(将前缀转换为索引序列)outputs = [vocab[prefix[0]]]# 辅助函数:获取当前输入(最后一个输出的索引,形状(1,1),符合模型输入格式)def get_input():return torch.tensor([outputs[-1]], device=device).reshape((1, 1))# 预热期:用前缀更新模型状态(不生成新字符,仅让模型"记住"前缀的信息)for y in prefix[1:]:_, state = net(get_input(), state)  # 前向传播,更新状态(忽略输出,因只需状态)outputs.append(vocab[y])  # 记录前缀字符的索引(确保outputs包含完整前缀)# 预测期:生成num_preds个字符for _ in range(num_preds):y, state = net(get_input(), state)  # 前向传播,获取输出和新状态(y是当前时间步的输出)# 取概率最大的字符索引(贪婪采样:简单策略,选择模型认为最可能的下一个字符)outputs.append(int(y.argmax(dim=1).reshape(1)))# 将索引转换为字符,拼接成字符串返回(完成从索引到文本的转换)return ''.join([vocab.idx_to_token[i] for i in outputs])# -------------------------- 8. 梯度裁剪(防止梯度爆炸) --------------------------
def grad_clipping(net, theta):  # @save"""裁剪梯度(将梯度L2范数限制在theta内),防止梯度爆炸(RNN训练中常见问题)原理:若梯度范数超过阈值theta,则按比例缩小所有梯度,确保训练稳定参数:net: 模型(自定义模型或nn.Module)theta: 梯度阈值(如1.0,根据经验设置)"""# 获取需要梯度更新的参数if isinstance(net, nn.Module):# 若为PyTorch官方Module,直接取parameters(包含所有需要梯度的参数)params = [p for p in net.parameters() if p.requires_grad]else:# 若为自定义模型(如RNNModelScratch),取params属性(存储模型参数)params = net.params# 计算所有参数梯度的L2范数(平方和开根号)norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:  # 若范数超过阈值,按比例裁剪(保持梯度方向不变,缩小幅度)for param in params:param.grad[:] *= theta / norm# -------------------------- 9. 训练函数 --------------------------
def train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter):"""训练一个周期(单轮遍历数据集)参数:net: LSTM模型train_iter_fn: 迭代器生成函数(调用后返回新迭代器,每个epoch重新生成数据)loss: 损失函数(如CrossEntropyLoss,计算预测与标签的差距)updater: 优化器(如SGD,用于更新模型参数)device: 计算设备use_random_iter: 是否使用随机抽样(影响状态处理方式:随机抽样时状态独立,无需传递)返回:ppl: 困惑度(perplexity,语言模型性能指标,越低表示模型越好)speed: 训练速度(词元/秒,衡量训练效率)"""state, timer = None, d2l.Timer()  # 初始化状态和计时器(timer用于计算训练速度)metric = d2l.Accumulator(2)  # 累加器:(总损失, 总词元数),用于计算平均损失batches_processed = 0  # 记录处理的批次数量(验证是否有数据被处理)# 关键修复:每次训练都通过函数生成新的迭代器(避免迭代器被提前消费,确保每个epoch数据不同)train_iter = train_iter_fn()# 遍历批量数据(每个X, Y是一个批次)for X, Y in train_iter:batches_processed += 1# 初始化状态:# - 首次迭代时需初始化(state为None)# - 随机抽样时,每个批次的序列独立(无上下文关联),需重新初始化if state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:# 非随机抽样时,分离状态(切断梯度回流到之前的批次,避免梯度计算依赖过长导致爆炸)if isinstance(net, nn.Module) and not isinstance(state, tuple):state.detach_()  # 单个状态直接detach(如GRU只有隐状态)else:for s in state:  # 多个状态(如LSTM有隐状态和记忆元)逐个detachs.detach_()# 处理标签:# Y.T.reshape(-1):转置后展平为(num_steps*batch_size,)(与输出y_hat的形状匹配)# 输出y_hat的形状是(num_steps*batch_size, vocab_size),标签需为1D张量y = Y.T.reshape(-1)# 将输入和标签移到目标设备(GPU/CPU,确保与模型参数在同一设备)X, y = X.to(device), y.to(device)# 前向传播:获取输出和新状态y_hat, state = net(X, state)# 计算损失(mean()是因损失函数可能返回每个样本的损失,取平均得到批次损失)l = loss(y_hat, y.long()).mean()# 反向传播与参数更新:if isinstance(updater, torch.optim.Optimizer):# 若为PyTorch优化器(如SGD)updater.zero_grad()  # 清零梯度(避免梯度累积)l.backward()  # 反向传播计算梯度grad_clipping(net, 1)  # 裁剪梯度(阈值1,防止梯度爆炸)updater.step()  # 更新参数else:# 若为自定义优化器(如d2l的sgd函数)l.backward()grad_clipping(net, 1)updater(batch_size=1)  # 假设批量大小为1的更新(简化实现)# 累加总损失和总词元数(用于计算平均损失)# metric[0] += l * y.numel():总损失=批次损失×词元数(因l是平均损失)# metric[1] += y.numel():总词元数=累加每个批次的词元数量metric.add(l * y.numel(), y.numel())# 检查是否有批次被处理(避免空迭代导致的错误)if batches_processed == 0:print("警告:没有处理任何训练批次!")return float('inf'), 0# 计算困惑度(perplexity = exp(平均损失),语言模型专用指标,与交叉熵损失正相关)# 平均损失 = 总损失 / 总词元数,exp后得到困惑度(完美模型困惑度=1)# 速度 = 总词元数 / 训练时间(词元/秒,衡量训练效率)return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()def train_ch8(net, train_iter_fn, vocab, lr, num_epochs, device, use_random_iter=False):"""训练模型(多周期,整合单周期训练逻辑,输出训练过程和结果)参数:net: LSTM模型train_iter_fn: 迭代器生成函数vocab: 词表lr: 学习率(控制参数更新幅度)num_epochs: 训练周期数(遍历数据集的次数,影响模型收敛程度)device: 计算设备use_random_iter: 是否使用随机抽样(默认False,即顺序抽样)"""loss = nn.CrossEntropyLoss()  # 交叉熵损失(适用于分类任务,此处为词元预测,多分类问题)# 动画器:可视化训练过程(实时绘制困惑度随周期变化的曲线,直观观察模型收敛情况)animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化优化器:if isinstance(net, nn.Module):# 若为PyTorch Module,使用SGD优化器(随机梯度下降,适合简单模型)updater = torch.optim.SGD(net.parameters(), lr)else:# 若为自定义模型,使用d2l的sgd函数(简化的随机梯度下降实现)updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)# 定义预测函数:根据前缀"time traveller"生成50个字符(验证模型学习效果)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 多周期训练for epoch in range(num_epochs):# 训练一个周期,返回困惑度和速度ppl, speed = train_epoch_ch8(net, train_iter_fn, loss, updater, device, use_random_iter)# 每10个周期打印一次预测结果(观察生成文本质量变化,判断模型是否学到有意义的模式)if (epoch + 1) % 10 == 0:print(f"epoch {epoch + 1} 预测: {predict('time traveller')}")animator.add(epoch + 1, [ppl])  # 记录困惑度,更新动画# 训练结束后输出最终结果(总结模型性能)print(f'最终困惑度 {ppl:.1f}, 速度 {speed:.1f} 词元/秒 {device}')print(f"time traveller 预测: {predict('time traveller')}")  # 用"time traveller"前缀生成文本print(f"traveller 预测: {predict('traveller')}")  # 用"traveller"前缀生成文本# -------------------------- 主程序 --------------------------
if __name__ == '__main__':# 超参数设置(根据经验和任务调整)batch_size, num_steps = 32, 35  # 批量大小=32(每次处理32个序列),时间步=35(每个序列35个词元)# 加载数据:获取迭代器生成函数和词表train_iter, vocab = load_data_time_machine(batch_size, num_steps)# 模型参数vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()  # 词表大小、隐藏层维度=256,自动选择GPU/CPUnum_epochs, lr = 500, 0.12  # 训练周期=500(充分训练),学习率=0.12(控制更新幅度)# 初始化LSTM模型(使用自定义的从零实现的模型)model = RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)# 开始训练(调用训练函数,启动多周期训练)train_ch8(model, train_iter, vocab, lr, num_epochs, device)plt.show(block=True)  # 显示训练过程的动画图(阻塞模式,确保图不闪退,便于观察)

十二、实验结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/89273.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/89273.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FreeRTOS之链表关键数据结构和函数操作接口-1

FreeRTOS之链表操作相关接口1 FreeRTOS源码下载地址2 任务控制块TCB2.1 任务控制块TCB2.1.1 任务控制块的关键成员2.1.2 TCB 的核心作用2.2 ListItem_t2.3 List_t3 函数接口3.1 vListInitialise3.2 vListInitialiseItem1 FreeRTOS源码下载地址 https://www.freertos.org/ 2 …

OpenVela之 Arch Timer 驱动框架使用指南

一、概述 在嵌入式系统开发中&#xff0c;定时器是实现任务调度、精确延时等功能的核心组件。Arch Timer 作为基于 Timer Driver 实现的间隔定时器&#xff0c;在系统调度中扮演着重要角色。本文将全面介绍 Arch Timer 驱动框架&#xff0c;从基本概念到实际应用&#xff0c;帮…

AAC编解码

AAC&#xff08;Advanced Audio Coding&#xff0c;高级音频编码&#xff09;是一种基于心理声学原理的有损音频编解码技术&#xff0c;广泛应用于流媒体、数字广播、移动音频等场景。其编解码流程围绕 “保留人耳可感知信息、去除冗余” 设计&#xff0c;分为编码&#xff08;…

STM32 | HC-SR04 超声波传感器测距

模块&#xff1a;HC-SR04感应角度&#xff1a;不大于15度 探测距离&#xff1a;2cm-450cm 高精度&#xff1a;可达0.3cmTrig&#xff1a;触发信号&#xff0c;接收MCU发送的控制脉冲&#xff0c;MCU对应GPIO 设置为输出Echo&#xff1a;反馈信号&#xff0c;向MCU发送数据…

【RTSP从零实践】12、TCP传输H264格式RTP包(RTP_over_TCP)的RTSP服务器(附带源码)

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

【unitrix】 6.1 类型化整数特征(t_int.rs)

一、源码 这段代码定义了一个 Rust 特征&#xff08;trait&#xff09;TInt 和一些实现&#xff0c;用于表示类型化的整数。 use crate::number::{Null, B, Bit, TNumber};/// 类型化整数标记特征 /// /// 要求&#xff1a; /// - 实现 TNumber /// - 可复制 (Copy) /// - 默认…

速通LVS

一、LVS的使用lvs部署命令介绍lvs软件相关信息&#xff1a;程序包&#xff1a;ipvsadm Unit File: ipvsadm.service 主程序&#xff1a;/usr/sbin/ipvsadm 规则保存工具&#xff1a;/usr/sbin/ipvsadm-save 规则重载工具&#xff1a;/usr/sbin/ipvsadm-restore 配置文件&#x…

Nginx,MD5和Knife4j

一、 Nginx: 项目网关与流量调度核心原理反向代理 (Reverse Proxy):在Web架构中&#xff0c;Nginx作为系统的统一入口&#xff08;API网关&#xff09;&#xff0c;接收所有外部客户端请求。它通过解析请求的URL路径&#xff08;location指令&#xff09;&#xff0c;判断请求的…

多态,内部类(匿名内部类),常用API(1)

多态 什么是多态&#xff1f; 同一个对象在不同时刻表现出来的不同形态&#xff08;多种形态&#xff09; 例&#xff1a;Cat extends Animal 第一种形态&#xff1a;Cat c1 new Cat(); //c1是只猫 第二种形态&#xff1a;Animal c2 new Cat(); //c2是个动物 &#xff08…

Qt小组件 - 7 SQL Thread Qt访问数据库ORM

简介网上关于Qt访问数据库的资料大多使用QSqlDatabase模块。虽然这在C中尚可接受&#xff0c;但在Python中使用就显得过于繁琐了——不仅要手动编写SQL语句&#xff0c;还与Python追求简洁的理念背道而驰。在这里写一个基于sqlalchemy的示例&#xff0c;也可以使用其他的ORM库 …

使用Gin框架构建高并发教练预约微服务:架构设计与实战解析

项目概述 技术栈 Web框架&#xff1a;Gin&#xff08;高性能HTTP框架&#xff09;数据存储&#xff1a;Redis&#xff08;内存数据库&#xff0c;用于高并发读写&#xff09; 项目结构 coach-booking-service ├── main.go # 程序入口&#xff0c;路由初始化&am…

深入拆解Spring第二大核心思想:AOP

什么是AOP Aspect Oriented Programming&#xff08;面向切面编程&#xff09; 什么是面向切面编程呢? 切⾯就是指某⼀类特定问题, 所以AOP也可以理解为面向特定方法编程. 什么是面向特定方法编程呢? 比如对于"登录校验", 就是⼀类特定问题. 登录校验拦截器, 就是…

linux服务器stress-ng的使用

安装方法 • Ubuntu/Debian&#xff1a;sudo apt update && sudo apt install stress-ng -y• CentOS/RHEL&#xff08;需EPEL源&#xff09;&#xff1a;sudo yum install epel-release -ysudo yum install stress-ng -y• 源码编译&#xff08;适合定制化需求&#x…

探索阿里云DMS:解锁高效数据管理新姿势

一、阿里云 DMS 是什么 阿里云 DMS&#xff0c;全称为 Data Management Service&#xff0c;即数据管理服务 &#xff0c;是一种集数据管理、结构管理、安全管理于一体的全面数据库服务平台。它能够有效地支持各类数据库产品&#xff0c;包括但不限于 MySQL、SQL Server、Post…

python爬取新浪财经网站上行业板块股票信息的代码

在这个多行业持续高速发展的时代&#xff0c;科技正在改变着我们的生活。 在世界科技领域中&#xff0c;中国正占据越来越重要的位置。当下&#xff0c;每个行业都提到了区块链、人工智能、大数据、5G等科技力量&#xff0c;强调了科技在行业咨询与数据分析领域的重要意义。 随…

【JAVA】监听windows中鼠标侧面键的按钮按下事件

监听windows中鼠标侧面键的按钮按下事件用到的包核心类使用这个类用到的包 jna-5.11.0.jar jna-platform-5.11.0.jar核心类 package sample.tt.mouse;import com.sun.jna.Pointer; import com.sun.jna.platform.win32.*; import com.sun.jna.platform.win32.WinDef.HMODULE; …

Redis突发写入阻断?解析“MISCONF Redis is configured to save RDB…“故障处理

当你的Redis服务器突然拒绝写入并抛出 MISCONF Redis is configured to save RDB snapshots... 错误时&#xff0c;别慌&#xff01;这是Redis的数据安全保护机制在发挥作用。本文带你深度解析故障根因&#xff0c;并提供完整的解决方案。&#x1f525; 故障现象还原 客户端&am…

产品更新丨谷云科技 iPaaS 集成平台 V7.6 版本发布

六月&#xff0c;谷云科技iPaaS集成平台更新了V7.6版本。这次更新中我们着重对API网关、API编排、组织管理权限、API监控等功能进行了增强以及优化&#xff0c;一起来看看有什么新变化吧&#xff01; 网关、监控、编排、组织权限全方位升级 1.API网关 错误码预警&#xff0c;可…

图像处理中的模板匹配:原理与实现

目录 一、什么是模板匹配&#xff1f; 二、模板匹配的匹配方法 1. 平方差匹配&#xff08;cv2.TM_SQDIFF&#xff09; 2. 归一化平方差匹配&#xff08;cv2.TM_SQDIFF_NORMED&#xff09; 3. 相关匹配&#xff08;cv2.TM_CCORR&#xff09; 4. 归一化相关匹配&#xff08…

高性能架构模式——高性能NoSQL

目录 一、关系数据库的缺点二、常见的 NoSQL 方案分 类2.1、K-V 存储2.2、文档数据库2.3、列式数据库2.4、全文搜索引擎三、高性能 NoSQL 方案的典型特征和应用场景3.1、K-V 存储典型特征和应用场景3.2、文档数据库典型特征和应用场景3.1.1、文档数据库的 no-schema 特性的优势…