torch.nn 究竟是什么?
PyTorch 提供了设计精良的模块和类,如 torch.nn、torch.optim、Dataset 和 DataLoader,帮助你创建和训练神经网络。为了充分利用它们的能力并根据你的问题进行定制,你需要真正理解它们到底在做什么。为了帮助你理解这一点,我们将首先在不使用这些模块的任何功能的情况下,在 MNIST 数据集上训练一个基本的神经网络;我们最初只使用最基本的 PyTorch 张量功能。然后,我们将每次增量添加 torch.nn
、torch.optim
、Dataset
或 DataLoader
中的一个功能,准确展示每个部分的作用,以及它是如何使代码更简洁或更灵活的。
本教程假设你已安装 PyTorch,并熟悉张量操作的基础知识。(如果你熟悉 Numpy 数组操作,你会发现此处使用的 PyTorch 张量操作几乎相同)。
MNIST 数据设置
我们将使用经典的 MNIST 数据集,该数据集包含手写数字(0 到 9)的黑白图像。
我们将使用 pathlib 来处理路径(它是 Python 3 标准库的一部分),并将使用 requests 下载数据集。我们只在使用时导入模块,这样你就可以清楚地看到每一步使用了什么。
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "https://github.com/pytorch/tutorials/raw/main/_static/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
该数据集采用 numpy 数组格式,并使用 pickle(一种 Python 特有的数据序列化格式)存储。
import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
每张图像是 28 x 28 像素,并以长度为 784 (=28x28) 的扁平行存储。让我们看一张;我们需要先将其重塑为 2D 形式。
from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
# ``pyplot.show()`` only if not on Colab
try:import google.colab
except ImportError:pyplot.show()
print(x_train.shape)
# plt.show()
输出为:
(50000, 784)
得到的图像:
PyTorch 使用 torch.tensor
而非 numpy 数组,因此我们需要转换数据。
import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
输出为:
tensor([[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]) tensor([5, 0, 4, ..., 8, 4, 8])
torch.Size([50000, 784])
tensor(0) tensor(9)
从头开始构建神经网络(不使用 torch.nn
)
我们首先只使用 PyTorch 张量操作创建一个模型。我们假设你已经熟悉神经网络的基础知识。(如果你不熟悉,可以在