pytorch搭建卷积神经网络(alexnet、vgg16、resnet50)以及训练

  • Post author:
  • Post category:其他


文末有代码和数据集链接!!!!

(注:文章中所有path指文件的路径)

因毕业设计需要,接触卷积神经网络。由于pytorch方便使用,所以最后使用pytorch来完成卷积神经网络训练。

接触到的网络有Alexnet、vgg16、resnet50,毕业答辩完后,一直在训练Alexnet。

1.卷积神经网络搭建

pytorch中有torchvision.models,里面有许多已搭建好的模型。如果采用预训练模型,只需要修改最后分类的类别。

虽然这样但是我还是inception v3模型修改上失败。

alexnet和vgg16修改的是全连接层的最后一层。

model.classifier = nn.Sequential(nn.Linear(25088, 4096),      #vgg16
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 2))
alexnet_model.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 2),
        )

resnet50只需要修改最后的fc层。

model.fc = nn.Linear(2048, 2)

简单的修改,就可以完成。

如果采用要采用预训练模型的话,还需要对修改处参数的进行修改。(vgg16和alexnet需要,resnet50不需要,原因我认为是修改的地方不同)

 for index, parma in enumerate(model.classifier.parameters()):
     if index == 6:
        parma.requires_grad = True

2.训练

这张图是我所认为的神经网络训练的七步吧。

(1) 模型的创建上文已介绍。

(2) 数据集的建立:在PyTorch中对于数据集的文件格式有一定的要求。如图4-10所示,在目录下分别建cat和dog文件夹,这就相当于做标签

(3)对数据集进行预处理:这里采用的是数据增强变化的方法,包括对图片大小进行压缩和输入像素统一,都为224224,还有图像翻转以及归一化。

data_transform = transforms.Compose([
    transforms.Scale((224,224), 2),                           #对图像大小统一
    transforms.RandomHorizontalFlip(),                        #图像翻转
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[    #图像归一化
                             0.229, 0.224, 0.225])
         ])

(4)数据集的加载,加载方式有三种:1.如果采用pytorch模块自带的数据集就可以使用torchvision.datasets.       来添加数据集。2.和我下面代码一样,使用torchvision.datasets.ImageFolder,不过文件夹要按照(2)中固定格式来创建数据集。3.参照pytorch中的源码自己写一个相对应的函数。

train_dataset = torchvision.datasets.ImageFolder(root='/path/data/train/',transform=data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=0)

val_dataset = torchvision.datasets.ImageFolder(root='/path/data/val/', transform=data_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle=True, num_workers=0)

(5)  模型的训练

    for epoch in range(num_epochs):
        batch_size_start = time.time()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            if epoch >= 5:
                optimizer = torch.optim.SGD(model.classifier.parameters(), lr=lr2)
                print("lr", lr2)
            else:
                optimizer = torch.optim.SGD(model.classifier.parameters(), lr=lr1)
                print("lr", lr1)
            inputs = Variable(inputs)
            labels = Variable(labels)
            optimizer.zero_grad()
            outputs = model(inputs)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(outputs, labels)        #交叉熵
            loss.backward()
            optimizer.step()                          #更新权重
            running_loss += loss.data[0]

        print('Epoch [%d/%d], Loss: %.4f,need time %.4f'
                  % (epoch + 1, num_epochs, running_loss / (4000 / batch_size), time.time() - batch_size_start))

(6)验证集的验证  ,代码中有模型的保存

        correct = 0
        total = 0
        model.eval()
        for (images, labels) in val_loader:
            batch_size_start = time.time()
            images = Variable(images)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # print("正确的数量:", correct)
        print(" Val BatchSize cost time :%.4f s" % (time.time() - batch_size_start))
        print('Test Accuracy of the model on the %d Val images: %.4f' % (total, float(correct) / total))
        if (float(correct) / total) >= 0.99:
            print('the Accuracy>=0.98 the num_epochs:%d'% epoch)
            break
        x_epoch.append(epoch)
        Acc = round((float(correct) / total), 3)
        y_acc.append(Acc)

        picName = os.path.join(codeDirRoot, "log", "pic",
                               "alexnet%s.png" % experimentSuffix)
        line_chart(x_epoch, y_acc, picName)

        # if (epoch + 1) % adjustLREpoch == 0:
        #     adjust_learning_rate(optimizer, LRModulus)

        if (epoch+1) % saveModelEpoch != 0:
            continue
        saveModelName = os.path.join(codeDirRoot, "model", "alexnet%s_model.pkl"%experimentSuffix + "_" + str(epoch))
        torch.save(model.state_dict(), saveModelName)

(7) 测试集的测试,代码中包含模型的加载。

model.load_state_dict(torch.load(
    "/path/cnn/model/vgg16/39_vgg16_model.pkl",map_location=lambda storage, loc: storage))
model.eval()
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images)
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
print("正确的数量%d,所有图片数量%d:" % (correct, total))
print('val accuracy of the %d val images:%.4f' % (total, float(correct) / total))

这是完整的过程。在这个过程中加入了,警告忽略、日志保存、图形化数据。代码如下。

import warnings
warnings.filterwarnings("ignore")
class Logger(object):

    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass
sys.stdout = Logger("/path/cnn/log/resnet50_image_show.txt")
# 画折线图形并保存
def line_chart(x_epoch, y_acc, picName):
    plt.figure()#创建绘图对象
    plt.plot(x_epoch, y_acc, "b--", linewidth=1)   #在当前绘图对象绘图(X轴,Y轴,蓝色虚线,线宽度)
    plt.ylim(0.00, 1.00)
    plt.xlabel("epoch")            #X轴标签
    plt.ylabel("accuracy")               #Y轴标签
    plt.title("alexnet-Line _chart")          #图标题
    # plt.savefig(os.path.join(codeDirRoot, "log", "pic", "resnet50%s.png"%experimentSuffix))  # 保存图
    plt.savefig(picName)  # 保存图

我在老师要求下,做了最后的识别结果输出。下面是完整的代码。

warnings.filterwarnings("ignore")     #忽略警告
class Logger(object):                                #保存日志函数
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass
sys.stdout = Logger("path/cnn/log/alexnet_image_show.txt")

#显示图片函数
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)
# 模型搭建
model = models.alexnet(pretrained=False)
model.classifier = nn.Sequential(nn.Linear(9216, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 2))
print("model", model)
#加载预训练模型
model.load_state_dict(torch.load("/path/cnn/model/alexnet_model.pkl", map_location=lambda storage, loc: storage))
#数据预处理
data_transform = transforms.Compose([
    transforms.Scale((224, 224), 2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
#创建数据集
test_dataset = torchvision.datasets.ImageFolder("/path/data/show", data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
#分类的类别
class_names = test_dataset.classes
# 显示一些图片预测函数
def visualize_model(model, num_images):
    model.eval()
    images_so_far = 0

    for i, data in enumerate(test_loader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)

        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(num_images//2, 2, images_so_far)
            ax.axis('off')
            ax.set_title('predicted: {}'.format(class_names[predicted[j]]))
            imshow(inputs.cpu().data[j])
            if images_so_far == num_images:
                return
visualize_model(model, 10)       显示十张图片

# plt.ioff()     #“关闭交互模式”。
plt.savefig("/path/cnn/log/pic/alexnet.png")  # 保存图
plt.show()

这就是整个过程。

百度网盘链接:链接: https://pan.baidu.com/s/1pVb_JFp-WJjSUxMTruqxeg 提取码: m9ip



版权声明:本文为PC1022原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。