Pytorch tutorials 实战教程(1)——训练自己的数据集(代码详解)

  • Post author:
  • Post category:其他


最开始入坑的时候使用的是caffe,前一段时间换了使用主流框架的keras(Tensorflow as backward),但是keras确实封装得太好了,一个高级的API对于我这种编程渣渣来说反而上手有些不习惯,在写了一段时间的代码以后开始使用pytorch(反正老板要求了两个框架都要熟练那就都学啦),对于源码部分确实友好了很多,尽管需要自己定义前向过程但是也很简单啦~

先给两个github上非常友好的tutorials(如果觉得官方文档看着上手太慢了的话):


https://github.com/SherlockLiao/pytorch-beginner



https://github.com/hunkim/PyTorchZeroToAll


官方参考资料:


https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/#torchvisiondatasets



http://pytorch.org/tutorials/beginner/data_loading_tutorial.html


**

一、训练torchvision自带数据集:

**

搭建网络、训练torchvision里面自带的数据集都是easy stuff,这个tutorials尽量记录我在实际代码中遇到的稍微要费点儿精力的事情。

首先是如何训练数据集,如果训练torchvision里自带的数据集非常简单,只需要使用torchvision.datasets直接进行读取,再实例化torch.utils.data.DataLoader(规定好batch_size以及是否进行shuffle),在训练时使用enumerate枚举函数导入数据,也可以用以下代码查看是否导入数据成功显示图片:

for i, data in enumerate(dataLoader, 0):  
    print(data[i][0])  
    # PIL  
    img = transforms.ToPILImage()(data[i][0])  
    img.show()  
    break  

完整的代码如下:

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from logger import Logger

# 定义超参数
batch_size = 128
learning_rate = 1e-2
num_epoches = 20


def to_np(x):
    return x.cpu().data.numpy()


# download datasets
train_dataset = datasets.CIFAR10(
    root='./cifar_data', train=True, transform=transforms.ToTensor(), download=True)

test_dataset = datasets.CIFAR10(
    root='./cifar_data', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


#define model
class slice_ssc(nn.Module):
    def __init__(self,in_channel,n_class):
        super(slice_ssc,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel,32,3,1,1),
            nn.ReLU(True),
            nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,64,3,1,1),
            nn.ReLU(True),
            nn.MaxPool2d(2))
        self.fc = nn.Sequential(
            nn.Linear(64*8*8,128),
            nn.Linear(128,64),
            nn.Linear(64,n_class))

    def forward(self,x):
        conv1_out = self.conv1(x)
        conv2_out = self.conv2(conv1_out)
        conv2_out = conv2_out.view(conv2_out.size(0),-1)
        out = self.fc(conv2_out)
        return out

model = slice_ssc(1,10)
print model

use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:
    model = model.cuda()
# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
logger = Logger('./logs')
#training
for epoch in range(num_epoches):
    print 'epoch {}'.format(epoch+1)
    train_loss=0.0
    train_acc=0.0

    #==========training============
    for i,data in enumerate(train_loader,1):
        img,label=data
        img=img.view(img.size(0)*3,1,32,32)
        label = torch.cat((label,label,label),0)
        #print img.size()
        #print label.size()
        if use_gpu:
            img = img.cuda()
            label = label.cuda()
        img = Variable(img)
        label = Variable(label)      

        #forward
        out = model(img)
        loss = criterion(out,label)
        train_loss += loss.data[0] #*label.size(0)
        _, pred = torch.max(out,1)
        train_correct = (pred == label).sum()
        accuracy = (pred == label).float().mean()
        train_acc += train_correct.data[0]
        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #=============log===============
        step = epoch*len(train_loader)+i
        info = {'loss':loss.data[0],'accuracy':accuracy.data[0]}   
        for tag, value in info.items():
            logger.scalar_summary(tag, value, step)

        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            logger.histo_summary(tag, to_np(value), step)
            logger.histo_summary(tag + '/grad', to_np(value.grad), step)

        info = {'images': to_np(img.view(-1, 32, 32)[:10])}
        for tag, images in info.items():
            logger.image_summary(tag, images, step)
        if i % 300 == 0:
            print '[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(
                epoch + 1, num_epoches, train_loss / (batch_size * i),
                train_acc / (batch_size * i))

    print 'Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(
        epoch + 1, train_loss / (len(train_dataset)), train_acc / (len(
            train_dataset)))

    #============testing=============
    model.eval()
    eval_loss = 0.0
    eval_acc = 0.0
    for data in test_loader:
        img,label = data
        img=img.view(img.size(0)*3,1,32,32)
        label = torch.cat((label,label,label),0)
        if use_gpu:
            img = Variable(img,volatile=True).cuda()
            label = Variable(label,volatile=True).cuda()
        else:
            img = Variable(img, volatile=True)
            label = Variable(label, volatile=True)
        out = model(img)
        loss = criterion(out, label)
        eval_loss += loss.data[0] * label.size(0)
        _, pred = torch.max(out, 1)
        num_correct = (pred == label).sum()
        eval_acc += num_correct.data[0]
    print 'Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
        test_dataset)), eval_acc / (len(test_dataset)))

