RNN小练习
要求:
假设有 4 个字 吃 了 没 ?
,请使用 torch.nn.RNN
完成以下任务
- 将每个进行 one-hot 编码
- 请使用
吃 了 没
作为输入序列,了 没 ?
作为输出序列 - RNN 的
hidden_size = 64
- 请将 RNN 的输出使用全连接转换成 4 个特征,并使用 CrossEntropyLoss 训练模型
- 训练模型并验证
1、准备数据集
import torch.nn.functional
from torch.utils.data import Datasetclass mydataset(Dataset):def __init__(self):super().__init__()texts = '吃 了 没 ?'self.words = texts.split()self.input = self.words[:3]self.label = self.words[1:]def __len__(self):return 1def __getitem__(self, idx):# 对输入进行 one_hot 编码inp = torch.nn.functional.one_hot(torch.tensor([self.words.index(word) for word in self.input]),len(self.words)).float()# 对标签进行编码,返回文字的索引label = torch.tensor([self.words.index(word) for word in self.label])return inp, label
2、创建模型
import torch.nn as nnclass mymodel(nn.Module):def __init__(self):super().__init__()self.rnn = nn.RNN(4,64,nonlinearity='relu')self.fc1 = nn.Linear(64,4)def forward(self, x,h=None):x,h = self.rnn(x,h)x = self.fc1(x)return x,h
3、训练模型以及预测
import torch.nn as nn
from torch import optimfrom myset import mydataset
from mymodel import mymodelEPOCH = 1000
LR = 1e-2ds = mydataset()
inputs,lables = ds[0]model = mymodel()loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=LR)for epoch in range(EPOCH):optimizer.zero_grad()y,h = model(inputs)loss = loss_fn(y,lables)print(loss)loss.backward()optimizer.step()model.eval()y,h = model(inputs)y = y.softmax(-1)
maxarg = y.argmax(-1)print([ds.words[indx] for indx in maxarg.tolist()])