Dataloader 源码分析(二)
Dataloader 组件
Sampler 类
在看 Sampler 的具体实现之前,我们先看看 Dataloader 在什么时候产生 Sampler 对象:
class DataLoader(object):
def __init__(self, ...):
...
if sampler is None:
...
# 如果指定shuffle就使用随机采样,否则使用顺序采样
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# 如果指定了batch_size又没有指定自定义的batch_sampler,就开启自动批采样
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
...
我们可以看到 Sampler 对象的主要职责就是生成用于访问 Dataset 的 index。其中 Sampler 的子类如下:
- SequentialSampler 顺序采样
- RandomSampler 随机采样
- BatchSampler 批采样
实际上还有其他的采样方法,但是因为使用的不多,本文主要讲解上述的三种 Sampler。上述提到的几种采样类都是 Sampler 的子类,Sampler 中的__iter__方法定义为 raise NotImplementedError:
class Sampler(Generic[T_co]):
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
SequentialSampler
SequentialSampler 实现:
class SequentialSampler(Sampler[int]):
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
# 创建一个迭代器
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
这里主要关注__Iter__方法,实际上返回的 index 就是 range(len(self.data_source)) 顺序递增的结果:len(data_source) 实际上就是 Dataset 返回的 samples 的长度。创建迭代器之后,当对这个迭代器调用__next__方法,就会返回 0, 1, 2, 3, 4, … 顺序递增的 index。
RandomSampler
RandomSampler 实现:
class RandomSampler(Sampler[int]):
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
...
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
# replacement 表示是否可以生成重复 index
if self.replacement:
# num_samples 表示一次性采样的数据量
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist(
版权声明:本文为weixin_41670608原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。