pytorch中的squeeze()与unsqueeze() 包懂

  • Post author:
  • Post category:其他




squeeze



下文是根据官方文档的定义给出的代码解释
torch.squeeze(input, dim=None) → Tensor



不传入dim

不传尺寸时候,返回一个张量,其中删除了大小为1的输入的所有维度。

例如,如果输入的形状为:(A×1×B×C×1×D)则输出张量的形状为(A×B×C×D)



代码演示
import torch
a = torch.tensor([[[1],[2]],[[3],[4]]])
b = torch.tensor([[[[1],[2]]],[[[3],[4]]]])
c = torch.tensor([[[1,2],[2,2]],[[3,4],[4,5]]])
print(a.size(),a.squeeze().size())
print(a.squeeze())
print(b.size(),b.squeeze().size())
print(b.squeeze())
print(c.size(),c.squeeze().size())

在这里插入图片描述



结果分析

a的形状是(2 * 2 * 1)

b的形状是(2 * 2 * 2 * 1)

c的形状是(2 * 2 * 2)

根据官方给的公式: (

A×1×B×C×1×D) -> (A×B×C×D)


删除大小为一的维度


故 a =(2 * 2 * 1) – > (2 * 2)

b =(2 * 2 * 2 * 1)- > (2 * 2 * 2)

c = (2 * 2 * 2) – > (2 * 2 * 2)



传入dim

当给定尺寸时,仅在给定尺寸中进行挤压操作。如果输入的形状为:(A×1×B),则挤压(input,0)使张量保持不变,但挤压(input,1)将张量挤压为形状(A×B)



注意!!! 就是这里很多博客都没说清楚!!

这里注意A * 1 * B是指维度第一个数字为A,如果第二个数字是1则会消掉,否则不变,如果是1,则后面的数字都是B,不如第二数字不是1,则第一个数字是A,后面都是B



举例子
x = torch.zeros(2, 1, 2, 1, 2)
y = torch.zeros(2, 1, 1, 1, 1,2)
b = torch.tensor([[[[1],[2]]],[[[3],[4]]]])
print(x.size(),x.squeeze(0).size())
print(x.size(),x.squeeze(1).size(),x.squeeze(1).size())
print(b.size(),b.squeeze(1).size(),b.squeeze(1).squeeze(1).size())
print(y.size(),y.squeeze(1).size(),y.squeeze(1).squeeze(1).size(),y.squeeze(1).squeeze(1).squeeze(1).size())

在这里插入图片描述

比如x的维度是2,1,2,1,2 则A=2,B=[2,1,2] 则构成了A * 1 * B 故输出是2,2,1,2,如是dim=0则不变这很简单。

再比如y的维度是2,1,1,1,1,2则A=2,B=[1,1,1,2],由构成了A * 1 * B 故输出是2,1,1,1,2,有的朋友可能已经想到了如果在加一层squeeze(1)呢,没错和你想得一样会变成2,1,1,2,在继续加会变成2,1,2一直变成2,2

是不是很简单呢,unsqueeze()更简单



注意

这是官方给的提醒

1、返回的张量与输入张量共享存储,因此更改一个张量的内容将更改另一个张量的内容。

2、如果张量具有大小为1的批次维度,那么挤压(输入)也将移除批次维度,这可能导致意外错误。



unsqueeze()

torch.unsqueeze(input, dim) → Tensor

返回一个新的张量,在指定位置插入一个尺寸为1的张量。返回的张量与该张量共享相同的底层数据。其实与squeeze()相似只不过是反着来。

unsqueeze(0)则表示在第一个位置加入一个尺寸为1 的张量

unsqueeze(1)则还是安装A * 1 * B的公式在A的后面插入一个尺寸为1 的张量



代码演示
x = torch.zeros(2, 1, 2, 1, 2)
y = torch.zeros(2, 1, 1,2)
b = torch.tensor([[[[1],[2]]],[[[3],[4]]]])
print(b.size(),b.squeeze(1).size(),b.squeeze(1).unsqueeze(1).size())
print(x.size(),x.unsqueeze(0).size())
print(x.size(),x.unsqueeze(1).size())

在这里插入图片描述

squeeze了解后unsqueeze比较简单了,代码一目了然

如果还有不熟悉的地方拿代码自己再试一下就更清楚了



结尾

个人觉得自己写一遍代码就是最好的理解的方式



版权声明:本文为qq_45893652原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。