VAE(变分自编码器) 详解

  • Post author:
  • Post category:其他


近期看论文要用到VAE,看了很多资料,有这样一种感觉,要么过度过于偏向数学原理,要么只是讲了讲网络结构。本文将两者结合,以简洁易懂的语言结合代码实现来介绍VAE。



1 解决问题

VAE是变分推断(variational inference )以及自编码器(Auto-encoder)的组合,是一种非监督的生成模型,其概率图模型和深度学习有机结合,近年来比较火热。VAE可以用于但不局限于降维信息检索等任务,我看文献遇到的是一篇做配准的论文,也用到了VAE。



2 从神经网络理解

先从神经网络的角度去看VAE,VAE实际上是在Atuo-encoder(AE)的变种,其基本架构也如Atuo-encoder,包含两部分encoder(编码器)和decoder(解码器)。大概如下:

在这里插入图片描述

图片摘自:

github


以最开始的自编码器为例,其loss函数一般是输入和输出的MSE,通过调整Encoder输出层的节点数(需要低于输入的维度),我们可以从低维度的数据(code)通过Decoder重建出输入。

自编码器存在这样的问题,倘若模型过完备(中间层维度大于输入),模型会直接复制模型的输入作为输出。

在一般实际使用中,我们往往会添加正则项。

除此之外,还有VAE,可以学习出高容量且过完备(中间层维度大于输入)模型。VAE的网络结构如下:

在这里插入图片描述

本图摘自:李宏毅2020深度学习课程

可以看到VAE和AE的区别在于两方面:

1.中间层引入了一个noise;

2.loss函数的改变多了:



i

=

1

3

(

x

e

p

(

σ

i

(

1

+

σ

i

)

+

(

m

i

2

)

)

)

\sum_{i=1}^{3}(xep(\sigma_{i}-(1+\sigma_{i})+(m_{i}^{2})))



















i


=


1










3



















(


x


e


p


(



σ











i






























(


1




+









σ











i



















)




+








(



m











i










2



















)))





这一项。根据上面的结构,我们基本可以很容易地代码实现。


那么如何直观地理解上面地改变呢?


1.为什么要引入noise?首先直观理解,就算有一个noise也要尽量输入和输出相似,这样的decoder更加鲁棒。另一个直观理解就是,在引入noise之前,我们的decoder的输入和输出地映射可以看作是离散的,但是在加入noise之后,可以看作把不连续地变成连续的了。

2.为什么更改loss?首先假设没有更改,其实最好的情况肯定是方差为0,即



σ

=

0

\sigma=0






σ




=








0





。这就回到了AE的形式。直观地说,多加地这一项避免了这一点。那么怎么要这样更改loss呢?首先公式地前两项的值域大于等于0,最后一项可以看作一个L2正则项。

上面的理解都是直观的,感性的认识。下面一部分将从变分推断的角度推导出所谓的引入noise其实是重参数化技巧的结果,而loss的改变也是推导所得,本质是一个KL散度。如果读者到目前为止还有求知欲,就可以继续往下看。



3 从变分推断理解VAE

设:x:为观测数据,可以看作是样本

y:为隐变量,包含但不限于模型的参数

首先变分推断的核心思想是:因为一般情况下后验概率



p

(

z

x

)

p(z|x)






p


(


z





x


)





是不可求解的,所以变分推断采用了一种迂回的策略,即使用



q

(

z

)

q(z)






q


(


z


)





去近似



p

(

z

x

)

p(z|x)






p


(


z





x


)





。如果读者对生成模型不是很理解,可以把



p

(

z

x

)

p(z|x)






p


(


z





x


)





看作是AE中的编码器,



p

(

x

z

)

p(x|z)






p


(


x





z


)





看作是AE中的解码器。

VAE是典型的生成模型,那就从下面公式开始:





log

p

(

x

)

=

log

p

(

x

,

z

)

p

(

z

x

)

=

log

p

(

x

,

z

)

q

(

z

)

log

p

(

z

x

)

q

(

z

)

\begin{aligned}\log p(x)&=\log \frac{p(x,z)}{p(z|x)} \\ &= \log \frac{p(x,z)}{q(z)}-\log \frac{p(z|x)}{q(z)} \end{aligned}
















lo

g





p


(


x


)



































=




lo

g
















p


(


z





x


)














p


(


x


,




z


)






























=




lo

g
















q


(


z


)














