最开始入坑的时候使用的是caffe,前一段时间换了使用主流框架的keras(Tensorflow as backward),但是keras确实封装得太好了,一个高级的API对于我这种编程渣渣来说反而上手有些不习惯,在写了一段时间的代码以后开始使用pytorch(反正老板要求了两个框架都要熟练那就都学啦),对于源码部分确实友好了很多,尽管需要自己定义前向过程但是也很简单啦~
先给两个github上非常友好的tutorials(如果觉得官方文档看着上手太慢了的话):
https://github.com/SherlockLiao/pytorch-beginner
https://github.com/hunkim/PyTorchZeroToAll
官方参考资料:
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/#torchvisiondatasets
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
**
一、训练torchvision自带数据集:
**
搭建网络、训练torchvision里面自带的数据集都是easy stuff,这个tutorials尽量记录我在实际代码中遇到的稍微要费点儿精力的事情。
首先是如何训练数据集,如果训练torchvision里自带的数据集非常简单,只需要使用torchvision.datasets直接进行读取,再实例化torch.utils.data.DataLoader(规定好batch_size以及是否进行shuffle),在训练时使用enumerate枚举函数导入数据,也可以用以下代码查看是否导入数据成功显示图片:
for i, data in enumerate(dataLoader, 0):
print(data[i][0])
# PIL
img = transforms.ToPILImage()(data[i][0])
img.show()
break
完整的代码如下:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from logger import Logger
# 定义超参数
batch_size = 128
learning_rate = 1e-2
num_epoches = 20
def to_np(x):
return x.cpu().data.numpy()
# download datasets
train_dataset = datasets.CIFAR10(
root='./cifar_data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.CIFAR10(
root='./cifar_data', train=False, transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
#define model
class slice_ssc(nn.Module):
def __init__(self,in_channel,n_class):
super(slice_ssc,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel,32,3,1,1),
nn.ReLU(True),
nn.MaxPool2d(2))
self.conv2 = nn.Sequential(
nn.Conv2d(32,64,3,1,1),
nn.ReLU(True),
nn.MaxPool2d(2))
self.fc = nn.Sequential(
nn.Linear(64*8*8,128),
nn.Linear(128,64),
nn.Linear(64,n_class))
def forward(self,x):
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv2_out = conv2_out.view(conv2_out.size(0),-1)
out = self.fc(conv2_out)
return out
model = slice_ssc(1,10)
print model
use_gpu = torch.cuda.is_available() # 判断是否有GPU加速
if use_gpu:
model = model.cuda()
# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
logger = Logger('./logs')
#training
for epoch in range(num_epoches):
print 'epoch {}'.format(epoch+1)
train_loss=0.0
train_acc=0.0
#==========training============
for i,data in enumerate(train_loader,1):
img,label=data
img=img.view(img.size(0)*3,1,32,32)
label = torch.cat((label,label,label),0)
#print img.size()
#print label.size()
if use_gpu:
img = img.cuda()
label = label.cuda()
img = Variable(img)
label = Variable(label)
#forward
out = model(img)
loss = criterion(out,label)
train_loss += loss.data[0] #*label.size(0)
_, pred = torch.max(out,1)
train_correct = (pred == label).sum()
accuracy = (pred == label).float().mean()
train_acc += train_correct.data[0]
#backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
#=============log===============
step = epoch*len(train_loader)+i
info = {'loss':loss.data[0],'accuracy':accuracy.data[0]}
for tag, value in info.items():
logger.scalar_summary(tag, value, step)
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
logger.histo_summary(tag, to_np(value), step)
logger.histo_summary(tag + '/grad', to_np(value.grad), step)
info = {'images': to_np(img.view(-1, 32, 32)[:10])}
for tag, images in info.items():
logger.image_summary(tag, images, step)
if i % 300 == 0:
print '[{}/{}] Loss: {:.6f}, Acc: {:.6f}'.format(
epoch + 1, num_epoches, train_loss / (batch_size * i),
train_acc / (batch_size * i))
print 'Finish {} epoch, Loss: {:.6f}, Acc: {:.6f}'.format(
epoch + 1, train_loss / (len(train_dataset)), train_acc / (len(
train_dataset)))
#============testing=============
model.eval()
eval_loss = 0.0
eval_acc = 0.0
for data in test_loader:
img,label = data
img=img.view(img.size(0)*3,1,32,32)
label = torch.cat((label,label,label),0)
if use_gpu:
img = Variable(img,volatile=True).cuda()
label = Variable(label,volatile=True).cuda()
else:
img = Variable(img, volatile=True)
label = Variable(label, volatile=True)
out = model(img)
loss = criterion(out, label)
eval_loss += loss.data[0] * label.size(0)
_, pred = torch.max(out, 1)
num_correct = (pred == label).sum()
eval_acc += num_correct.data[0]
print 'Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_dataset)), eval_acc / (len(test_dataset)))
# 保存模型
torch.save(model.state_dict(), './cnn.pth')
其中保存log日志的logger.py代码为:
import tensorflow as tf
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
class Logger(object):
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag,
simple_value=value)])
self.writer.add_summary(summary, step)
def image_summary(self, tag, images, step):
"""Log a list of images."""
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(
tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)
# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))
# Drop the start of the first bin
bin_edges = bin_edges[1:]
# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)
# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
**
二、训练自己的数据集:
**
1、Dataset class:
**
torch.utils.data.Dataset :
是一个表达dataset的抽象类,需要继承Dataset类,并进行override,最重要的复写类中的几个函数如下:
(1) __init__ : 读各种格式的数据集、路径等,控制传入参数
(2) __getitem__ : 使dataset[i]能够获得第i个样本数据,即导入具体数据
(3) __len__ : len(dataset) returns the size of the dataset
完整代码实例如下:
def default_loader(path):
return Image.open(path).convert('RGB')
############# Dataset ############
class myImageFloder(data.Dataset):
def __init__(self,root,image_path,label_path,transform = None,target_transform = None,loader = default_loader):
f_img = open(image_path)
f_label = open(label_path)
#c = 0
imgs = []
img_names = []
label_names = []
for line in f_img.readlines():
cls = line.split()
img_name = cls.pop(1)
img_names.append(img_name)
#read image
if os.path.isfile(os.path.join(root,img_name)):
imgs.append((img_name,tuple([float(v) for v in cls])))
for line in f_label.readlines():
cls = line.split()
label_name = cls.pop(1)
label_names.append(label_name)
self.root = root
self.imgs = imgs
self.img_names = img_names
self.lable_names = label_names
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self,index):
img_name,label_name = self.imgs[index]
img = self.loader(os.path.join(self.root,img_name))
if self.transform is not None:
img = self.transform(img)
return img,torch.Tensor(label)
def __len__(self):
return len(self.imgs)
**
2.Transform:
**
需要用一些转化函数对输入的图像对做转换变化,常用函数如下:
rescale:scale the image
randomcrop:crop from image randomly,for data augmentation
ToTensor:convert the numpy image to torch image
例如如下完整代码定义:
########### Transform ############
mytransform = transforms.Compose([
transforms.ToTensor()
]
)
**
3.实例化DataLoader:
**
这一步是为了将上面得到的数据做处理:Batch the data、Shuffle the data、load the data in parallel using multiprocessing workers.并且对trainloader、testloader单独进行实例化。
完整代码实例如下:
########## Dataloader ############
trainloader = torch.utils.data.DataLoader(
myFloder.myImageFloder(root = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images',
image_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images_train.txt',
label_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/image_class_labels_train.txt',
transform = mytransform),
batch_size = 24,shuffle = True,num_workers = 2)
print("TrainLoader success...")
testloader = torch.utils.data.DataLoader(
myFloder.myImageFloder(root = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images',
image_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/images_test.txt',
label_path = '/home/zzq/Distillation/Datasets/bird_classification-master/data/image_class_labels_test.txt',
transform = mytransform),
batch_size = 24,shuffle = False,num_workers = 2)
print("TestLoader success...")