转载地址:
https://bbs.huaweicloud.com/forum/thread-137101-1-1.html
作者: 雨丝儿
最近在参加华为与高校合做开发mindspore模型的活动,使用mindspore开发了SRGAN模型,下面几篇帖子想针对SRGAN做一些自己的经验分享。这篇帖子分享SRAGAN loss pytorch的实现。
Pytorch版本参考:https://github.com/dongheehand/SRGAN-PyTorch
Paper中SRGAN的loss:
对于Discriminator:
就是基础GAN中Discriminator的loss
代码实现:
其中gt为原始高分辨率图像,lr为gt经过双三次插值缩小四倍的低分辨率图像,cross_ent为BCELoss()
对与Generator:
Generator的loss包含三部分,一是基础的MSELoss,二是adversarial loss,三是将生成的HR图像与原始高清分辨率图像分别经过预训练的vgg19提取特征后,计算MSELoss.
代码部分:
VGG_loss = perceptual_loss(vgg_net)
cross_ent = nn.BCELoss()
tv_loss = TVLoss()
real_label = torch.ones((args.batch_size, 1)).to(device)
fake_label = torch.zeros((args.batch_size, 1)).to(device)
for i, tr_data in enumerate(loader):
gt = tr_data[‘GT’].to(device)
lr = tr_data[‘LR’].to(device)
output, _ = generator(lr)
fake_prob = discriminator(output)
# 第一部分
L2_loss = l2_loss(output, gt)
# 第二部分
adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)
# 第三部分
_percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer)
percep_loss = args.vgg_rescale_coeff * _percep_loss
g_loss = L2_loss + adversarial_loss + percep_loss
g_optim.zero_grad()
d_optim.zero_grad()
g_loss.backward()
g_optim.step()
其中vgg19是在imagenet上训练好的vgg19,选取其前37层,args.adv_coeff,args.vgg_rescale_coeff 为loss的系数,分别取0.003和0.006。
以上就是srgan loss部分的pytorch代码实现,下篇帖子将分享srgan loss部分minspore代码的实现。