transformer demo

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import pytestclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length=5000):super(PositionalEncoding, self).__init__()# 创建位置编码矩阵pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))# 计算正弦和余弦位置编码pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 注册为非训练参数self.register_buffer('pe', pe)def forward(self, x):# 添加位置编码到输入张量return x + self.pe[:, :x.size(1)]class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads# 定义线性变换层self.q_linear = nn.Linear(d_model, d_model)self.k_linear = nn.Linear(d_model, d_model)self.v_linear = nn.Linear(d_model, d_model)self.out_linear = nn.Linear(d_model, d_model)def forward(self, q, k, v, mask=None):batch_size = q.size(0)# 线性变换和重塑q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# 计算注意力分数scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)# 应用掩码(如果提供)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# 应用softmax获取注意力权重attn_weights = F.softmax(scores, dim=-1)# 应用注意力权重到值向量attn_output = torch.matmul(attn_weights, v)# 重塑并应用最终线性变换attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)output = self.out_linear(attn_output)return outputclass FeedForward(nn.Module):def __init__(self, d_model, d_ff):super(FeedForward, self).__init__()self.linear1 = nn.Linear(d_model, d_ff)self.linear2 = nn.Linear(d_ff, d_model)def forward(self, x):return self.linear2(F.relu(self.linear1(x)))class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(EncoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = FeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):# 自注意力层和残差连接attn_output = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))# 前馈网络和残差连接ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return xclass DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = FeedForward(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask=None, tgt_mask=None):# 自注意力层和残差连接attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))# 编码器-解码器注意力层和残差连接cross_attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(cross_attn_output))# 前馈网络和残差连接ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return xclass Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers,num_decoder_layers, d_ff, max_seq_length, dropout=0.1):super(Transformer, self).__init__()# 词嵌入层self.src_embedding = nn.Embedding(src_vocab_size, d_model)self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)# 位置编码self.positional_encoding = PositionalEncoding(d_model, max_seq_length)# 编码器和解码器层self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)for _ in range(num_encoder_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)for _ in range(num_decoder_layers)])# 输出层self.output_layer = nn.Linear(d_model, tgt_vocab_size)self.dropout = nn.Dropout(dropout)self.d_model = d_model# 初始化参数self._init_parameters()def _init_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, src, tgt, src_mask=None, tgt_mask=None):# 源序列和目标序列的嵌入和位置编码src = self.src_embedding(src) * math.sqrt(self.d_model)src = self.positional_encoding(src)src = self.dropout(src)tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)tgt = self.positional_encoding(tgt)tgt = self.dropout(tgt)# 编码器前向传播enc_output = srcfor enc_layer in self.encoder_layers:enc_output = enc_layer(enc_output, src_mask)# 解码器前向传播dec_output = tgtfor dec_layer in self.decoder_layers:dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)# 输出层output = self.output_layer(dec_output)return output# 创建掩码函数
def create_masks(src, tgt):# 源序列掩码(用于屏蔽填充标记)src_mask = (src != 0).unsqueeze(1).unsqueeze(2)# 目标序列掩码(用于屏蔽填充标记和未来标记)tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)# 创建后续标记掩码(用于自回归解码)seq_length = tgt.size(1)nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()# 合并掩码tgt_mask = tgt_mask & nopeak_maskreturn src_mask, tgt_mask# 简单的训练函数
def train_transformer(model, optimizer, criterion, train_loader, epochs):model.train()for epoch in range(epochs):total_loss = 0for src, tgt in train_loader:# 创建掩码src_mask, tgt_mask = create_masks(src, tgt[:, :-1])# 前向传播output = model(src, tgt[:, :-1], src_mask, tgt_mask)# 计算损失loss = criterion(output.contiguous().view(-1, output.size(-1)),tgt[:, 1:].contiguous().view(-1))# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader):.4f}')# 添加model fixture
@pytest.fixture
def model():# 定义超参数d_model = 512num_heads = 8num_encoder_layers = 6num_decoder_layers = 6d_ff = 2048max_seq_length = 100dropout = 0.1# 假设的词汇表大小src_vocab_size = 10000tgt_vocab_size = 10000# 创建模型model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads,num_encoder_layers, num_decoder_layers, d_ff, max_seq_length, dropout)return model# 添加test_loader fixture
@pytest.fixture
def test_loader():# 创建一个简单的测试数据集batch_size = 2seq_length = 10# 随机生成一些测试数据src_data = torch.randint(1, 10000, (batch_size, seq_length))tgt_data = torch.randint(1, 10000, (batch_size, seq_length))# 创建DataLoaderfrom torch.utils.data import TensorDataset, DataLoaderdataset = TensorDataset(src_data, tgt_data)test_loader = DataLoader(dataset, batch_size=batch_size)return test_loader# 简单的测试函数
def test_transformer(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for src, tgt in test_loader:# 创建掩码src_mask, _ = create_masks(src, tgt)# 预测output = model(src, tgt, src_mask, None)pred = output.argmax(dim=-1)# 计算准确率total += tgt.size(0) * tgt.size(1)correct += (pred == tgt).sum().item()accuracy = correct / totalprint(f'Test Accuracy: {accuracy:.4f}')# 简单的序列到序列翻译示例
def translate(model, src_sequence, src_vocab, tgt_vocab, max_length=50):model.eval()# 将源序列转换为索引src_indices = [src_vocab.get(token, src_vocab['<unk>']) for token in src_sequence]src_tensor = torch.LongTensor(src_indices).unsqueeze(0)# 创建源序列掩码src_mask = (src_tensor != 0).unsqueeze(1).unsqueeze(2)# 初始目标序列为开始标记tgt_indices = [tgt_vocab['<sos>']]with torch.no_grad():for i in range(max_length):tgt_tensor = torch.LongTensor(tgt_indices).unsqueeze(0)# 创建目标序列掩码_, tgt_mask = create_masks(src_tensor, tgt_tensor)# 预测下一个标记output = model(src_tensor, tgt_tensor, src_mask, tgt_mask)next_token_logits = output[:, -1, :]next_token = next_token_logits.argmax(dim=-1).item()# 添加预测的标记到目标序列tgt_indices.append(next_token)# 如果预测到结束标记,则停止if next_token == tgt_vocab['<eos>']:break# 将目标序列索引转换回标记tgt_sequence = [tgt_vocab.get(index, '<unk>') for index in tgt_indices]return tgt_sequence# 示例使用
if __name__ == "__main__":# 定义超参数d_model = 512num_heads = 8num_encoder_layers = 6num_decoder_layers = 6d_ff = 2048max_seq_length = 100dropout = 0.1# 假设的词汇表大小src_vocab_size = 10000tgt_vocab_size = 10000# 创建模型model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads,num_encoder_layers, num_decoder_layers, d_ff, max_seq_length, dropout)# 定义优化器和损失函数optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)criterion = nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充标记# 这里应该有实际的数据加载代码# train_loader = ...# test_loader = ...# 训练模型# train_transformer(model, optimizer, criterion, train_loader, epochs=10)# 测试模型# test_transformer(model, test_loader)# 翻译示例# src_vocab = ...# tgt_vocab = ...# src_sequence = ["hello", "world", "!"]# translation = translate(model, src_sequence, src_vocab, tgt_vocab)# print(f"Source: {' '.join(src_sequence)}")# print(f"Translation: {' '.join(translation)}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/87081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

centos 8.3(阿里云服务器)mariadb由系统自带版本(10.3)升级到10.6

1. 备份数据库 在进行任何升级操作前&#xff0c;务必备份所有数据库&#xff1a; mysqldump -u root -p --all-databases > all_databases_backup.sql # 或者为每个重要数据库单独备份 mysqldump -u root -p db_name1 > db_name1_backup.sql mysqldump -u root -p db…

如何稳定地更新你的大模型知识(算法篇)

目录 在线强化学习的稳定知识获取机制:算法优化与数据策略一、算法层面的稳定性控制机制二、数据处理策略的稳定性保障三、训练过程中的渐进式优化策略四、环境设计与反馈机制的稳定性影响五、稳定性保障的综合应用策略六、总结与展望通过强化学习来让大模型学习高层语义知识,…

图的遍历模板

图的遍历 BFS 求距离 #include<bits/stdc.h>using namespace std;int n, m, k,q[20001],dist[20001]; vector<int> edge[20001];int main(){scanf("%d%d%d",&n,&m,&k);for (int i 1;i<m;i){int x,y;scanf("%d%d",&x,&am…

Java集合 - LinkedList底层源码解析

以下是基于 JDK 8 的 LinkedList 深度源码解析&#xff0c;涵盖其数据结构、核心方法实现、性能特点及使用场景。我们从 类结构、Node节点、插入/删除/访问操作、线程安全、性能对比 等角度进行详细分析 一、类结构与继承关系 1. 类定义 public class LinkedList<E> e…

Pytorch 卷积神经网络参数说明一

系列文章目录 文章目录 系列文章目录前言一、卷积层的定义1.常见的卷积操作2. 感受野3. 如何理解参数量和计算量4.如何减少计算量和参数量 二、神经网络结构&#xff1a;有些层前面文章说过&#xff0c;不全讲1. 池化层&#xff08;下采样&#xff09;2. 上采样3. 激活层、BN层…

C++ 中的 iostream 库:cin/cout 基本用法

iostream 是 C 标准库中用于输入输出操作的核心库&#xff0c;它基于面向对象的设计&#xff0c;提供了比 C 语言的 stdio.h 更强大、更安全的 I/O 功能。下面详细介绍 iostream 库中最常用的输入输出工具&#xff1a;cin 和 cout。 一、 基本概念 iostream 库&#xff1a;包…

SAP复制一个自定义移动类型

SAP复制移动类型 在SAP系统中&#xff0c;复制移动类型201可以通过事务码OMJJ或SPRO路径完成&#xff0c;用于创建自定义的移动类型以满足特定业务需求。 示例操作步骤 进入OMJJ事务码&#xff1a; 打开事务码OMJJ&#xff0c;选择“移动类型”选项。 复制移动类型&#xff…

Bambu Studio 中的“回抽“与“装填回抽“的区别

回抽 装填回抽: Bambu Studio 中的“回抽” (Retraction) 和“装填回抽”(Prime/Retract) 是两个不同的概念&#xff0c;它们都与材料挤出机的操作过程相关&#xff0c;但作用和触发条件有所不同。 回抽(Retraction): 回抽的作用, 在打印机移动到另一个位置之前&#xff0c;将…

危化品安全监测数据分析挖掘范式:从被动响应到战略引擎的升维之路

在危化品生产的复杂生态系统中,安全不仅仅是合规性要求,更是企业生存和发展的生命线。传统危化品安全生产风险监测预警系统虽然提供了基础保障,但其“事后响应”和“单点预警”的局限性日益凸显。我们正处在一个由大数据、人工智能、数字孪生和物联网技术驱动的范式变革前沿…

C++ RPC 远程过程调用详细解析

一、RPC 基本原理 RPC (Remote Procedure Call) 是一种允许程序调用另一台计算机上子程序的协议,而不需要程序员显式编码这个远程交互细节。其核心思想是使远程调用看起来像本地调用一样。 RPC 工作流程 客户端调用:客户端调用本地存根(stub)方法参数序列化:客户端存根将参…

Python:操作 Excel 预设色

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…

中科院1区|IF10+:加大医学系团队利用GPT-4+电子病历分析,革新肝硬化并发症队列识别

中科院1区|IF10&#xff1a;加大医学系团队利用GPT-4电子病历分析&#xff0c;革新肝硬化并发症队列识别 在当下的科研领域&#xff0c;人工智能尤其是大语言模型的迅猛发展&#xff0c;正为各个学科带来前所未有的机遇与变革。在医学范畴&#xff0c;从疾病的早期精准筛查&am…

Python学习小结

bg&#xff1a;记录一下&#xff0c;怕忘了&#xff1b;先写一点&#xff0c;后面再补充。 1、没有方法重载 2、字段都是公共字段 3、都是类似C#中顶级语句的写法 4、对类的定义直接&#xff1a; class Student: 创建对象不需要new关键字&#xff0c;直接stu Student() 5、方…

QCustomPlot 中实现拖动区域放大‌与恢复

1、拖动区域放大‌ 在 QCustomPlot 中实现 ‌拖动区域放大‌&#xff08;即通过鼠标左键拖动绘制矩形框选区域进行放大&#xff09;的核心方法是设置 SelectionRectMode。具体操作步骤&#xff1a; 1‌&#xff09;禁用拖动模式‌ 确保先关闭默认的图表拖动功能&#xff08;否…

如何将文件从 iPhone 传输到闪存驱动器

您想将文件从 iPhone 或 iPad 传输到闪存盘进行备份吗&#xff1f;这是一个很好的决定&#xff0c;但您需要先了解一些实用的方法。虽然 Apple 生态系统在很大程度上是封闭的&#xff0c;但您可以使用一些实用工具将文件从 iPhone 或 iPad 传输到闪存盘。下文提供了这些行之有效…

互联网大厂Java求职面试:云原生架构与微服务设计中的复杂挑战

互联网大厂Java求职面试&#xff1a;云原生架构与微服务设计中的复杂挑战 面试官开场白 面试官&#xff08;严肃模式开启&#xff09;&#xff1a;郑薪苦&#xff0c;欢迎来到我们的技术面试环节。我是本次面试的技术总监&#xff0c;接下来我们将围绕云原生架构、微服务设计、…

leetcode-hot-100 (链表)

1. 相交链表 题目链接&#xff1a;相交链表 题目描述&#xff1a;给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null 。 解答&#xff1a; 其实这道题目我一开始没太看懂题目给…

Web前端基础之HTML

一、浏览器 火狐浏览器、谷歌浏览器(推荐)、IE浏览器 推荐谷歌浏览器原因&#xff1a; 1、简洁大方,打开速度快 2、开发者调试工具&#xff08;右键空白处->检查&#xff0c;打开调试模式&#xff09; 二、开发工具 核心IDE工具 Visual Studio Code (VS Code)‌ 微软开发…

11.TCP三次握手

TCP连接建立与传输 1&#xff0e;主机 A 与主机 B 使用 TCP 传输数据&#xff0c;A 是 TCP 客户&#xff0c;B 是 TCP 服务器。假设有512B 的数据要传输给 B&#xff0c;B 仅给 A 发送确认&#xff1b;A 的发送窗口 swnd 的尺寸为 100B&#xff0c;而 TCP 数据报文段每次也携带…

Python 爬虫入门 Day 3 - 实现爬虫多页抓取与翻页逻辑

Python 第二阶段 - 爬虫入门 &#x1f3af; 今日目标 掌握网页分页的原理和定位“下一页”的链接能编写循环逻辑自动翻页抓取内容将多页抓取整合到爬虫系统中 &#x1f4d8; 学习内容详解 &#x1f501; 网页分页逻辑介绍 以 quotes.toscrape.com 为例&#xff1a; 首页链…