构建网络:
import torch
import torchvision.models as models
class VGG(torch.nn.Module):
def __init__(self, vgg='vgg16_bn', data_set='CIFAR10', pretrained=False):
super(VGG, self).__init__()
self.features = models.__dict__[vgg](pretrained=pretrained).features
classifier = []
if 'CIFAR' in data_set:
num_class = int(data_set.split("CIFAR")[1])
classifier.append(torch.nn.Linear(512, 512))
classifier.append(torch.nn.BatchNorm1d(512))
classifier.append(torch.nn.Linear(512, num_class))
else:
raise RuntimeError("Not expected data flag !!!")
self.classifier = torch.nn.Sequential(*classifier)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
生成网络:
network = VGG(args.vgg, args.data_set)
版权声明:本文为GeneralJing原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。