grid_sample函数
这篇博客只对bilinear mode进行解释说明,并且会对align_corners为True或False两种情况进行分情况讨论。
torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zero’, align_corners=None)
nn.functional下的grid_sample函数会根据提供的坐标(grid)对input pixels进行采样(sampling),这篇文章只以bilinear interpolation sampling为例。 根据官方文档介绍,input shape必须是4D或5D的,分别用于二维和三维图像的采样(前两个维度为batch size和channel)。
input的shape(4D case)是
(
N
,
C
,
H
i
n
,
W
i
n
)
(N, C, H_{in}, W_{in})
(
N
,
C
,
H
in
,
W
in
)
, 这个很好理解。
gird的shape(4D case)是
(
N
,
H
o
u
t
,
W
o
u
t
,
2
)
(N, H_{out}, W_{out}, 2)
(
N
,
H
o
u
t
,
W
o
u
t
,
2
)
, 这里的H和W是output的长和宽,有一点需要注意的是,grid_sample的output shape是
(
N
,
C
,
H
o
u
t
,
W
o
u
t
)
(N, C, H_{out}, W_{out})
(
N
,
C
,
H
o
u
t
,
W
o
u
t
)
, 所以output的shape和grid的shape是一样的, 而不是和input的shape一样。grid的最后一个维度2表示的是x,y坐标, 如果是5D的情况,也就是处理三维图像的时候,gird的最后一个维度就是3,因为需要引入z坐标。
grid表示的是的sampling pixel的坐标,这个坐标是被normalized过的,grid坐标取值范围为[-1, 1]。 点(-1,-1)为左上角的pixel,(1,1)为右下的pixel。中间的坐标值为某个浮点数。
grid_sample函数做的就是根据grid坐标,从input的pixels里采样。 如果此坐标下没有对应的input pixel,就要用bilinear interpolation从周围的pixels采样。
下面是Piotr给出的一个例子
https://discuss.pytorch.org/t/solved-torch-grid-sample/51662/2
inp = torch.arange(4*4).view(1, 1, 4, 4).float()
d = torch.linspace(-1, 1, 8)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2).unsqueeze(0)
output = torch.nn.functional.grid_sample(inp, grid, align_corners=False)
meshy是x坐标
meshx是y坐标
align_corners=True
当align_corners=True时,以坐标(-0.7143, -0.7143)为例,请看下图。
因为align_corners=True,所以(-1, -1)点的值为0, (1, 1)点的值为15,可以认为grid的-1和1在是在corner pixel的中心位置。由此可以推出值为1和2的坐标为(-0.3333, -1)和(0.3333, -1)。我们要采样的点(-0.7143, -0.7143)在0, 1, 4, 5中间,所以要从这四点进行采样。根据坐标算出长度比例,然后用bilinear interpolation算出坐标(-0.7143, -0.7143)的值就okay了。
下图是align_corners=True的output
align_corners=False
当align_corners=False时,以坐标(0.7143, -0.7143)为例,请看下图。
注意:这个例子的坐标和上个例子的坐标不一样。
因为align_corners=False, 所以(-1, -1)点的值不为0,(1, 1)点的值也不是15,grid的-1和1不在corner pixel的中心位置,而是在正方形像素的角。所以(-0.75, -0.75)的值才是0, (0.75, 0.75)的值才是15。由此可以推出值为1和2的坐标分别为(-0.25, -0.75)和(0.25, -0.75)。我们要采样的点(0.7143, -0.7143)在2, 3, 6, 7中间,所以要从这四点进行采样。根据坐标算出长度比例,然后用bilinear interpolation算出坐标(0.7143, -0.7143)的值就okay了。
下图是align_corners=False的output: