循环神经网络(RNN):从理论到翻译

循环神经网络(RNN)是一种专为处理序列数据设计的神经网络,如时间序列、自然语言或语音。与传统的全连接神经网络不同,RNN具有"记忆"功能,通过循环传递信息,使其特别适合需要考虑上下文或顺序的任务。它出现在Transformer之前,广泛应用于文本生成、语音识别和时间序列预测(如股价预测)等领域。

RNN的数学基础

rnn-https://zlu.me

核心方程

在每个时间步 t t t,RNN执行以下操作:

  1. 隐藏状态更新
    h t = tanh ( W h h h t − 1 + W x h x t + b h ) h_t = \text{tanh}(W_{hh}h_{t-1} + W_{xh}x_t + b_h) ht=tanh(Whhht1+Wxhxt+bh)

    • h t h_t ht: 时间 t t t的新隐藏状态(形状:[hidden_size]
    • h t − 1 h_{t-1} ht1: 前一个隐藏状态(形状:[hidden_size]
    • x t x_t xt: 时间 t t t的输入(形状:[input_size]
    • W h h W_{hh} Whh: 隐藏到隐藏的权重矩阵(形状:[hidden_size, hidden_size]
    • W x h W_{xh} Wxh: 输入到隐藏的权重矩阵(形状:[hidden_size, input_size]
    • b h b_h bh: 隐藏层偏置项(形状:[hidden_size]
    • tanh \text{tanh} tanh: 双曲正切激活函数
  2. 输出计算
    o t = W h y h t + b y o_t = W_{hy}h_t + b_y ot=Whyht+by

    • o t o_t ot: 时间 t t t的输出(形状:[output_size]
    • W h y W_{hy} Why: 隐藏到输出的权重矩阵(形状:[output_size, hidden_size]
    • b y b_y by: 输出偏置项(形状:[output_size]

随时间反向传播(BPTT)

RNN使用BPTT进行训练,它通过时间展开网络并应用链式法则:

∂ L ∂ W = ∑ t = 1 T ∂ L t ∂ o t ∂ o t ∂ h t ∑ k = 1 t ( ∏ i = k + 1 t ∂ h i ∂ h i − 1 ) ∂ h k ∂ W \frac{\partial L}{\partial W} = \sum_{t=1}^T \frac{\partial L_t}{\partial o_t} \frac{\partial o_t}{\partial h_t} \sum_{k=1}^t \left( \prod_{i=k+1}^t \frac{\partial h_i}{\partial h_{i-1}} \right) \frac{\partial h_k}{\partial W} WL=t=1TotLthtotk=1t(i=k+1thi1hi)Whk

这可能导致梯度消失/爆炸问题,LSTM和GRU架构可以解决这个问题。

GRU:门控循环单元

在深入翻译示例之前,让我们先了解GRU的数学基础。GRU通过门控机制解决了标准RNN中的梯度消失问题。

GRU方程

在每个时间步 t t t,GRU计算以下内容:

  1. 更新门 ( z t z_t zt):
    z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz[ht1,xt]+bz)

    • z t z_t zt: 更新门(形状:[hidden_size]
    • W z W_z Wz: 更新门的权重矩阵(形状:[hidden_size, hidden_size + input_size]
    • b z b_z bz: 更新门的偏置项(形状:[hidden_size]
    • h t − 1 h_{t-1} ht1: 前一个隐藏状态
    • x t x_t xt: 当前输入
    • σ \sigma σ: Sigmoid激活函数(将值压缩到0和1之间)

    更新门决定保留多少之前的隐藏状态。

  2. 重置门 ( r t r_t rt):
    r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

    • r t r_t rt: 重置门(形状:[hidden_size]
    • W r W_r Wr: 重置门的权重矩阵(形状:[hidden_size, hidden_size + input_size]
    • b r b_r br: 重置门的偏置项(形状:[hidden_size]

    重置门决定忘记多少之前的隐藏状态。

  3. 候选隐藏状态 ( h ~ t \tilde{h}_t h~t):
    h ~ t = tanh ( W ⋅ [ r t ⊙ h t − 1 , x t ] + b ) \tilde{h}_t = \text{tanh}(W \cdot [r_t \odot h_{t-1}, x_t] + b) h~t=tanh(W[rtht1,xt]+b)

    • h ~ t \tilde{h}_t h~t: 候选隐藏状态(形状:[hidden_size]
    • W W W: 候选状态的权重矩阵(形状:[hidden_size, hidden_size + input_size]
    • b b b: 偏置项(形状:[hidden_size]
    • ⊙ \odot : 逐元素乘法(哈达玛积)

    这表示可能使用的新隐藏状态内容。

  4. 最终隐藏状态 ( h t h_t ht):
    h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1zt)ht1+zth~t

    • 最终隐藏状态是前一个隐藏状态和候选状态的组合
    • z t z_t zt作为新旧信息之间的插值因子

GRU在翻译中的优势

  1. 更新门

    • 在英中翻译中,这有助于决定:
      • 保留多少上下文(例如,保持句子的主语)
      • 更新多少新信息(例如,遇到新词时)
  2. 重置门

    • 帮助忘记不相关的信息
    • 例如,在翻译新句子时,可以重置前一个句子的上下文
  3. 梯度流动

    • 最终隐藏状态计算中的加法更新( + + +)有助于保持梯度流动
    • 这对于学习翻译任务中的长程依赖关系至关重要

简单的RNN示例

这个简化示例训练一个RNN来预测单词"hello"中的下一个字符。

  1. 模型定义

    • nn.RNN处理循环计算
    • 全连接层(fc)将隐藏状态映射到输出(字符预测)
  2. 数据

    • 使用"hell"作为输入,期望输出为"ello"(序列移位)
    • 字符转换为one-hot向量(例如,‘h’ → [1, 0, 0, 0])
  3. 训练

    • 通过最小化预测字符和目标字符之间的交叉熵损失来学习
  4. 预测

    • 训练后,模型可以预测下一个字符
import torch
import torch.nn as nnclass SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.rnn = nn.RNN(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, hidden):out, hidden = self.rnn(x, hidden)out = self.fc(out)return out, hiddendef init_hidden(self, batch_size):return torch.zeros(1, batch_size, self.hidden_size)# 超参数
input_size = 4   # 唯一字符数 (h, e, l, o)
hidden_size = 8  # 隐藏状态大小
output_size = 4  # 与input_size相同
learning_rate = 0.01# 字符词汇表
chars = ['h', 'e', 'l', 'o']
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}# 输入数据:"hell" 预测 "ello"
input_seq = "hell"
target_seq = "ello"# 转换为one-hot编码
def to_one_hot(seq):tensor = torch.zeros(1, len(seq), input_size)  # [batch_size, seq_len, input_size]for t, char in enumerate(seq):tensor[0][t][char_to_idx[char]] = 1  # 批大小为1return tensor# 准备输入和目标张量
input_tensor = to_one_hot(input_seq)  # 形状: [1, 4, 4]
print("输入张量形状:", input_tensor.shape)
target_tensor = torch.tensor([char_to_idx[ch] for ch in target_seq], dtype=torch.long)  # 形状: [4]# 初始化模型、损失函数和优化器
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 训练循环
for epoch in range(100):hidden = model.init_hidden(1)  # 批大小为1print("隐藏状态形状:", hidden.shape)  # 应该是 [1, 1, 8]optimizer.zero_grad()output, hidden = model(input_tensor, hidden)  # 输出: [1, 4, 4], 隐藏: [1, 1, 8]loss = criterion(output.squeeze(0), target_tensor)  # output.squeeze(0): [4, 4], target: [4]loss.backward()optimizer.step()if epoch % 20 == 0:print(f'轮次 {epoch}, 损失: {loss.item():.4f}')# 测试模型
with torch.no_grad():hidden = model.init_hidden(1)

英中翻译示例

我们将使用PyTorch的GRU(门控循环单元)构建一个简单的英中翻译模型,GRU是RNN的一种变体,能更好地处理长程依赖关系。

1. 数据准备

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np# 样本平行语料(英文 -> 中文)
english_sentences = ["hello", "how are you", "i love machine learning","good morning", "artificial intelligence"
]chinese_sentences = ["你好", "你好吗", "我爱机器学习","早上好", "人工智能"
]# 创建词汇表
eng_chars = sorted(list(set(' '.join(english_sentences))))
zh_chars = sorted(list(set(''.join(chinese_sentences))))# 添加特殊标记
SOS_token = 0  # 句子开始
EOS_token = 1  # 句子结束
eng_chars = ['<SOS>', '<EOS>', '<PAD>'] + eng_chars
zh_chars = ['<SOS>', '<EOS>', '<PAD>'] + zh_chars# 创建词到索引的映射
eng_to_idx = {ch: i for i, ch in enumerate(eng_chars)}
zh_to_idx = {ch: i for i, ch in enumerate(zh_chars)}# 将句子转换为张量
def sentence_to_tensor(sentence, vocab, is_target=False):indices = [vocab[ch] for ch in (sentence if not is_target else sentence)]if is_target:indices.append(EOS_token)  # 为目标添加EOS标记return torch.tensor(indices, dtype=torch.long).view(-1, 1)

2. 模型架构

class Seq2Seq(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Seq2Seq, self).__init__()self.hidden_size = hidden_size# 编码器(英文到隐藏状态)self.embedding = nn.Embedding(input_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size)# 解码器(隐藏状态到中文)self.out = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input_seq, hidden=None, max_length=10):# 编码器embedded = self.embedding(input_seq).view(1, 1, -1)output, hidden = self.gru(embedded, hidden)# 解码器decoder_input = torch.tensor([[SOS_token]], device=input_seq.device)decoder_hidden = hiddendecoded_words = []for _ in range(max_length):output, decoder_hidden = self.gru(self.embedding(decoder_input).view(1, 1, -1),decoder_hidden)output = self.softmax(self.out(output[0]))topv, topi = output.topk(1)if topi.item() == EOS_token:breakdecoded_words.append(zh_chars[topi.item()])decoder_input = topi.detach()return ''.join(decoded_words), decoder_hiddendef init_hidden(self):return torch.zeros(1, 1, self.hidden_size)

3. 训练模型

# 超参数
hidden_size = 256
learning_rate = 0.01
n_epochs = 1000# 初始化模型
model = Seq2Seq(len(eng_chars), hidden_size, len(zh_chars))
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练循环
for epoch in range(n_epochs):total_loss = 0for eng_sent, zh_sent in zip(english_sentences, chinese_sentences):# 准备数据input_tensor = sentence_to_tensor(eng_sent, eng_to_idx)target_tensor = sentence_to_tensor(zh_sent, zh_to_idx, is_target=True)# 前向传播model.zero_grad()hidden = model.init_hidden()# 编码器前向传播embedded = model.embedding(input_tensor).view(len(input_tensor), 1, -1)_, hidden = model.gru(embedded, hidden)# 准备解码器decoder_input = torch.tensor([[SOS_token]])decoder_hidden = hiddenloss = 0# 教师强制:使用目标作为下一个输入for di in range(len(target_tensor)):output, decoder_hidden = model.gru(model.embedding(decoder_input).view(1, 1, -1),decoder_hidden)output = model.out(output[0])loss += criterion(output, target_tensor[di])decoder_input = target_tensor[di]# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item() / len(target_tensor)# 打印进度if (epoch + 1) % 100 == 0:print(f'轮次 {epoch + 1}, 平均损失: {total_loss / len(english_sentences):.4f}')# 测试翻译
def translate(sentence):with torch.no_grad():input_tensor = sentence_to_tensor(sentence.lower(), eng_to_idx)output_words, _ = model(input_tensor)return output_words# 示例翻译
print("\n翻译结果:")
print(f"'hello' -> '{translate('hello')}'")
print(f"'how are you' -> '{translate('how are you')}'")
print(f"'i love machine learning' -> '{translate('i love machine learning')}'")

4. 理解输出

训练后,模型应该能够将简单的英文短语翻译成中文。例如:

  • 输入: “hello”

    • 输出: “你好”
  • 输入: “how are you”

    • 输出: “你好吗”
  • 输入: “i love machine learning”

    • 输出: “我爱机器学习”

5. 关键组件解释

  1. 嵌入层

    • 将离散的词索引转换为连续向量
    • 捕捉词与词之间的语义关系
  2. GRU(门控循环单元)

    • 使用更新门和重置门控制信息流
    • 解决标准RNN中的梯度消失问题
  3. 教师强制

    • 在训练过程中使用目标输出作为下一个输入
    • 帮助模型更快地学习正确的翻译
  4. 束搜索

    • 可以用于提高翻译质量
    • 在解码过程中跟踪多个可能的翻译

6. 挑战与改进

  1. 处理变长序列

    • 使用填充和掩码
    • 实现注意力机制以获得更好的对齐
  2. 词汇表大小

    • 使用子词单元(如Byte Pair Encoding, WordPiece)
    • 实现指针生成网络处理稀有词
  3. 性能

    • 使用双向RNN增强上下文理解
    • 实现Transformer架构以实现并行处理

这个示例为使用RNN进行序列到序列学习提供了基础。对于生产系统,建议使用基于Transformer的模型(如BART或T5),这些模型在机器翻译任务中表现出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/84591.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

window批处理文件(.bat),用来清理git的master分支

echo off chcp 65001 > nul setlocal enabledelayedexpansionecho 正在检查Git仓库... git rev-parse --is-inside-work-tree >nul 2>&1 if %errorlevel% neq 0 (echo 错误&#xff1a;当前目录不是Git仓库&#xff01;pauseexit /b 1 )echo 警告&#xff1a;这将…

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性&#xff1a; 隐藏字段的实现细节 提供对字段的受控访问 访问控制&#xff1a; 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性&#xff1a; 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑&#xff1a; 可以…

【mysql】联合索引和单列索引的区别

区别核心&#xff1a;联合索引可加速多个字段组合查询&#xff0c;单列索引只能加速一个字段。 &#x1f539;联合索引&#xff08;复合索引&#xff09; INDEX(col1, col2, col3)适用范围&#xff1a; WHERE col1 ... ✅ WHERE col1 ... AND col2 ... ✅ WHERE col1 ..…

如何用 HTML 展示计算机代码

原文&#xff1a;如何用 HTML 展示计算机代码 | w3cschool笔记 &#xff08;请勿将文章标记为付费&#xff01;&#xff01;&#xff01;&#xff01;&#xff09; 在编程学习和文档编写过程中&#xff0c;清晰地展示代码是一项关键技能。HTML 作为网页开发的基础语言&#x…

大模型笔记_模型微调

1. 大模型微调的概念 大模型微调&#xff08;Fine-tuning&#xff09;是指在预训练大语言模型&#xff08;如GPT、BERT、LLaMA等&#xff09;的基础上&#xff0c;针对特定任务或领域&#xff0c;使用小量的目标领域数据对模型进行进一步训练&#xff0c;使其更好地适配具体应…

React Native UI 框架与动画系统:打造专业移动应用界面

React Native UI 框架与动画系统&#xff1a;打造专业移动应用界面 关键要点 UI 框架加速开发&#xff1a;NativeBase、React Native Paper、UI Kitten 和 Tailwind-RN 提供预构建组件&#xff0c;帮助开发者快速创建美观、一致的界面。动画提升体验&#xff1a;React Native…

在QT中使用OpenGL

参考资料&#xff1a; 主页 - LearnOpenGL CN https://blog.csdn.net/qq_40120946/category_12566573.html 由于OpenGL的大多数实现都是由显卡厂商编写的&#xff0c;当产生一个bug时通常可以通过升级显卡驱动来解决。 OpenGL中的名词解释 OpenGL 上下文&#xff08;Conte…

Qt::QueuedConnection详解

在多线程编程中&#xff0c;线程间的通信是一个关键问题。Qt框架提供了强大的信号和槽机制来处理线程通信&#xff0c;其中Qt::QueuedConnection是一种非常有用的连接类型。本文将深入探讨Qt::QueuedConnection的原理、使用场景及注意事项。 一、基本概念 Qt::QueuedConnecti…

X86 OpenHarmony5.1.0系统移植与安装

近期在研究X86鸿蒙,通过一段时间的研究终于成功了,在X86机器上成功启动了openharmony系统了.下面做个总结和分享 1. 下载源码 获取OpenHarmony标准系统源码 repo init -u https://gitee.com/openharmony/manifest.git -b refs/tags/OpenHarmony-v5.1.0-Release --no-repo-ve…

如何诊断服务器硬盘故障?出现硬盘故障如何处理比较好?

当服务器硬盘出现故障时&#xff0c;及时诊断问题并采取正确的处理方法至关重要。硬盘故障可能导致数据丢失和系统不稳定&#xff0c;影响服务器的正常运行。以下是诊断服务器硬盘故障并处理的最佳实践&#xff1a; 诊断服务器硬盘故障的步骤 1. 监控警报 硬盘监控工具&#…

vue3提供的hook和通常的函数有什么区别

Vue 3 提供的 hook&#xff08;组合式函数&#xff09; 和普通函数在使用场景、功能和设计目的上有明显区别&#xff0c;它们是 Vue 3 组合式 API 的核心概念。下面从几个关键维度分析它们的差异&#xff1a; 1. 设计目的不同 Hook&#xff08;组合式函数&#xff09; 专为 Vu…

Spark提交流程

bin/spark-submit --class org.apache.spark.examples.SparkPi --master yarn ./examples/jars/spark-examples_2.12-3.3.1.jar 10 这一句命令实际上是 启动一个Java程序 java org.apache.spark.deploy.SparkSubmit 并将命令行参数解析到这个类的对应属性上 因为master给…

Microsoft Copilot Studio - 尝试一下Agent

1.简单介绍 Microsoft Copilot Studio以前的名字是Power Virtual Agent(简称PVA)。Power Virutal Agent是2019年出现的&#xff0c;是低代码平台Power Platform的一部分。当时Generative AI还没有出现&#xff0c;但是基于已有的Conversation AI技术&#xff0c;即Microsoft L…

【源码剖析】2-搭建kafka源码环境

在上篇文章kafka核心概念中&#xff0c;解释了kafka的核心概念&#xff0c;下面开始进行kafka源码编译。为什么学习源码需要进行源码编译呢&#xff0c;我认为主要有两点&#xff1a; 可以进行debug&#xff0c;跟踪代码执行逻辑可以对源码改动&#xff0c;强化学习学习效果 …

小红书视频图文提取:采集+CV的实战手记

项目说明&#xff1a;这波视频&#xff0c;值不值得采&#xff1f; 你有没有遇到过这样的场景&#xff1f;老板说&#xff1a;“我们得看看最近小红书上关于‘旅行’的视频都说了些什么。”团队做数据分析的&#xff0c;立马傻眼&#xff1a;官网打不开、接口抓不着、视频不能…

Cloudflare 从 Nginx 到 Pingora:性能、效率与安全的全面升级

在互联网的快速发展中&#xff0c;高性能、高效率和高安全性的网络服务成为了各大互联网基础设施提供商的核心追求。Cloudflare 作为全球领先的互联网安全和基础设施公司&#xff0c;近期做出了一个重大技术决策&#xff1a;弃用长期使用的 Nginx&#xff0c;转而采用其内部开发…

从编辑到安全设置: 如何满足专业文档PDF处理需求

随着数字化办公的发展&#xff0c;PDF 已成为跨平台文档交互的标准格式。无论是在日常办公、学术研究&#xff0c;还是项目协作中&#xff0c;对 PDF 文件进行高效编辑与管理的需求日益增长。功能全面、操作流畅且无额外负担的 PDF 编辑工具&#xff0c;它是一款在功能上可与 A…

Kafka消费者组位移重设指南

#作者&#xff1a;张桐瑞 文章目录 一、Kafka 与传统消息引擎的核心差异二、重设消费者组位移的核心原因三、重设位移的两大维度与七种策略四、重设位移的实现方式&#xff08;一&#xff09;Java API 方式&#xff08;二&#xff09;命令行脚本方式&#xff08;Kafka 0.11&am…

分类模型:逻辑回归

1、针对设计&#xff1a;二分类 Logistic 回归最初是为二分类问题设计的&#xff0c; Logistic 回归基于概率&#xff0c;通过 Sigmoid 函数转换输入特征的线性组合&#xff0c;将任意实数映射到 [0, 1] 区间内。 通过引入一个决策规则&#xff08;通常是概率的阈值&#xff…

CppCon 2015 学习:C++ WAT

这段代码展示了 C 中的一些有趣和令人困惑的特性&#xff0c;尤其是涉及数组访问和某些语法的巧妙之处。让我们逐个分析&#xff1a; 1. assert(map[“Hello world!”] e;) 这一行看起来很不寻常&#xff0c;因为 map 在这里被用作数组下标访问器&#xff0c;但是在前面没有…