SRGAN模型——pytorch实现

  • Post author:
  • Post category:其他


论文传送门:

https://arxiv.org/pdf/1609.04802.pdf


SRGAN模型目的:输入低分辨率图像,生成高分辨率图像。


生成网络

由三部分构成:

①卷积+PReLU激活函数;

②(卷积+BN+PReLU+卷积+BN,连接残差边)x16+卷积+BN,连接残差边;

③(卷积+像素重组+PReLU)x2+卷积;

①②用于提取图像特征,③用于图像上采样,实现超分。


生成网络的目的:输入低分辨率图像,输出高分辨率图像。


鉴别网络

类似VGG结构,由(卷积+BN+LeakyReLU)组成。


鉴别网络目的:输入高分辨图像,判断输入图像是真实图像还是生成图像。

class D_Block(nn.Module):  # 定义判别器中结构块(卷积+标准化+激活函数)
    def __init__(self, in_channel, out_channle, strid):  # 初始化方法,参数:输入通道数,输出通道数,卷积步长
        super(D_Block, self).__init__()  # 继承初始化方法
        self.block = nn.Sequential(  # 结构块
            nn.Conv2d(in_channel, out_channle, 3, strid, 1),  # conv
            nn.BatchNorm2d(out_channle),  # bn
            nn.LeakyReLU(0.2)  # leakyrelu
        )

    def forward(self, x):  # 前传函数
        return self.block(x)


class Discriminator(nn.Module):  # 定义判别器
    def __init__(self):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)  # conv
        self.leakyrelu = nn.LeakyReLU(0.2)  # leakyrelu
        self.downsample = nn.Sequential(  # 下采样结构块,与VGG相同
            D_Block(64, 64, 2),  # 卷积+标准化+激活函数
            D_Block(64, 128, 1),  # 卷积+标准化+激活函数
            D_Block(128, 128, 2),  # 卷积+标准化+激活函数
            D_Block(128, 256, 1),  # 卷积+标准化+激活函数
            D_Block(256, 256, 2),  # 卷积+标准化+激活函数
            D_Block(256, 512, 1),  # 卷积+标准化+激活函数
            D_Block(512, 512, 2)  # 卷积+标准化+激活函数
        )
        self.linear = nn.Sequential(  # 线性映射结构块
            nn.AdaptiveAvgPool2d(1),  # 平均自适应池化
            nn.Conv2d(512, 1024, 1, 1, 0),  # conv,使用1x1卷积代替全连接
            nn.LeakyReLU(0.2),  # leakyrelu
            nn.Conv2d(1024, 1, 1, 1, 0),  # conv,使用1x1卷积代替全连接
            nn.Sigmoid()  # sigmoid
        )

    def forward(self, x):  # 前传函数,输入高分辨率图像
        x = self.leakyrelu(self.conv1(x))  # conv+leakyrelu,(n,3,256,256)-->(n,64,256,256)
        x = self.downsample(
            x)  # 下采样,(n,64,256,256)-->(n,64,128,128)-->(n,128,128,128)-->(n,128,64,64)-->(n,256,64,64)-->(n,256,32,32)-->(n,512,32,32)-->(n,512,16,16)
        x = self.linear(x)  # 线性映射,(n,512,16,16)-->(n,512,1,1)-->(n,1024,1,1)-->(n,1,1,1)
        x = x.squeeze()  # 删除多余的维度,(n,1,1,1)-->(n)
        return x  # 返回图片真假的得分


class G_Block(nn.Module):  # 定义生成器中结构块(残差结构)
    def __init__(self, channel):  # 初始化方法,参数:通道数,残差结构前后通道数不变
        super(G_Block, self).__init__()  # 继承初始化方法
        self.block = nn.Sequential(  # 结构块
            nn.Conv2d(channel, channel, 3, 1, 1),  # conv
            nn.BatchNorm2d(channel),  # bn
            nn.PReLU(channel),  # prelu,带参数的relu激活函数
            nn.Conv2d(channel, channel, 3, 1, 1),  # conv
            nn.BatchNorm2d(channel)  # bn
        )

    def forward(self, x):  # 前传函数
        return x + self.block(x)  # F(x) + x


class Generator(nn.Module):  # 定义生成器
    def __init__(self):  # 初始化方法
        super(Generator, self).__init__()  # 继承初始化方法
        self.conv1 = nn.Conv2d(3, 64, 9, 1, 4)  # conv
        self.prelu1 = nn.PReLU(64)  # prelu
        self.blocks = []  # 存放残差块的列表
        for _ in range(16):  # 共16个残差块
            self.blocks.append(G_Block(64))  # 添加残差块
        self.blocks = nn.Sequential(*self.blocks)  # 列表转化为模型结构序列
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)  # conv
        self.bn2 = nn.BatchNorm2d(64)  # bn
        self.upsample = nn.Sequential(  # 上采样块
            nn.Conv2d(64, 256, 3, 1, 1),  # conv
            nn.PixelShuffle(2),  # pixelshuffle,像素重组,将通道拆分重组至(H,W)
            nn.PReLU(64),  # prelu
            nn.Conv2d(64, 256, 3, 1, 1),  # conv
            nn.PixelShuffle(2),  # pixelshuffle
            nn.PReLU(64),  # prelu
            nn.Conv2d(64, 3, 9, 1, 4)  # conv
        )

    def forward(self, x):  # 前传函数,输入低分辨率图像
        x = self.prelu1(self.conv1(x))  # conv+prelu,(n,3,64,64)-->(n,64,64,64)
        x += self.bn2(self.conv2(self.blocks(x)))  # F(x)+x,F(x):16层残差结构+conv+bn,(n,64,64,64)-->(n,64,64,64)
        x = self.upsample(
            x)  # 上采样,(n,64,64,64)-->(n,256,64,64)-->(n,64,128,128)-->(n,256,128,128)-->(n,64,256,256)-->(n,3,256,256)
        return x  # 返回高分辨率图像



版权声明:本文为Peach_____原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。