基于pytorch的猫狗预测-CNN

  • Post author:
  • Post category:其他


基于 PyTorch 的猫狗预测可以使用卷积神经网络(CNN)实现。下面是一个简要的介绍:

1. 数据预处理

首先,需要对数据进行预处理,将其转换为模型所需要的格式。可以使用 PyTorch 中的 `transforms` 模块对数据进行处理,例如将图片缩放到指定大小,进行水平翻转等数据增强操作。

2. 加载数据集

将预处理后的数据集加载到内存中,可以使用 PyTorch 中的 `datasets` 模块加载图片数据集,并使用 `dataloaders` 模块对数据进行分批处理和加载。

3. 定义模型

使用 PyTorch 中的 `nn` 模块定义卷积神经网络模型。可以定义多个卷积层、池化层和全连接层,以及添加激活函数和批归一化等技巧来提高模型精度。

4. 训练模型

使用 `optim` 模块定义优化器,如随机梯度下降(SGD)或 Adam 等。然后使用 `nn` 模块中的损失函数,如交叉熵损失函数等,对模型进行训练。在训练期间,可以使用 `scheduler` 模块对学习率进行调整。

5. 评估模型

训练完成后,使用测试集对模型进行评估,计算模型的准确率和损失值等指标。

至此,基于 PyTorch 的猫狗预测的 CNN 模型训练和评估过程就完成了。当然,为了获得更好的预测结果,还可以对数据集进行更多的预处理和增强操作,以及调整模型的超参数等。

这里也附上代码:导入需要的包

%matplotlib inline
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets,transforms
from torch import nn
from d2l import torch as d2l
from torch.utils.data import random_split
from torch.utils import data
transforms = transforms.Compose(
[
transforms.RandomResizedCrop(150),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
]
)

train_data = torchvision.datasets.ImageFolder('D:\\Jupyter Notebook\\Pytorch入门\\catsdogs\\train',transform=transforms)
valid_data =  torchvision.datasets.ImageFolder('D:\\Jupyter Notebook\\Pytorch入门\\catsdogs\\test',transform=transforms)
#设置迭代器
batch_size = 32
train_iter = data.DataLoader(train_data,batch_size,shuffle = True,num_workers = 0)
valid_iter = data.DataLoader(valid_data,batch_size,shuffle = False,num_workers = 0)
class CNN_net(nn.Module):
    def __init__(self):
        super(CNN_net,self).__init__()
        self.seq = nn.Sequential(
            #x = (32,3,150,150)
            nn.Conv2d(3,20,5,5),
            #x = (20,30,30)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            #x = (20,15,15)
            nn.Conv2d(20,50,4,1),
            #x = (50,12,12)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Flatten(),
            nn.Linear(50*6*6,200),
            nn.ReLU(),
            nn.Linear(200,2),
            nn.Sigmoid()
        )

    def forward(self,x):
        x = self.seq(x)
        return x
lr=1e-4
device=torch.device("cuda" if torch.cuda.is_available() else "cpu" )

model=CNN_net().to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=lr)
loss_fn = nn.CrossEntropyLoss().to(device)

print(device)
def train(model,device,train_iter,optimizer,loss,epochs):
    total_train_step = 0
    for epoch in range(epochs):
        print("第{}轮训练开始".format(epoch+1))
        model.train()
        for idx,(data,target) in enumerate(train_iter):
            data,target = data.to(device),target.to(device)
            pred = model(data)
            loss = loss_fn(pred,target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_step = total_train_step+1
            
            if total_train_step %10 == 0:
                print("训练次数:{},Loss:{}".format(total_train_step,loss.item()))

def test(model,device,test_iter,loss_fn):
    total_test_step = 0
    total_test_loss = 0
    total_accuracy = 0
    model.eval()
    correct = 0
    with torch.no_grad():
        for idx,(data,target) in enumerate(test_iter):
            data,target = data.to(device),target.to(device)
            pred = model(data)
            loss = loss_fn(pred,target)
            total_test_loss = total_test_loss + loss.item()#计算测试Loss
            #计算精确度
            accuracy = (pred.argmax(1) == target).sum()
            total_accuracy = total_accuracy + accuracy
    print("整体测试集上的Loss:{}".format(total_test_loss))
    print("整体测试集上的accuracy:{}".format(total_accuracy/len(valid_data)))
#     acc=correct/len(valid_data)
#     print("accuracy:{},average_loss:{}".format(acc,sum(losses)/len(valid_data)))
num_epochs=30
import time
begin_time=time.time()
print(time.ctime(begin_time))
train(model,device,train_iter,optimizer,loss_fn,num_epochs)
# test(model,device,test_loader)
end_time=time.time() 
print(time.ctime(end_time))
test(model,device,valid_iter,loss_fn) 



版权声明:本文为renzhenxuanxuan原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。