在 PyTorch 中,tensor
是一种强大且灵活的数据结构,可以与多种 Python 常用数据结构(如 int
, list
, numpy array
等)互相转换。下面是详细解释和代码示例:
1. Tensor ↔ int / float
转为 int / float(前提是 tensor 中只有一个元素)
import torcht = torch.tensor(3.14)
i = t.item() # 转为 float
j = int(t.item()) # 强制转为 intprint(i) # 3.14
print(j) # 3
.item()
只能用于单元素张量:tensor.numel() == 1
,否则会报错。
2. Tensor ↔ list
Tensor 转 list(Python 原生嵌套 list)
t = torch.tensor([[1, 2], [3, 4]])
lst = t.tolist()
print(lst) # [[1, 2], [3, 4]]
list 转 Tensor
lst = [[1, 2], [3, 4]]
t = torch.tensor(lst)
print(t) # tensor([[1, 2], [3, 4]])
支持嵌套 list(矩阵)、一维 list(向量)。
3. Tensor ↔ numpy.ndarray
PyTorch Tensor 和 NumPy array 之间可以无缝转换,共享内存(改变其中一个会影响另一个)。
Tensor → numpy array
import numpy as np
t = torch.tensor([[1, 2], [3, 4]])
a = t.numpy()
print(type(a)) # <class 'numpy.ndarray'>
numpy array → Tensor
a = np.array([[1, 2], [3, 4]])
t = torch.from_numpy(a)
print(type(t)) # <class 'torch.Tensor'>
numpy 数组必须是数值型(不能是对象数组等),否则会报错。
4. Tensor ↔ Python scalar 类型(int, float)
如果你从计算结果中获取单个数值,比如:
t = torch.tensor([5.5])
val = float(t) # 也可以使用 float(t.item())
print(val) # 5.5# 对于整型:
t2 = torch.tensor([3])
val2 = int(t2) # 等效于 int(t2.item())
print(val2) # 3
5. Tensor ↔ bytes(用于序列化,如保存到文件)
Tensor → bytes
t = torch.tensor([1, 2, 3])
b = t.numpy().tobytes()
bytes → Tensor
import numpy as np
b = b'\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00'
a = np.frombuffer(b, dtype=np.int32)
t = torch.from_numpy(a)
print(t) # tensor([1, 2, 3], dtype=torch.int32)
6.实战示例
下面我们从三个实际应用场景来讲解 PyTorch 中 tensor 与其他类型(如 list
、numpy
、int
等)相互转换的用途和技巧:
场景一:数据加载与预处理
读取图像数据(使用 PIL) → 转为 tensor
from PIL import Image
from torchvision import transformsimg = Image.open('cat.jpg') # 打开图片为 PIL.Image
to_tensor = transforms.ToTensor()
t = to_tensor(img) # 转为 [C, H, W] 的 float32 Tensor
此时你获得了一个 Tensor
,可以送入模型。但如果你想可视化或分析:
Tensor → numpy → 可视化或保存
import matplotlib.pyplot as pltimg_np = t.permute(1, 2, 0).numpy() # [H, W, C]
plt.imshow(img_np)
plt.show()
permute
是因为ToTensor
会变成[C,H,W]
,而 matplotlib 需要[H,W,C]
。
场景二:模型推理后的结果处理(转为 Python 值)
假设你有一个分类网络,输出如下:
output = torch.tensor([[0.1, 0.7, 0.2]]) # 假设输出为 batch_size=1 的 logits
pred_idx = output.argmax(dim=1) # tensor([1])
你要拿到预测类别的整数值:
pred_class = pred_idx.item() # 1
print(type(pred_class)) # <class 'int'>
.item()
在推理阶段非常常用!
场景三:保存 Tensor 到磁盘 / 网络传输
Tensor 保存和加载时经常需要转为 numpy 或 byte 流:
保存为 bytes 再写入文件
t = torch.tensor([1, 2, 3, 4], dtype=torch.int32)
with open("tensor.bin", "wb") as f:f.write(t.numpy().tobytes())
从文件读回 tensor
with open("tensor.bin", "rb") as f:byte_data = f.read()import numpy as np
arr = np.frombuffer(byte_data, dtype=np.int32)
t2 = torch.from_numpy(arr)
print(t2) # tensor([1, 2, 3, 4], dtype=torch.int32)
你必须记住原始
dtype
和shape
才能正确还原!
场景四:构造 batch 时将 list 转为 Tensor
在训练时经常从数据集中拿到多个样本组成 batch(Python list):
samples = [[1.0, 2.0], [3.0, 4.0]]
batch_tensor = torch.tensor(samples, dtype=torch.float32)
print(batch_tensor.shape) # torch.Size([2, 2])
或者更通用的方式(可以处理动态 shape):
batch_tensor = torch.stack([torch.tensor(s) for s in samples])
补充:在 with torch.no_grad()
中常用转换
推理阶段经常用 Tensor → numpy → list
:
with torch.no_grad():output = model(input_tensor)pred = output.softmax(dim=1)top1_class = pred.argmax(dim=1).item()
小结对照表
转换类型 | 方法 | 注意事项 |
---|---|---|
Tensor → int/float | .item() | 只能单元素 |
Tensor → list | .tolist() | 支持嵌套 |
list → Tensor | torch.tensor(list) | 自动推断类型 |
Tensor → ndarray | .numpy() | 共享内存 |
ndarray → Tensor | torch.from_numpy(ndarray) | 共享内存 |
Tensor → bytes | tensor.numpy().tobytes() | 用于存储 |
bytes → Tensor | np.frombuffer + from_numpy | 需知道 dtype |