p


(


x


,




z


)



























lo

g
















q


(


z


)














p


(


z





x


)










































两边关于



q

(

z

)

q(z)






q


(


z


)





同时求期望,则:




左边

=

z

q

(

z

)

log

p

(

x

)

d

z

=

log

p

(

x

)

左边=\int_{z}q(z)\log p(x) dz=\log p(x)






左边




=





















z





















q


(


z


)




lo

g





p


(


x


)


d


z




=








lo

g





p


(


x


)










右边

=

z

q

(

z

)

log

p

(

x

,

z

)

q

(

z

)

d

z

z

q

(

z

x

)

log

p

(

z

x

)

q

(

z

)

d

z

=

L

(

q

)

+

K

L

(

q

(

z

)

p

(

z

x

)

)

\begin{aligned} 右边 &= \int_{z} q(z)\log \frac{p(x,z)}{q(z)} dz-\int_{z} q(z|x)\log \frac{p(z|x)}{q(z)} dz \\ &=\mathcal{L}(q)+KL(q(z)||p(z|x)) \end{aligned}
















右边



































=

















z





















q


(


z


)




lo

g
















q


(


z


)














p


(


x


,




z


)




















d


z






















z





















q


(


z





x


)




lo

g
















q


(


z


)














p


(


z





x


)




















d


z












=




L


(


q


)




+




K


L


(


q


(


z


)


∣∣


p


(


z





x


))























其中右边化简后的



L

(

q

)

\mathcal{L}(q)






L


(


q


)





通常被称为变分下界,



K

L

(

q

(

z

)

p

(

z

x

)

)

KL(q(z)||p(z|x))






K


L


(


q


(


z


)


∣∣


p


(


z





x


))





是KL散度,用来衡量两个分布之间相似性。





L

(

q

)

\mathcal{L}(q)






L


(


q


)





做进一步化简,




L

(

q

)

=

z

q

(

z

)

log

p

(

x

,

z

)

q

(

z

)

d

z

=

z

q

(

z

)

log

p

(

x

,

z

)

q

(

z

)

q

(

z

)

p

(

z

)

z

q

(

z

x

)

log

q

(

z

)

p

(

z

)

=

E

q

(

z

)

log

p

(

x

z

)

K

L

(

q

(

z

)

p

(

z

)

)

\begin{aligned} \mathcal{L}(q)&=\int_{z} q(z)\log \frac{p(x,z)}{q(z)} dz \\ &=\int_{z}q(z)\log \frac{p(x,z)q(z)}{q(z)p(z)}-\int_{z}q(z|x)\log{q(z)}{p(z)} \\ &=\mathbb{E_{q(z)}}\log p(x|z) – KL(q(z)||p(z)) \end{aligned}
















L


(


q


)









































=

















z





















q


(


z


)




lo

g
















q


(


z


)














p


(


x


,




z


)




















d


z












=

















z





















q


(


z


)




lo

g
















q


(


z


)


p


(


z


)














p


(


x


,




z


)


q


(


z


)








































z





















q


(


z





x


)




lo

g






q


(


z


)




p


(


z


)













=





E











q


(


z


)





















lo

g





p


(


x





z


)









K


L


(


q


(


z


)


∣∣


p


(


z


))























回到最初的目标,我们要用



q

(

z

)

q(z)






q


(


z


)





去近似



p

(

z

x

)

p(z|x)






p


(


z





x


)





,那么我们优化的目标就是最大化



K

L

(

q

(

z

)

p

(

z

x

)

)

KL(q(z)||p(z|x))






K


L


(


q


(


z


)


∣∣


p


(


z





x


))





。即:




q

(

z

x

)

=

a

r

g

max

q

(

z

)

K

L

(

q

(

z

)

p

(

z

x

)

)

=

a

r

g

max

q

(

z

)

log

p

(

x

)

L

(

q

)

=

a

r

g

max

q

(

z

)

log

p

(

x

)

E

q

(

z

)

log

p

(

x

z

)

+

K

L

(

q

(

z

)

p

(

z

)

)

\begin{aligned} q(z|x)&=arg \max_{q(z)} KL(q(z)||p(z|x)) \\ & = arg\max_{q(z)} \log p(x)-\mathcal{L}(q) \\ & = arg\max_{q(z)} \log p(x)-\mathbb{E_{q(z)}}\log p(x|z) + KL(q(z)||p(z)) \end{aligned}
















