函数描述:
unsqueeze(input, dim) → Tensor
作用:在指定位置插入一个维度,对数据维度进行扩充
input:输入的Tensor
dim:要插入的维度
a = torch.arange(6).reshape(2, 3)
print(a)
b = a.unsqueeze(1)#在第2维度加一维度
print(b)
print(b.shape)
>>>
tensor([[0, 1, 2],[3, 4, 5]])
tensor([[[0, 1, 2]],[[3, 4, 5]]])
torch.Size([2, 1, 3])
函数描述:
squeeze(input, dim) → Tensor
作用:对数据维度进行压缩
a = torch.arange(12).reshape(1, 2, 6)
print(a)
a1 = a.squeeze(0)#将第一个维度去掉
print(a1)
print(a1.shape)
>>>
tensor([[[ 0, 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10, 11]]])
tensor([[ 0, 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10, 11]])
torch.Size([2, 6])
-----------------------
a2 = a.squeeze(-1)#最后一个维度并没有被去掉,因为不为1
print(a2)
print(a2.shape)
>>>
tensor([[[ 0, 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10, 11]]])
torch.Size([1, 2, 6])