torch实现自编码器_Pytorch卷积自动编码器

  • Post author:
  • Post category:其他


How one construct decoder part of convolutional autoencoder? Suppose I have this

(input -> conv2d -> maxpool2d -> maxunpool2d -> convTranspose2d -> output):

# CIFAR images shape = 3 x 32 x 32

class ConvDAE(nn.Module):

def __init__(self):

super().__init__()

# input: batch x 3 x 32 x 32 -> output: batch x 16 x 16 x 16

self.encoder = nn.Sequential(

nn.Conv2d(3, 16, 3, stride=1, padding=1), # batch x 16 x 32 x 32

nn.ReLU(),

nn.BatchNorm2d(16),

nn.MaxPool2d(2, stride=2) # batch x 16 x 16 x 16

)

# input: batch x 16 x 16 x 16 -> output: batch x 3 x 32 x 32

self.decoder = nn.Sequential(

# this line does not work

# nn.MaxUnpool2d(2, stride=2, padding=0), # batch x 16 x 32 x 32

nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, output_padding=1), # batch x 16 x 32 x 32

nn.ReLU(),

nn.BatchNorm2d(16),

nn.ConvTranspose2d(16, 3, 3, stride=1, padding=1, output_padding=0), # batch x 3 x 32 x 32

nn.ReLU()

)

def forward(self, x):

print(x.size())

out = self.encoder(x)

print(out.size())

out = self.decoder(out)

print(out.size())

return out

Pytorch specific question: why can’t I use MaxUnpool2d in decoder part. This gives me the following error:

TypeError: forward() missing 1 required positional argument: ‘indices’

And the conceptual question: Shouldn’t we do in decoder inverse of whatever we did in encoder? I saw some implementations and it seems they only care about the dimensions of input and output of decoder. Here and here are some examples.

解决方案

For the torch part of the question, unpool modules have as a required positional argument the indices returned from the pooling modules which will be returned with return_indices=True. So you could do

class ConvDAE(nn.Module):

def __init__(self):

super().__init__()

# input: batch x 3 x 32 x 32 -> output: batch x 16 x 16 x 16

self.encoder = nn.Sequential(

nn.Conv2d(3, 16, 3, stride=1, padding=1), # batch x 16 x 32 x 32

nn.ReLU(),

nn.BatchNorm2d(16),

nn.MaxPool2d(2, stride=2, return_indices=True)

)

self.unpool = nn.MaxUnpool2d(2, stride=2, padding=0)

self.decoder = nn.Sequential(

nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, output_padding=1),

nn.ReLU(),

nn.BatchNorm2d(16),

nn.ConvTranspose2d(16, 3, 3, stride=1, padding=1, output_padding=0),

nn.ReLU()

)

def forward(self, x):

print(x.size())

out, indices = self.encoder(x)

out = self.unpool(out, indices)

out = self.decoder(out)

print(out.size())

return out

As for the general part of the question, I don’t think state of the art is to use a symmetric decoder part, as it has been shown that devonvolution/transposed convolution produces checkerboard effects and many approaches tend to use upsampling modules instead. You will find more info faster through PyTorch channels.



版权声明:本文为weixin_39999781原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。