用pytorch,使用VGG块与残差结构搭建一个网络模型,仅给出网络模型

  • Post author:
  • Post category:其他


好的,下面是一个使用 PyTorch 搭建的网络模型,同时包含 VGG 块和残差结构:

“`python import torch import torch.nn as nn

class VGGBlock(nn.Module): def

init

(self, in_channels, middle_channels, out_channels, activation=nn.ReLU): super(VGGBlock, self).

init

() self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(middle_channels) self.activation = activation() self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels)

def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.activation(out)
    out = self.conv2(out)
    out = self.bn2(out)
    return out

class ResidualBlock(nn.Module): def

init

(self, channels, activation=nn.ReLU): super(ResidualBlock, self).

init

() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.activation = activation() self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels)

def forward(self, x):
    residual = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.activation(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out += residual
    return out

class Network(nn.Module): def

init

(self, num_classes=10): super(Network, self).

init

() self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.bn1 = nn.BatchNorm2d(64) self.activation = nn.ReLU() self.vgg_block1 = VGGBlock(64, 128, 128) self.vgg_block2 = VGGBlock(128, 256, 256) self.residual_block = ResidualBlock(256) self.fc = nn.Linear(256 * 8 * 8, num_classes)

def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.activation(out)
    out = self.



版权声明:本文为weixin_42584586原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。