Pytorch grid_sample解析

  • Post author:
  • Post category:其他




grid_sample函数

这篇博客只对bilinear mode进行解释说明,并且会对align_corners为True或False两种情况进行分情况讨论。

torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zero’, align_corners=None)

nn.functional下的grid_sample函数会根据提供的坐标(grid)对input pixels进行采样(sampling),这篇文章只以bilinear interpolation sampling为例。 根据官方文档介绍,input shape必须是4D或5D的,分别用于二维和三维图像的采样(前两个维度为batch size和channel)。

input的shape(4D case)是



(

N

,

C

,

H

i

n

,

W

i

n

)

(N, C, H_{in}, W_{in})






(


N


,




C


,





H











in



















,





W











in



















)





, 这个很好理解。

gird的shape(4D case)是



(

N

,

H

o

u

t

,

W

o

u

t

,

2

)

(N, H_{out}, W_{out}, 2)






(


N


,





H











o


u


t



















,





W











o


u


t



















,




2


)





, 这里的H和W是output的长和宽,有一点需要注意的是,grid_sample的output shape是



(

N

,

C

,

H

o

u

t

,

W

o

u

t

)

(N, C, H_{out}, W_{out})






(


N


,




C


,





H











o


u


t



















,





W











o


u


t



















)





, 所以output的shape和grid的shape是一样的, 而不是和input的shape一样。grid的最后一个维度2表示的是x,y坐标, 如果是5D的情况,也就是处理三维图像的时候,gird的最后一个维度就是3,因为需要引入z坐标。

grid表示的是的sampling pixel的坐标,这个坐标是被normalized过的,grid坐标取值范围为[-1, 1]。 点(-1,-1)为左上角的pixel,(1,1)为右下的pixel。中间的坐标值为某个浮点数。

grid_sample函数做的就是根据grid坐标,从input的pixels里采样。 如果此坐标下没有对应的input pixel,就要用bilinear interpolation从周围的pixels采样。

下面是Piotr给出的一个例子


https://discuss.pytorch.org/t/solved-torch-grid-sample/51662/2

inp = torch.arange(4*4).view(1, 1, 4, 4).float()
d = torch.linspace(-1, 1, 8)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2).unsqueeze(0)
output = torch.nn.functional.grid_sample(inp, grid, align_corners=False)

meshy是x坐标

在这里插入图片描述

meshx是y坐标

在这里插入图片描述



align_corners=True

当align_corners=True时,以坐标(-0.7143, -0.7143)为例,请看下图。

因为align_corners=True,所以(-1, -1)点的值为0, (1, 1)点的值为15,可以认为grid的-1和1在是在corner pixel的中心位置。由此可以推出值为1和2的坐标为(-0.3333, -1)和(0.3333, -1)。我们要采样的点(-0.7143, -0.7143)在0, 1, 4, 5中间,所以要从这四点进行采样。根据坐标算出长度比例,然后用bilinear interpolation算出坐标(-0.7143, -0.7143)的值就okay了。

在这里插入图片描述

下图是align_corners=True的output

在这里插入图片描述



align_corners=False

当align_corners=False时,以坐标(0.7143, -0.7143)为例,请看下图。

注意:这个例子的坐标和上个例子的坐标不一样。

因为align_corners=False, 所以(-1, -1)点的值不为0,(1, 1)点的值也不是15,grid的-1和1不在corner pixel的中心位置,而是在正方形像素的角。所以(-0.75, -0.75)的值才是0, (0.75, 0.75)的值才是15。由此可以推出值为1和2的坐标分别为(-0.25, -0.75)和(0.25, -0.75)。我们要采样的点(0.7143, -0.7143)在2, 3, 6, 7中间,所以要从这四点进行采样。根据坐标算出长度比例,然后用bilinear interpolation算出坐标(0.7143, -0.7143)的值就okay了。

在这里插入图片描述

下图是align_corners=False的output:

在这里插入图片描述



版权声明:本文为xingye_fan原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。