背景
深度学习的训练数据往往很多,如果一次性训练所有的数据,不但会导致时间过长,而且训练次数不够,参数也不能得到很好的迭代。为此,将训练数据分成小的batch,一次batch迭代就可以完成一次参数更新,大大提高了训练速度。
Pytorch中有现成的batch生成器,但是为了底层原理的理解,最好自己能够写出这样的代码,就先从能看懂现成代码开始吧。
batch生成器函数
def data_iter(batch_size,features,labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range(0,num_examples,batch_size):
j = torch.LongTensor(indices[i:min(i+batch_size,num_examples)])
yield features.index_select(0,j),labels.index_select(0,j)
这个函数名为batch_iter,参数有三:batch_size(batch的大小)、features(训练数据的特征,可以视为自变量)和labels(训练数据的标签,可以视为因变量)。
首先,num_examples获得features变量的长度,这个值就是训练数据的个数;
接着,利用range函数生成从0到training number-1(num_examples-1)的range,利用list函数将其转为列表,这样,indices就是一个包含从0到num_examples-1的list了;
然后利用random包的shuffle函数对indices的数进行洗牌(实际上就是打乱,然后随机排列);
接下来,range(0, num_examples, batch_size)是从0到num_examples-1,每隔batch_size步长产生一个数,即这里的i分别为0,10,20,…,990;
接着,indices[i:min(i+batch_size,num_examples)]是一个索引操作,表示取出indices这个list里从i到下一个变量值-1的子list,其中min(i+batch_size,num_examples)的作用是防止索引超出最大范围;这样,j就得到了一个值类型为long的tensor,值是indices里索引出来的子list;
最后,yield函数是一个生成器,在这里可以简单的看作是return;features.index_select(0,j)中,0表示的是dim,这个操作从features中索引出了所有行数为j的features,组成一个tensor;labels同理。
这样,调用这个函数之后就可以获得一个训练数据的batch了。
实例演示
import torch
import numpy as np
import random
num_inputs = 2
num_examples = 1000
true_w = [2,-3.4]
true_b = 4.2
features = torch.from_numpy(np.random.normal(0,1,(num_examples,num_inputs)))
labels = true_w[0]*features[:,0]+true_w[1]*features[:,1]+true_b
labels += torch.from_numpy(np.random.normal(0,0.01,size = labels.size()))
def data_iter(batch_size,features,labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range(0,num_examples,batch_size):
j = torch.LongTensor(indices[i:min(i+batch_size,num_examples)])
yield features.index_select(0,j),labels.index_select(0,j)
batch_size = 10
for x,y in data_iter(batch_size,features,labels):
print(x,y) #这里只演示出第一个batch
break
输出结果:
tensor([[ 1.0289, -0.5676],
[ 0.4811, 0.0651],
[-0.7113, -0.7735],
[ 0.5077, 1.5935],
[ 0.5343, 0.8802],
[-1.1659, -1.0234],
[ 0.1249, -0.2690],
[-1.9804, 0.9771],
[-0.5953, -0.0802],
[ 0.2558, -1.0796]], dtype=torch.float64) tensor([ 8.2047, 4.9386, 5.4247, -0.2031, 2.2704, 5.3475, 5.3715, -3.0973,
3.2829, 8.3943], dtype=torch.float64)