q


(


z





x


)









































=




a


r


g













q


(


z


)









max



















K


L


(


q


(


z


)


∣∣


p


(


z





x


))












=




a


r


g













q


(


z


)









max



















lo

g





p


(


x


)









L


(


q


)












=




a


r


g













q


(


z


)









max



















lo

g





p


(


x


)










E











q


(


z


)





















lo

g





p


(


x





z


)




+




K


L


(


q


(


z


)


∣∣


p


(


z


))























看到这里,我们可能会想到。



log

p

(

x

)

E

q

(

z

x

)

log

p

(

x

z

)

\log p(x)-\mathbb{E_{q(z|x)}}\log p(x|z)






lo

g





p


(


x


)














E











q


(


z





x


)





















lo

g





p


(


x





z


)





不就是真实的X预测的X之间的差异吗?我们用MSE或者MAE代替,后面的KL散度根据预测和假设的结果不是可以直接算出来,就已经解释了上一节的网络结构。但是细想一下感觉又不对,因为这里虽然表示为q(z),但是q(z)要近似的是p(z|x),所以这个z必定和x有关,这样的话上面的想法就不对了。

在此之前我们把q(z)表示为



q

(

z

x

)

q(z|x)






q


(


z





x


)





,这时通常要使用

重参数化技巧

对上式进一步变形,现在想要把



q

(

z

x

)

q(z|x)






q


(


z





x


)





中的x的成分消去。那么上面是重参数化技巧呢?先举个例子吧,一个随机变量a服从概率分布N(0,1),那么对于随机变量b=a+m,服从高斯分布N(m,1)。现在我们采样这个b的时候采用这样的策略:

1.从高斯分布N(0,1)中采样得a。

2.取b = a+m。

其实这就是重采样技巧。对于我们的



q

(

z

x

)

q(z|x)






q


(


z





x


)





,我们假设



z

=

g

Φ

(

x

,

ϵ

)

z=g_{\Phi}(x,\epsilon)






z




=









g











Φ



















(


x


,




ϵ


)





,然后



ϵ

\epsilon






ϵ





服从某个分布,记为



p

(

ϵ

)

p(\epsilon)






p


(


ϵ


)





,一般我们假设其服从标准正态分布。那么我们采样



q

(

z

x

)

q(z|x)






q


(


z





x


)





就变成了,先根据



p

(

ϵ

)

p(\epsilon)






p


(


ϵ


)





采样一个



ϵ

i

\epsilon^{i}







ϵ











i













,再根据



z

=

g

Φ

(

x

,

ϵ

)

z=g_{\Phi}(x,\epsilon)






z




=









g











Φ



















(


x


,




ϵ


)





计算出z。根据重采样技巧,我们忘掉之前的结果,重新推导KL(q||p):




K

L

(

q

(

z

x

)

p

(

z

x

)

)

=

log

p

(

x

)

z

q

(

z

)

log

p

(

x

,

z

)

q

(

z

x

)

q

(

z

x

)

p

(

z

)

+

z

q

(

z

x

)

log

q

(

z

x

)

p

(

z

)

=

log

p

(

x

)

z

q

(

z

)

log

p

(

x

z

)

d

z

+

K

L

(

q

(

z

x

)

p

(

z

)

)

=

log

p

(

x

)

ϵ

p

(

ϵ

)

log

p

(

x

g

Φ

(

x

,

ϵ

)

)

d

ϵ

+

K

L

(

q

(

z

x

)

p

(

z

)

)

=

log

p

(

x

)

E

p

(

ϵ

)

log

p

(

x

g

Φ

(

x

,

ϵ

)

)

+

K

L

(

q

(

z

x

)

p

(

z

)

)

\begin{aligned} KL(q(z|x)||p(z|x))&= \log p(x)-\int_{z}q(z)\log \frac{p(x,z)q(z|x)}{q(z|x)p(z)}+\int_{z}q(z|x)\log{q(z|x)}{p(z)} \\ &=\log p(x)-\int_{z}q(z)\log p(x|z)dz+KL(q(z|x)||p(z)) \\ &=\log p(x)-\int_{\epsilon}p(\epsilon)\log p(x|g_{\Phi}(x,\epsilon))d\epsilon+KL(q(z|x)||p(z)) \\ &=\log p(x)-\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))+KL(q(z|x)||p(z)) \end{aligned}
















K


L


