STANet网络复现–可运行/持续更新
前言
STANet作为入门的第一篇研究文章,github上已有开源代码,也有优秀的博主对文章内容做了讲解,但是,新手实在是太菜了。下面对遇到的几个问题做总结,并给出成功运行代码的方法,特别是BAM/PAM内存溢出问题的解决。
一、out of memory?
RuntimeError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 6.00 GiB total capacity; 2.98 GiB already allocated; 1.78 GiB free; 2.89 MiB cached)
如果你感到上述报错非常熟悉,那你就看对文章了。
分析可知主要是attention矩阵维度过大,导致了内存溢出,目前找到的解决办法有两种:
1.ds参数设置
不是batch size,这里的是代码注意力机制默认的参数ds,代码默认为1,想要跑通需要设置为2/4/8。
具体的bam运行代码为:
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFA0 --lr 0.001 --model CDFA --SA_mode BAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop --ds 4
具体的pam运行代码为:
python ./train.py --save_epoch_freq 1 --angle 15 --dataroot ./LEVIR-CD/train --val_dataroot ./LEVIR-CD/val --name LEVIR-CDFAp0 --lr 0.001 --model CDFA --SA_mode PAM --batch_size 8 --load_size 256 --crop_size 256 --preprocess rotate_and_crop --ds 4
若显存仍旧不足,请同时降低bz和ds(笔者同设置为2则可以高精度跑通)
2.数据集分割
这里你肯定会疑惑?代码运行的时候不是做了数据集分割?
在验证过程笔者发现有时候数据集可能没有被分割,所以保险起见进行了分割以后直接参与运行
分割数据集的代码(1024×1024分割为16张64*64的图片):
from PIL import Image
import os
def image_crop(data_dir,save_dir):
if not os.path.exists(save_dir):
os.mkdir(save_dir)
path=os.path.join(data_dir)
img_list=os.listdir(path)
for img in img_list:
a=0
# if img.endswith('.png') or img.endswith('.jpg' ):
if img.endswith('.png'):
img_name = path+'/' +img
im=Image.open(img_name)
#这里是将[1024, 1024]裁剪为[64,64]
for i in range(4):
for j in range(4):
x=i*256
y=j*256
region=im.crop((x,y,x+256,y+256))
region.save(save_dir+'/'+img.split('-')[0]+'_'+str(a)+'.png')
a+=1
data_dir=r"E:/this desk/STANet-master/LEVIR-CD/val/label"#原数据集位置
save_dir=r"E:/this desk/STANet-master/LEVIR-CD1/val/label"#新数据集位置
image_crop(data_dir,save_dir)
二、原因分析
1.导致oom的原因
我们来看网络的主要结构,如下:
在softmax之前和之后,分别两次
O
(
N
2
)
{\rm{O}}\left( {
{N^2}} \right)
O
(
N
2
)
级别的矩阵乘法,自然吃掉了极大部分的显存。
2.ds参数是什么
代码中:ds表示控制输入下采样(从灰色到蓝橙绿)
代码如下(示例):
class _PAMBlock(nn.Module):
'''
The basic implementation for self-attention block/non-local block
Input/Output:
N * C * H * (2*W)
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
value_channels : the dimension after the value transform
scale : choose the scale to partition the input feature maps
ds : downsampling scale
自我注意块/非局部块的基本实现
输入/输出:
N * C * H * (2*W)
参数:
in_channels:输入特征映射的维度
key_channels:键/查询转换之后的维度
value_c、channels:值转换之后的维度
比例尺:选择比例尺对输入的特征图进行划分
ds:下采样尺度
'''
# 进行定义
def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
super(_PAMBlock, self).__init__()
self.scale = scale
self.ds = ds
self.pool = nn.AvgPool2d(self.ds)#
总结
雪儿妹妹的求学路第一篇到此结束,谢谢大家支持,互相学习~~