好的,下面是一个使用 PyTorch 搭建的网络模型,同时包含 VGG 块和残差结构:
“`python import torch import torch.nn as nn
class VGGBlock(nn.Module): def
init
(self, in_channels, middle_channels, out_channels, activation=nn.ReLU): super(VGGBlock, self).
init
() self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(middle_channels) self.activation = activation() self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
out = self.bn2(out)
return out
class ResidualBlock(nn.Module): def
init
(self, channels, activation=nn.ReLU): super(ResidualBlock, self).
init
() self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.activation = activation() self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
return out
class Network(nn.Module): def
init
(self, num_classes=10): super(Network, self).
init
() self.conv1 = nn.Conv2d(3, 64, 3, padding=1) self.bn1 = nn.BatchNorm2d(64) self.activation = nn.ReLU() self.vgg_block1 = VGGBlock(64, 128, 128) self.vgg_block2 = VGGBlock(128, 256, 256) self.residual_block = ResidualBlock(256) self.fc = nn.Linear(256 * 8 * 8, num_classes)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.activation(out)
out = self.