一.前言
本期我们来说一下张量的索引操作,需要掌握张量不同索引操作,我们在操作张量时,经常需要去进⾏获取或者修改操作,掌握张量的花式索引操作是必须的⼀项能⼒。
二.简单行、列索引
import torchdata = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)def test01():print(data[0]) #取行print(data[:, 0]) #取列print('-' * 50)if __name__ == '__main__':test01()
结果展示:
tensor([[5, 2, 1, 3, 0],
[1, 7, 3, 9, 1],
[9, 3, 4, 8, 4],
[0, 3, 3, 1, 1]])
--------------------------------------------------
tensor([5, 2, 1, 3, 0])
tensor([5, 1, 9, 0])
--------------------------------------------------
三.列表索引
import torchdata = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)def test02():# 返回 (0, 1)、(1, 2) 两个位置的元素print(data[[0, 1], [1, 2]])print('-' * 50)# 返回 0、1 ⾏的 1、2 列共4个元素print(data[[[0], [1]], [1, 2]])if __name__ == '__main__':test02()
结果展示:
tensor([[3, 8, 7, 5, 6],
[3, 4, 4, 2, 4],
[3, 7, 0, 1, 5],
[5, 6, 8, 5, 8]])
--------------------------------------------------
tensor([8, 4])
--------------------------------------------------
tensor([[8, 7],
[4, 4]])
四.索引范围
import torchdata = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)# 3. 范围索引
def test03():# 前3⾏的前2列数据print(data[:3, :2])# 第2⾏到最后的前2列数据print(data[2:, :2])if __name__ == '__main__':test03()
结果展示:
tensor([[4, 7, 1, 8, 5],
[4, 4, 8, 7, 0],
[1, 4, 3, 8, 8],
[7, 0, 1, 7, 1]])
--------------------------------------------------
tensor([[4, 7],
[4, 4],
[1, 4]])
tensor([[1, 4],
[7, 0]])
五.布尔索引
import torchdata = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)# 布尔索引
def test():# 第2列⼤于5的⾏数据print(data[data[:, 2] > 5])# 第1⾏⼤于5的列数据print(data[:, data[1] > 5])if __name__ == '__main__':test()
结果展示:
tensor([[2, 4, 2, 2, 3],
[7, 8, 3, 8, 6],
[8, 4, 5, 4, 1],
[7, 5, 8, 7, 0]])
--------------------------------------------------
tensor([[7, 5, 8, 7, 0]])
tensor([[2, 4, 2, 3],
[7, 8, 8, 6],
[8, 4, 4, 1],
[7, 5, 7, 0]])
六.多维索引
import torch# 布尔索引
def test05():data = torch.randint(0, 10, [3, 4, 5])print(data)print('-' * 50)print(data[0, :, :])print(data[:, 0, :])print(data[:, :, 0])if __name__ == '__main__':test05()
结果展示:
tensor([[[7, 4, 2, 3, 5],
[6, 1, 2, 1, 0],
[8, 2, 2, 4, 3],
[3, 5, 4, 5, 8]],[[8, 4, 9, 8, 2],
[1, 3, 5, 8, 1],
[2, 3, 0, 6, 6],
[7, 2, 4, 7, 9]],[[8, 8, 1, 7, 2],
[8, 9, 9, 2, 6],
[6, 5, 3, 0, 4],
[9, 8, 1, 2, 1]]])
--------------------------------------------------
tensor([[7, 4, 2, 3, 5],
[6, 1, 2, 1, 0],
[8, 2, 2, 4, 3],
[3, 5, 4, 5, 8]])
tensor([[7, 4, 2, 3, 5],
[8, 4, 9, 8, 2],
[8, 8, 1, 7, 2]])
tensor([[7, 6, 8, 3],
[8, 1, 2, 7],
[8, 8, 6, 9]])
七.总结
列表的索引和numpy取值是基本一样的,大家把代码自己运行一遍就差不多就知道了。