一文捋清【reshape、view、rearrange、contiguous、transpose、squeeze、unsqueeze】——python & torch

  • Post author:
  • Post category:python


一文捋清【reshape、view、rearrange、contiguous、transpose、squeeze、unsqueeze】



1. reshape


reshape() 函数:

用于在不更改数据的情况下为数组赋予新形状。


注意:

用于低维度转高维度

c = np.arange(6)
print("** ", c)
c1 = c.reshape(3, -1)
print("** ", c1)
c2 = c.reshape(-1, 6)
print("** ", c2)
**  [0 1 2 3 4 5]
**  [[0 1]
 [2 3]
 [4 5]]
**  [[0 1 2 3 4 5]]



2. view

torch中,view() 的作用相当于numpy中的reshape,重新定义矩阵的形状。

v1 = torch.range(1, 16) 
v2 = v1.view(-1, 4)  
print("** ", v1)
print("** ", v2)
**  tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16.])
**  tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])



3. rearrange

rearrange是einops中的一个函数调用方法。

from einops import rearrange

image1 = torch.zeros(2, 224, 224, 3)
image2 = rearrange(image1, 'b w h c -> b (w h) c')
image3 = rearrange(image1, 'b w h c -> (b c) (w h)')

print("** ", image1.shape)
print("** ", image2.shape)
print("** ", image3.shape)
**  torch.Size([2, 224, 224, 3])
**  torch.Size([2, 50176, 3])
**  torch.Size([6, 50176])



4. transpose

torch.transpose(Tensor,dim0,dim1)是pytorch中的ndarray矩阵进行转置的操作。


注意:

transpose()一次只能在两个维度间进行转置(也可以理解为维度转换)

x = torch.Tensor(2, 3, 4, 5)  # 这是一个4维的矩阵(只用空间位置,没有数据)
print(x.shape)
# 先转置0维和1维,之后在第2,3维间转置,之后在第1,3间转置
y = x.transpose(0, 1).transpose(3, 2).transpose(1, 3)
print(y.shape)
torch.Size([2, 3, 4, 5])
torch.Size([3, 4, 5, 2])



5. permute


注意:

permute相当于可以同时操作于tensor的若干维度,transpose只能同时作用于tensor的两个维度,

permute是transpose的进阶版。

print(torch.Tensor(2,3,4,5).permute(3,2,0,1).shape)
torch.Size([5, 4, 2, 3])



6. contiguous


x.is_contiguous()

——判断tensor是否连续


x.contiguous()

——把tensor变成在内存中连续分布的形式


需要变成连续分布的情况:

view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。

x = torch.Tensor(5, 10)
print(x.is_contiguous())
print(x.transpose(0, 1).is_contiguous())
print(x.transpose(0, 1).contiguous().is_contiguous())
True
False
True

写代码时,一般没写contiguous()会报错提示,所以不用担心…



7. squeeze

squeeze()函数的功能是

维度压缩

。返回一个tensor(张量),其中 input 中大小为1的所有维都已删除。

x = torch.Tensor(2, 1, 2, 1, 2)
print(x.shape)
y = torch.squeeze(x) # 默认是把所有是1的维度都删掉
print(y.shape)
y = torch.squeeze(x, 1)
print(y.shape)
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
torch.Size([2, 2, 1, 2])



8. unsqueeze

squeeze()函数的功能是

增加维度

x = torch.arange(0,6)
print(x.shape)
y = x.unsqueeze(0)
print(y.shape)
z = x.unsqueeze(1)
print(z.shape)
w = x.unsqueeze(2)
print(w.shape)
torch.Size([6])
torch.Size([1, 6])
torch.Size([6, 1])
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)