张量的拼接操作在神经网络搭建过程中是非常常用的方法,例如: 在后面将要学习的注意力机制中都使用到了张量拼接。
torch.cat 函数可以将两个张量根据指定的维度拼接起来,不改变数据维度。
前提:除了拼接的维度,其他维度一定要相同。
def cat_tensor():data1=torch.randint(0,10,(1,2,3))data2=torch.randint(0,10,(1,2,3))# 0轴拼接 dim=0# data3=torch.cat([data1,data2]) # 默认dim=0,torch.Size([2, 2, 3])# data3=torch.cat([data1,data2],dim=0)# 1轴拼接 dim=1# data3=torch.cat([data1,data2],dim=1) # torch.Size([1, 4, 3])# 2轴拼接 dim=2data3=torch.cat([data1,data2],dim=2) # torch.Size([1, 2, 6])# data3=torch.concat([data1,data2],dim=2) # torch.Size([1, 2, 6]) # 第二种写法 同种结果# data3=torch.concatenate([data1,data2],dim=2) # torch.Size([1, 2, 6]) # 第三种写法 同种结果print(data3.shape)if __name__ == '__main__':cat_tensor()