一文捋清【reshape、view、rearrange、contiguous、transpose、squeeze、unsqueeze】
1. reshape
reshape() 函数:
用于在不更改数据的情况下为数组赋予新形状。
注意:
用于低维度转高维度
c = np.arange(6)
print("** ", c)
c1 = c.reshape(3, -1)
print("** ", c1)
c2 = c.reshape(-1, 6)
print("** ", c2)
** [0 1 2 3 4 5]
** [[0 1]
[2 3]
[4 5]]
** [[0 1 2 3 4 5]]
2. view
torch中,view() 的作用相当于numpy中的reshape,重新定义矩阵的形状。
v1 = torch.range(1, 16)
v2 = v1.view(-1, 4)
print("** ", v1)
print("** ", v2)
** tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16.])
** tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
3. rearrange
rearrange是einops中的一个函数调用方法。
from einops import rearrange
image1 = torch.zeros(2, 224, 224, 3)
image2 = rearrange(image1, 'b w h c -> b (w h) c')
image3 = rearrange(image1, 'b w h c -> (b c) (w h)')
print("** ", image1.shape)
print("** ", image2.shape)
print("** ", image3.shape)
** torch.Size([2, 224, 224, 3])
** torch.Size([2, 50176, 3])
** torch.Size([6, 50176])
4. transpose
torch.transpose(Tensor,dim0,dim1)是pytorch中的ndarray矩阵进行转置的操作。
注意:
transpose()一次只能在两个维度间进行转置(也可以理解为维度转换)
x = torch.Tensor(2, 3, 4, 5) # 这是一个4维的矩阵(只用空间位置,没有数据)
print(x.shape)
# 先转置0维和1维,之后在第2,3维间转置,之后在第1,3间转置
y = x.transpose(0, 1).transpose(3, 2).transpose(1, 3)
print(y.shape)
torch.Size([2, 3, 4, 5])
torch.Size([3, 4, 5, 2])
5. permute
注意:
permute相当于可以同时操作于tensor的若干维度,transpose只能同时作用于tensor的两个维度,
permute是transpose的进阶版。
print(torch.Tensor(2,3,4,5).permute(3,2,0,1).shape)
torch.Size([5, 4, 2, 3])
6. contiguous
x.is_contiguous()
——判断tensor是否连续
x.contiguous()
——把tensor变成在内存中连续分布的形式
需要变成连续分布的情况:
view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。
x = torch.Tensor(5, 10)
print(x.is_contiguous())
print(x.transpose(0, 1).is_contiguous())
print(x.transpose(0, 1).contiguous().is_contiguous())
True
False
True
写代码时,一般没写contiguous()会报错提示,所以不用担心…
7. squeeze
squeeze()函数的功能是
维度压缩
。返回一个tensor(张量),其中 input 中大小为1的所有维都已删除。
x = torch.Tensor(2, 1, 2, 1, 2)
print(x.shape)
y = torch.squeeze(x) # 默认是把所有是1的维度都删掉
print(y.shape)
y = torch.squeeze(x, 1)
print(y.shape)
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
torch.Size([2, 2, 1, 2])
8. unsqueeze
squeeze()函数的功能是
增加维度
。
x = torch.arange(0,6)
print(x.shape)
y = x.unsqueeze(0)
print(y.shape)
z = x.unsqueeze(1)
print(z.shape)
w = x.unsqueeze(2)
print(w.shape)
torch.Size([6])
torch.Size([1, 6])
torch.Size([6, 1])
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)