基于pytorch的简单CNN实现
加载数据集
pytorch自带datasets用于数据处理,该类是抽象类,必须通过继承使用,本文实现的为MINST数据集,所以可以直接使用自带的torchvision.datasets.MNIST进行数据集加载,再通过Data.DataLoader进行数据的处理,它的优势体现在处理超大数据集进行训练时,不必将数据集一次行全部加入内存。
数据集加载和处理如下:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision # 数据库模块
import matplotlib.pyplot as plt
EPOCH = 1
BATCH_SIZE = 10
LR = 0.001
DOWNLOAD_MNIST = True
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True, # this is training data
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000] / 255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1),the input shape is (1,28,28)
test_y = test_data.test_labels[:2000]
train_loader = Data.DataLoader(dataset=train_data, shuffle=True, batch_size=BATCH_SIZE)
nn.Module实现CNN
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5, # 卷积核
stride=1, # 步长
padding=2 #填充,保证卷积之后的大小一样
),
nn.ReLU(), #激活
nn.MaxPool2d(kernel_size=2) #向下池化,这里选择2*2大小的窗口
) #out.shape=(16,14,14)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), #同上
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
) #out.shape=(32,7,7)
self.out = nn.Linear(32 * 7 * 7, 10)
#向前传播
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #展平多维卷积图
y = self.out(x)
return y,x
训练和结论
for epoch in range(EPOCH):
correct = 0
for step, (x, y) in enumerate(train_loader):
b_x = Variable(x)
b_y = Variable(y)
out = cnn(b_x)[0]
loss = loss_fun(out, b_y)
optimizer.zero_grad() # 梯度清零
loss.backward() # 反向传播
optimizer.step() # 更新
if step % 50 == 0:
test_output, last_layer = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
Epoch: 0 | train loss: 0.0134 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0129 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0046 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0572 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0599 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0076 | test accuracy: 0.98
Epoch: 0 | train loss: 0.8489 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0195 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0288 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0013 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0078 | test accuracy: 0.98
Epoch: 0 | train loss: 0.5733 | test accuracy: 0.98
Epoch: 0 | train loss: 0.4503 | test accuracy: 0.98
Epoch: 0 | train loss: 0.2883 | test accuracy: 0.98
Epoch: 0 | train loss: 0.0023 | test accuracy: 0.98
Epoch: 0 | train loss: 0.6105 | test accuracy: 0.98
#部分结果如上
in loss: 0.0023 | test accuracy: 0.98
Epoch: 0 | train loss: 0.6105 | test accuracy: 0.98
#部分结果如上
版权声明:本文为qq_52332972原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。