Pytorch Dataloader 模块源码分析(二):Sampler / Fetcher 组件及 Dataloader 核心代码

  • Post author:
  • Post category:其他




Dataloader 组件



Sampler 类

在看 Sampler 的具体实现之前,我们先看看 Dataloader 在什么时候产生 Sampler 对象:

class DataLoader(object):
    def __init__(self, ...):
        ...
        if sampler is None:  
            ...
             # 如果指定shuffle就使用随机采样,否则使用顺序采样
                if shuffle: 
                    sampler = RandomSampler(dataset, generator=generator)
                else:
                    sampler = SequentialSampler(dataset)

        if batch_size is not None and batch_sampler is None:
            # 如果指定了batch_size又没有指定自定义的batch_sampler,就开启自动批采样
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        ...

我们可以看到 Sampler 对象的主要职责就是生成用于访问 Dataset 的 index。其中 Sampler 的子类如下:

  • SequentialSampler 顺序采样
  • RandomSampler 随机采样
  • BatchSampler 批采样

实际上还有其他的采样方法,但是因为使用的不多,本文主要讲解上述的三种 Sampler。上述提到的几种采样类都是 Sampler 的子类,Sampler 中的__iter__方法定义为 raise NotImplementedError:

class Sampler(Generic[T_co]):
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError



SequentialSampler

SequentialSampler 实现:

class SequentialSampler(Sampler[int]):
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
    	# 创建一个迭代器
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

这里主要关注__Iter__方法,实际上返回的 index 就是 range(len(self.data_source)) 顺序递增的结果:len(data_source) 实际上就是 Dataset 返回的 samples 的长度。创建迭代器之后,当对这个迭代器调用__next__方法,就会返回 0, 1, 2, 3, 4, … 顺序递增的 index。



RandomSampler

RandomSampler 实现:

class RandomSampler(Sampler[int]):
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator    
		...
    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)
        if self.generator is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator
		# replacement 表示是否可以生成重复 index
        if self.replacement:
        	# num_samples 表示一次性采样的数据量
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist(



版权声明:本文为weixin_41670608原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。