(


q


(


z





x


)


∣∣


p


(


z





x


))















































=




lo

g





p


(


x


)






















z





















q


(


z


)




lo

g
















q


(


z





x


)


p


(


z


)














p


(


x


,




z


)


q


(


z





x


)






















+

















z





















q


(


z





x


)




lo

g






q


(


z





x


)




p


(


z


)













=




lo

g





p


(


x


)






















z





















q


(


z


)




lo

g





p


(


x





z


)


d


z




+




K


L


(


q


(


z





x


)


∣∣


p


(


z


))












=




lo

g





p


(


x


)






















ϵ





















p


(


ϵ


)




lo

g





p


(


x






g











Φ



















(


x


,




ϵ


))


d


ϵ




+




K


L


(


q


(


z





x


)


∣∣


p


(


z


))












=




lo

g





p


(


x


)










E











p


(


ϵ


)





















lo

g





p


(


x






g











Φ



















(


x


,




ϵ


))




+




K


L


(


q


(


z





x


)


∣∣


p


(


z


))

























其实整个VAE的构建就是根据上面的等式





g

Φ

(

x

,

ϵ

)

g_{\Phi}(x,\epsilon)







g











Φ



















(


x


,




ϵ


)





不知道是什么,那就用一个神经网络代替。



p

(

x

z

)

p(x|z)






p


(


x





z


)





不知道是什么,也用一个神经网络代替。下面文字叙述一下VAE的前向传播。

1.先从假设的



p

(

ϵ

)

p(\epsilon)






p


(


ϵ


)





中采样一个



ϵ

\epsilon






ϵ





,即上一节网络图中的



e

e






e







2.从假设的encoder中输入x以及



ϵ

\epsilon






ϵ





,输出隐变量



z

z






z





,即上一节网络图中的



c

c






c







3.将隐变量z输入decoder,输出



x

^

\hat{x}













x







^













而这个前向的过程表示在上面公式里,就是



E

p

(

ϵ

)

log

p

(

x

g

Φ

(

x

,

ϵ

)

)

d

ϵ

\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))d\epsilon







E











p


(


ϵ


)





















lo

g





p


(


x






g











Φ



















(


x


,




ϵ


))


d


ϵ





,显然优化这个网络,我们要让



K

L

(

q

(

z

x

)

p

(

z

x

)

)

KL(q(z|x)||p(z|x))






K


L


(


q


(


z





x


)


∣∣


p


(


z





x


))





最小,



log

p

(

x

)

E

p

(

ϵ

)

log

p

(

x

g

Φ

(

x

,

ϵ

)

)

\log p(x)-\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))






lo

g





p


(


x


)














E











p


(


ϵ


)





















lo

g





p


(


x






g











Φ



















(


x


,




ϵ


))





就用MSE表示,



K

L

(

q

(

z

x

)

p

(

z

)

)

KL(q(z|x)||p(z))






K


L


(


q


(


z





x


)


∣∣


p


(


z


))





是可以求出来的。具体的推导也不难,如果假设是高斯分布,即根据多维高斯分布的KL散度,结合我们重采样的q(z|x),推导出来最后的结果就是第一节中图中的公式,这里就省略推导了。

能看到这里的宝贝都很厉害,毕竟我感觉自己也写的不是很清楚,才疏学浅了。不过最困难的部分也过去了,我们不妨看看VAE的pytorch代码实现,看看自己理解的是不是对的。



4 pytorch代码

本文的代码来自

GITHUB

__author__ = 'SherlockLiao'

import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os

if not os.path.exists('./vae_img'):
    os.mkdir('./vae_img')


def to_img(x):
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x


num_epochs = 100
batch_size = 128
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor()
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = MNIST('./data', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE()
if torch.cuda.is_available():
    model.cuda()

reconstruction_function = nn.MSELoss(size_average=False)


def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + KLD


optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img)
        if torch.cuda.is_available():
            img = img.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(img)
        loss = loss_function(recon_batch, img, mu, logvar)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(img),
                len(dataloader.dataset), 100. * batch_idx / len(dataloader),
                loss.data[0] / len(img)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(dataloader.dataset)))
    if epoch % 10 == 0:
        save = to_img(recon_batch.cpu().data)
        save_image(save, './vae_img/image_{}.png'.format(epoch))

torch.save(model.state_dict(), './vae.pth')



版权声明:本文为weixin_42217488原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。