应用场景
随着深度学习领域的发展,研究人员发现模型越大训练出来的效果越好,因此模型越来越大成为深度学习领域的一个显著特征。但是越大的模型对设备的要求越高,即需要单卡的算力更强,内存空间更大。当单卡运行不能满足模型的要求时,往往需要多卡甚至多台机器协调工作,共同完成训练工作。但如何协调多卡/多机来完成大模型的训练,是大规模分布式训练所需要解决的问题。
模型并行策略是大规模分布式训练很常见的策略之一。它通过将模型中特定子图中的权值均匀的分配到多张卡上,从而降低了模型对单卡的内存要求。帮助模型顺利运行起来。
什么是模型并行
模型并行是算子层面的并行,它利用某些算子的特性将算子拆分到好多个设备上进行计算。因此并不是网络中所有的算子都可以拆分计算,可以拆分的算子需满足如下特性:
- 可以并行计算的算子
- 算子其中一个输入来自于权值
综上, 模型并行中最主要应用的算子就是matmul算子。(conv算子也满足上述要求,不知道为啥没用,个人猜测可能是CNN网络规模不是很大,单卡可以cover, 没必要进行模型并行)。
理论上只要模型中有一个matmu算子就可以进行模型并行策略,算法示例图如下,左图为一个matmul算子的网络结构,右图为通过模型并行策略将一个matmul算子拆分到两张卡上进行计算,每张卡只需要保存原算子1/2的权值。最后通过通信算子allgather获取其他卡的计算结果,从而使每张卡都可以获取完整的计算结果。
模型并行在网络中的应用
从上图可以看出,只要模型中存在matmul算子就可以使用模型并行策略,但实际的网络中并不推荐这么使用,因为这样的话,一个matmul算子就需要配套一个allgather通信算子,这样极大的扩大了通信开销,从而大大拖慢网络训练的速度。真实的网络中往往并不这么用,而是通过多个matmul算子组合,尽可能少的使用通信算子也能达到一样的目标。常见的组合有两种:
case1:MLP子图
MLP的子图介绍
MLP是transformer网络中的一段子图,其简化后的结构如下:
MLP子图的代码如下
# x是输入, w1和w2是训练的权值
out1 = torch.matmul(x, w1)
out2 = torch.matmul(out1, w2)
模型并行策略拆分MLP子图
子图中有两个matmul,拆分策略如下:
- 第一个matmul的权值按列切分
- 第二个matmul的权值按行切分
-
通过all_reduce通信算子获取完整的输出结果
MLP的拆分策略如下图:
模型并行策略实现的MLP代码示例:
模型并行策略实现的MLP
case2:Attention子图
注意力机制是NLP网络很重要的一个特性。以下是自注意力机制中涉及matmul代码的简易实现:
# bs: batch_size, s: seq_len, h: hidden_size
query = F.linear(x, w_q) # query.shape: [bs,s,h]
key = F.linear(x, w_k)
value = F.linear(x, w_v)
attention_scores = torch.matmul(query, key.permute(0,2,1)) # attention_scores.shape: [bs,s,s]
context_layer = torch.matmul(attention_scores, value) # context_layer.shape: [bs,s,h]
output = F.linear(context_layer, w_o) # output.shape: [bs,s,h]
使用模型并行策略拆分Attention子图:
仔细分析代码,可以发现以上代码是两组MLP子图。因此需要额外添加两组all_reduce通信。具体实现代码如下:
使用模型并行策略拆分Attention 网络子图
模型并行的优化点
优点
-
在模型中实现,不依赖第三方框架平台。大规模分布式训练的另外两种策略:ZeRO数据并行和流水并行。由于其实现的复杂性。它们分别需要借助DeeSpeed和Megatron来实现。而模型并行策略只需要原先的模型中修改即可。减少了学习第三方平台的成本。
缺点
缺点
- 网络中并不是所有节点都可以进行模型并行拆分,因此在内存优化的效果上1+1<< 2。
- 针对特定的子图结构才生效, 不通用。ZeRO数据并行和流水并行是两种通用的大规模分布式训练的策略,适用于任何模型。而模型并行是一种广泛应用于Transformer类网络中策略,且仅适用于其中的Self-Attention和MLP结构,其他网络暂时无法使用。
- 需要保存多份checkpoint。由于模型并行中每个device只保存部分的权值,因此每张卡的权值都需要保存下来。即需要保存mp_size份checkpoint。`