# 保存模型
torch.save(model.state_dict(), './cnn.pth')

其中保存log日志的logger.py代码为:

import tensorflow as tf
import numpy as np
import scipy.misc
try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO         # Python 3.x


class Logger(object):

    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = tf.summary.FileWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag,
                                                     simple_value=value)])
        self.writer.add_summary(summary, step)

    def image_summary(self, tag, images, step):
        """Log a list of images."""

        img_summaries = []
        for i, img in enumerate(images):
            # Write the image to a string
            try:
                s = StringIO()
            except:
                s = BytesIO()
            scipy.misc.toimage(img).save(s, format="png")

            # Create an Image object
            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
                                       height=img.shape[0],
                                       width=img.shape[1])
            # Create a Summary value
            img_summaries.append(
                tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

        # Create and write Summary
        summary = tf.Summary(value=img_summaries)
        self.writer.add_summary(summary, step)

    def histo_summary(self, tag, values, step, bins=1000):
        """Log a histogram of the tensor of values."""

        # Create a histogram using numpy
        counts, bin_edges = np.histogram(values, bins=bins)

        # Fill the fields of the histogram proto
        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values**2))

        # Drop the start of the first bin
        bin_edges = bin_edges[1:]

        # Add bin edges and counts
        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        # Create and write Summary
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
        self.writer.add_summary(summary, step)
        self.writer.flush()

**

二、训练自己的数据集:

**

1、Dataset class:

**


torch.utils.data.Dataset :

是一个表达dataset的抽象类,需要继承Dataset类,并进行override,最重要的复写类中的几个函数如下:

(1) __init__ : 读各种格式的数据集、路径等,控制传入参数
(2) __getitem__ : 使dataset[i]能够获得第i个样本数据,即导入具体数据
(3) __len__ : len(dataset) returns the size of the dataset

完整代码实例如下:

def default_loader(path):
    return Image.open(path).convert('RGB')

############# Dataset ############
class myImageFloder(data.Dataset):
    def __init__(self,root,image_path,label_path,transform = None,target_transform = None,loader = default_loader):
        f_img = open(image_path)
        f_label = open(label_path)

        #c = 0
        imgs = []
        img_names = []
        label_names = []

        for line in f_img.readlines():
            cls = line.split()
            img_name = cls.pop(1)
            img_names.append(img_name)

            #read image
            if os.path.isfile(os.path.join(root,img_name)):
                imgs.append((img_name,tuple([float(v) for v in cls])))

        for line in f_label.readlines():
            cls = line.split()
            label_name = cls.pop(1)
            label_names.append(label_name)     

        self.root = root
        self.imgs = imgs
        self.img_names = img_names
        self.lable_names = label_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self,index):
        img_name,label_name = self.imgs[index]
        img = self.loader(os.path.join(self.root,img_name))
        if self.transform is not None:
            img = self.transform(img)
        return img,torch.Tensor(label)

    def __len__(self):
        return len(self.imgs)

**

2.Transform:

**

需要用一些转化函数对输入的图像对做转换变化,常用函数如下:

rescale:scale the image
randomcrop:crop from image randomly,for data augmentation
ToTensor:convert the numpy image to torch image

例如如下完整代码定义:

########### Transform ############
mytransform = transforms.Compose([
    transforms.ToTensor()
    ]
)

**

3.实例化DataLoader:

**

这一步是为了将上面得到的数据做处理:Batch the data、Shuffle the data、load the data in parallel using multiprocessing workers.并且对trainloader、testloader单独进行实例化。

完整代码实例如下:

########## Dataloader ############
trainloader = torch.utils.data.DataLoader(
    myFloder.myImageFloder(root = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images',
                           image_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images_train.txt',
                           label_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/image_class_labels_train.txt',
                           transform = mytransform),
    batch_size = 24,shuffle = True,num_workers = 2)
print("TrainLoader success...")

testloader = torch.utils.data.DataLoader(
    myFloder.myImageFloder(root = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images',
                          image_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images_test.txt',
                          label_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/image_class_labels_test.txt',
                          transform = mytransform),
    batch_size = 24,shuffle = False,num_workers = 2)

print("TestLoader success...")



版权声明:本文为Lucifer_zzq原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。