模型训练时gpu内存不足的解决办法

  • Post author:
  • Post category:其他

最近在训练微调bert预训练模型的时候,gpu内存老是不足,跑不了一个epoch就爆掉了,在网上来来回回找了很多资料,这里把一些方法总结一下:

半精度训练

半精度float16比单精度float32占用内存小,计算更快,但是半精度也有不好的地方,它的舍入误差更大,而且在训练的时候有时候会出现nan的情况(我自己训练的时候也遇到过,解决方法可以参考我的另一篇博客)。
模型在gpu上训练,模型和输入数据都要.cuda()一下,转成半精度直接input.half()和model.half() 就行了。
另外,还有混合精度训练,可以参考:https://zhuanlan.zhihu.com/p/103685761

累积梯度

一般我们在训练模型的时候都是一个batch更新一次模型参数,但是在gpu内存不够的时候batchsize就不能设的比较大,但是batchsize比较小又影响模型的性能和训练速度。
这个时候累积梯度的作用就出来了,累积梯度就是让模型累积几个batch的梯度之后再更新参数,相当于变相增大batchsize。具体的实现代码如下:

# 梯度累积,相当于增大batch_size
loss.backward()  # 计算梯度
accumulation_steps =if ((i + 1) % accumulation_steps) == 0:
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # 梯度截断
    optimizer.step()  # 反向传播,更新网络参数
    optimizer.zero_grad()  # 清空梯度

with torch.no_grad()

在每个epoch训练完之后进行验证集测试的时候,将测试部分的代码用with torch.no_grad()包装一下,这样就不会计算梯度占用内存了。
另外,一般验证的时候都会加上model.eval(),这个的作用是对normalization和dropout层起作用的,因为这两个层在训练和测试的时候是不一样的,但是model.eval()好像不会影响梯度计算,因此在加上model.eval()之后还是要再加上with torch.no_grad()

loss.item()

一般我们在训练的时候都会监督loss的变化,如果直接像下面这样写:

epoch_loss += loss

显存占用会逐渐增大,因为loss是requires_grad=True的tensor,是计算图的一部分,是需要计算梯度的。如果在累加损失的时候直接加上loss会让系统认为epoch_loss也是计算图的一部分来计算梯度,造成大量的内存占用
正确的写法应该是这样,用.item()直接提取元素值:

epoch_loss += loss.item()

累加accuracy也是同理:

epoch_acc += acc.item()

del和torch.cuda.empty_cache()

在每个batch计算完成之后,输入、输出、损失、准确率啥的其实都可以删除了,删除之后,使用torch.cuda.empty_cache()释放掉内存,这样可以节省很多内存,毕竟那么多batch累积起来要占用很多内存的。

input,label=batch
output=model(input) 
loss=criteria(output,label)
acc=torch.sum((output>0.5).int()==label)/output.shape[0]

# 梯度累积,相当于增大batch_size
loss.backward()  # 计算梯度
accumulation_steps = 2
if ((i + 1) % accumulation_steps) == 0:
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # 梯度截断
    optimizer.step()  # 反向传播,更新网络参数
    optimizer.zero_grad()  # 清空梯度

epoch_loss += loss.item()
epoch_acc += acc.item()

del input,output,loss,acc    #内存释放
torch.cuda.empty_cache()

参考:https://blog.csdn.net/fish_like_apple/article/details/101448551

原创不易,转载请征得本人同意并注明出处!


版权声明:本文为lppfwl原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。