近期看论文要用到VAE,看了很多资料,有这样一种感觉,要么过度过于偏向数学原理,要么只是讲了讲网络结构。本文将两者结合,以简洁易懂的语言结合代码实现来介绍VAE。
1 解决问题
VAE是变分推断(variational inference )以及自编码器(Auto-encoder)的组合,是一种非监督的生成模型,其概率图模型和深度学习有机结合,近年来比较火热。VAE可以用于但不局限于降维信息检索等任务,我看文献遇到的是一篇做配准的论文,也用到了VAE。
2 从神经网络理解
先从神经网络的角度去看VAE,VAE实际上是在Atuo-encoder(AE)的变种,其基本架构也如Atuo-encoder,包含两部分encoder(编码器)和decoder(解码器)。大概如下:
图片摘自:
github
以最开始的自编码器为例,其loss函数一般是输入和输出的MSE,通过调整Encoder输出层的节点数(需要低于输入的维度),我们可以从低维度的数据(code)通过Decoder重建出输入。
自编码器存在这样的问题,倘若模型过完备(中间层维度大于输入),模型会直接复制模型的输入作为输出。
在一般实际使用中,我们往往会添加正则项。
除此之外,还有VAE,可以学习出高容量且过完备(中间层维度大于输入)模型。VAE的网络结构如下:
本图摘自:李宏毅2020深度学习课程
可以看到VAE和AE的区别在于两方面:
1.中间层引入了一个noise;
2.loss函数的改变多了:
∑
i
=
1
3
(
x
e
p
(
σ
i
−
(
1
+
σ
i
)
+
(
m
i
2
)
)
)
\sum_{i=1}^{3}(xep(\sigma_{i}-(1+\sigma_{i})+(m_{i}^{2})))
∑
i
=
1
3
(
x
e
p
(
σ
i
−
(
1
+
σ
i
)
+
(
m
i
2
)))
这一项。根据上面的结构,我们基本可以很容易地代码实现。
那么如何直观地理解上面地改变呢?
1.为什么要引入noise?首先直观理解,就算有一个noise也要尽量输入和输出相似,这样的decoder更加鲁棒。另一个直观理解就是,在引入noise之前,我们的decoder的输入和输出地映射可以看作是离散的,但是在加入noise之后,可以看作把不连续地变成连续的了。
2.为什么更改loss?首先假设没有更改,其实最好的情况肯定是方差为0,即
σ
=
0
\sigma=0
σ
=
0
。这就回到了AE的形式。直观地说,多加地这一项避免了这一点。那么怎么要这样更改loss呢?首先公式地前两项的值域大于等于0,最后一项可以看作一个L2正则项。
上面的理解都是直观的,感性的认识。下面一部分将从变分推断的角度推导出所谓的引入noise其实是重参数化技巧的结果,而loss的改变也是推导所得,本质是一个KL散度。如果读者到目前为止还有求知欲,就可以继续往下看。
3 从变分推断理解VAE
设:x:为观测数据,可以看作是样本
y:为隐变量,包含但不限于模型的参数
首先变分推断的核心思想是:因为一般情况下后验概率
p
(
z
∣
x
)
p(z|x)
p
(
z
∣
x
)
是不可求解的,所以变分推断采用了一种迂回的策略,即使用
q
(
z
)
q(z)
q
(
z
)
去近似
p
(
z
∣
x
)
p(z|x)
p
(
z
∣
x
)
。如果读者对生成模型不是很理解,可以把
p
(
z
∣
x
)
p(z|x)
p
(
z
∣
x
)
看作是AE中的编码器,
p
(
x
∣
z
)
p(x|z)
p
(
x
∣
z
)
看作是AE中的解码器。
VAE是典型的生成模型,那就从下面公式开始:
log
p
(
x
)
=
log
p
(
x
,
z
)
p
(
z
∣
x
)
=
log
p
(
x
,
z
)
q
(
z
)
−
log
p
(
z
∣
x
)
q
(
z
)
\begin{aligned}\log p(x)&=\log \frac{p(x,z)}{p(z|x)} \\ &= \log \frac{p(x,z)}{q(z)}-\log \frac{p(z|x)}{q(z)} \end{aligned}
lo
g
p
(
x
)
=
lo
g
p
(
z
∣
x
)
p
(
x
,
z
)
=
lo
g
q
(
z
)
p
(
x
,
z
)
−
lo
g
q
(
z
)
p
(
z
∣
x
)
两边关于
q
(
z
)
q(z)
q
(
z
)
同时求期望,则:
左边
=
∫
z
q
(
z
)
log
p
(
x
)
d
z
=
log
p
(
x
)
左边=\int_{z}q(z)\log p(x) dz=\log p(x)
左边
=
∫
z
q
(
z
)
lo
g
p
(
x
)
d
z
=
lo
g
p
(
x
)
右边
=
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
)
d
z
−
∫
z
q
(
z
∣
x
)
log
p
(
z
∣
x
)
q
(
z
)
d
z
=
L
(
q
)
+
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
\begin{aligned} 右边 &= \int_{z} q(z)\log \frac{p(x,z)}{q(z)} dz-\int_{z} q(z|x)\log \frac{p(z|x)}{q(z)} dz \\ &=\mathcal{L}(q)+KL(q(z)||p(z|x)) \end{aligned}
右边
=
∫
z
q
(
z
)
lo
g
q
(
z
)
p
(
x
,
z
)
d
z
−
∫
z
q
(
z
∣
x
)
lo
g
q
(
z
)
p
(
z
∣
x
)
d
z
=
L
(
q
)
+
K
L
(
q
(
z
)
∣∣
p
(
z
∣
x
))
其中右边化简后的
L
(
q
)
\mathcal{L}(q)
L
(
q
)
通常被称为变分下界,
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
KL(q(z)||p(z|x))
K
L
(
q
(
z
)
∣∣
p
(
z
∣
x
))
是KL散度,用来衡量两个分布之间相似性。
对
L
(
q
)
\mathcal{L}(q)
L
(
q
)
做进一步化简,
L
(
q
)
=
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
)
d
z
=
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
)
q
(
z
)
p
(
z
)
−
∫
z
q
(
z
∣
x
)
log
q
(
z
)
p
(
z
)
=
E
q
(
z
)
log
p
(
x
∣
z
)
−
K
L
(
q
(
z
)
∣
∣
p
(
z
)
)
\begin{aligned} \mathcal{L}(q)&=\int_{z} q(z)\log \frac{p(x,z)}{q(z)} dz \\ &=\int_{z}q(z)\log \frac{p(x,z)q(z)}{q(z)p(z)}-\int_{z}q(z|x)\log{q(z)}{p(z)} \\ &=\mathbb{E_{q(z)}}\log p(x|z) – KL(q(z)||p(z)) \end{aligned}
L
(
q
)
=
∫
z
q
(
z
)
lo
g
q
(
z
)
p
(
x
,
z
)
d
z
=
∫
z
q
(
z
)
lo
g
q
(
z
)
p
(
z
)
p
(
x
,
z
)
q
(
z
)
−
∫
z
q
(
z
∣
x
)
lo
g
q
(
z
)
p
(
z
)
=
E
q
(
z
)
lo
g
p
(
x
∣
z
)
−
K
L
(
q
(
z
)
∣∣
p
(
z
))
回到最初的目标,我们要用
q
(
z
)
q(z)
q
(
z
)
去近似
p
(
z
∣
x
)
p(z|x)
p
(
z
∣
x
)
,那么我们优化的目标就是最大化
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
KL(q(z)||p(z|x))
K
L
(
q
(
z
)
∣∣
p
(
z
∣
x
))
。即:
q
(
z
∣
x
)
=
a
r
g
max
q
(
z
)
K
L
(
q
(
z
)
∣
∣
p
(
z
∣
x
)
)
=
a
r
g
max
q
(
z
)
log
p
(
x
)
−
L
(
q
)
=
a
r
g
max
q
(
z
)
log
p
(
x
)
−
E
q
(
z
)
log
p
(
x
∣
z
)
+
K
L
(
q
(
z
)
∣
∣
p
(
z
)
)
\begin{aligned} q(z|x)&=arg \max_{q(z)} KL(q(z)||p(z|x)) \\ & = arg\max_{q(z)} \log p(x)-\mathcal{L}(q) \\ & = arg\max_{q(z)} \log p(x)-\mathbb{E_{q(z)}}\log p(x|z) + KL(q(z)||p(z)) \end{aligned}
q
(
z
∣
x
)
=
a
r
g
q
(
z
)
max
K
L
(
q
(
z
)
∣∣
p
(
z
∣
x
))
=
a
r
g
q
(
z
)
max
lo
g
p
(
x
)
−
L
(
q
)
=
a
r
g
q
(
z
)
max
lo
g
p
(
x
)
−
E
q
(
z
)
lo
g
p
(
x
∣
z
)
+
K
L
(
q
(
z
)
∣∣
p
(
z
))
看到这里,我们可能会想到。
log
p
(
x
)
−
E
q
(
z
∣
x
)
log
p
(
x
∣
z
)
\log p(x)-\mathbb{E_{q(z|x)}}\log p(x|z)
lo
g
p
(
x
)
−
E
q
(
z
∣
x
)
lo
g
p
(
x
∣
z
)
不就是真实的X预测的X之间的差异吗?我们用MSE或者MAE代替,后面的KL散度根据预测和假设的结果不是可以直接算出来,就已经解释了上一节的网络结构。但是细想一下感觉又不对,因为这里虽然表示为q(z),但是q(z)要近似的是p(z|x),所以这个z必定和x有关,这样的话上面的想法就不对了。
在此之前我们把q(z)表示为
q
(
z
∣
x
)
q(z|x)
q
(
z
∣
x
)
,这时通常要使用
重参数化技巧
对上式进一步变形,现在想要把
q
(
z
∣
x
)
q(z|x)
q
(
z
∣
x
)
中的x的成分消去。那么上面是重参数化技巧呢?先举个例子吧,一个随机变量a服从概率分布N(0,1),那么对于随机变量b=a+m,服从高斯分布N(m,1)。现在我们采样这个b的时候采用这样的策略:
1.从高斯分布N(0,1)中采样得a。
2.取b = a+m。
其实这就是重采样技巧。对于我们的
q
(
z
∣
x
)
q(z|x)
q
(
z
∣
x
)
,我们假设
z
=
g
Φ
(
x
,
ϵ
)
z=g_{\Phi}(x,\epsilon)
z
=
g
Φ
(
x
,
ϵ
)
,然后
ϵ
\epsilon
ϵ
服从某个分布,记为
p
(
ϵ
)
p(\epsilon)
p
(
ϵ
)
,一般我们假设其服从标准正态分布。那么我们采样
q
(
z
∣
x
)
q(z|x)
q
(
z
∣
x
)
就变成了,先根据
p
(
ϵ
)
p(\epsilon)
p
(
ϵ
)
采样一个
ϵ
i
\epsilon^{i}
ϵ
i
,再根据
z
=
g
Φ
(
x
,
ϵ
)
z=g_{\Phi}(x,\epsilon)
z
=
g
Φ
(
x
,
ϵ
)
计算出z。根据重采样技巧,我们忘掉之前的结果,重新推导KL(q||p):
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
)
)
=
log
p
(
x
)
−
∫
z
q
(
z
)
log
p
(
x
,
z
)
q
(
z
∣
x
)
q
(
z
∣
x
)
p
(
z
)
+
∫
z
q
(
z
∣
x
)
log
q
(
z
∣
x
)
p
(
z
)
=
log
p
(
x
)
−
∫
z
q
(
z
)
log
p
(
x
∣
z
)
d
z
+
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
=
log
p
(
x
)
−
∫
ϵ
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
d
ϵ
+
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
=
log
p
(
x
)
−
E
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
+
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
\begin{aligned} KL(q(z|x)||p(z|x))&= \log p(x)-\int_{z}q(z)\log \frac{p(x,z)q(z|x)}{q(z|x)p(z)}+\int_{z}q(z|x)\log{q(z|x)}{p(z)} \\ &=\log p(x)-\int_{z}q(z)\log p(x|z)dz+KL(q(z|x)||p(z)) \\ &=\log p(x)-\int_{\epsilon}p(\epsilon)\log p(x|g_{\Phi}(x,\epsilon))d\epsilon+KL(q(z|x)||p(z)) \\ &=\log p(x)-\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))+KL(q(z|x)||p(z)) \end{aligned}
K
L
(
q
(
z
∣
x
)
∣∣
p
(
z
∣
x
))
=
lo
g
p
(
x
)
−
∫
z
q
(
z
)
lo
g
q
(
z
∣
x
)
p
(
z
)
p
(
x
,
z
)
q
(
z
∣
x
)
+
∫
z
q
(
z
∣
x
)
lo
g
q
(
z
∣
x
)
p
(
z
)
=
lo
g
p
(
x
)
−
∫
z
q
(
z
)
lo
g
p
(
x
∣
z
)
d
z
+
K
L
(
q
(
z
∣
x
)
∣∣
p
(
z
))
=
lo
g
p
(
x
)
−
∫
ϵ
p
(
ϵ
)
lo
g
p
(
x
∣
g
Φ
(
x
,
ϵ
))
d
ϵ
+
K
L
(
q
(
z
∣
x
)
∣∣
p
(
z
))
=
lo
g
p
(
x
)
−
E
p
(
ϵ
)
lo
g
p
(
x
∣
g
Φ
(
x
,
ϵ
))
+
K
L
(
q
(
z
∣
x
)
∣∣
p
(
z
))
其实整个VAE的构建就是根据上面的等式
g
Φ
(
x
,
ϵ
)
g_{\Phi}(x,\epsilon)
g
Φ
(
x
,
ϵ
)
不知道是什么,那就用一个神经网络代替。
p
(
x
∣
z
)
p(x|z)
p
(
x
∣
z
)
不知道是什么,也用一个神经网络代替。下面文字叙述一下VAE的前向传播。
1.先从假设的
p
(
ϵ
)
p(\epsilon)
p
(
ϵ
)
中采样一个
ϵ
\epsilon
ϵ
,即上一节网络图中的
e
e
e
。
2.从假设的encoder中输入x以及
ϵ
\epsilon
ϵ
,输出隐变量
z
z
z
,即上一节网络图中的
c
c
c
。
3.将隐变量z输入decoder,输出
x
^
\hat{x}
x
^
。
而这个前向的过程表示在上面公式里,就是
E
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
d
ϵ
\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))d\epsilon
E
p
(
ϵ
)
lo
g
p
(
x
∣
g
Φ
(
x
,
ϵ
))
d
ϵ
,显然优化这个网络,我们要让
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
∣
x
)
)
KL(q(z|x)||p(z|x))
K
L
(
q
(
z
∣
x
)
∣∣
p
(
z
∣
x
))
最小,
log
p
(
x
)
−
E
p
(
ϵ
)
log
p
(
x
∣
g
Φ
(
x
,
ϵ
)
)
\log p(x)-\mathbb{E_{p(\epsilon)}}\log p(x|g_{\Phi}(x,\epsilon))
lo
g
p
(
x
)
−
E
p
(
ϵ
)
lo
g
p
(
x
∣
g
Φ
(
x
,
ϵ
))
就用MSE表示,
K
L
(
q
(
z
∣
x
)
∣
∣
p
(
z
)
)
KL(q(z|x)||p(z))
K
L
(
q
(
z
∣
x
)
∣∣
p
(
z
))
是可以求出来的。具体的推导也不难,如果假设是高斯分布,即根据多维高斯分布的KL散度,结合我们重采样的q(z|x),推导出来最后的结果就是第一节中图中的公式,这里就省略推导了。
能看到这里的宝贝都很厉害,毕竟我感觉自己也写的不是很清楚,才疏学浅了。不过最困难的部分也过去了,我们不妨看看VAE的pytorch代码实现,看看自己理解的是不是对的。
4 pytorch代码
本文的代码来自
GITHUB
__author__ = 'SherlockLiao'
import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os
if not os.path.exists('./vae_img'):
os.mkdir('./vae_img')
def to_img(x):
x = x.clamp(0, 1)
x = x.view(x.size(0), 1, 28, 28)
return x
num_epochs = 100
batch_size = 128
learning_rate = 1e-3
img_transform = transforms.Compose([
transforms.ToTensor()
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = MNIST('./data', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
if torch.cuda.is_available():
eps = torch.cuda.FloatTensor(std.size()).normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return self.decode(z), mu, logvar
model = VAE()
if torch.cuda.is_available():
model.cuda()
reconstruction_function = nn.MSELoss(size_average=False)
def loss_function(recon_x, x, mu, logvar):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
BCE = reconstruction_function(recon_x, x) # mse loss
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return BCE + KLD
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, data in enumerate(dataloader):
img, _ = data
img = img.view(img.size(0), -1)
img = Variable(img)
if torch.cuda.is_available():
img = img.cuda()
optimizer.zero_grad()
recon_batch, mu, logvar = model(img)
loss = loss_function(recon_batch, img, mu, logvar)
loss.backward()
train_loss += loss.data[0]
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * len(img),
len(dataloader.dataset), 100. * batch_idx / len(dataloader),
loss.data[0] / len(img)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(dataloader.dataset)))
if epoch % 10 == 0:
save = to_img(recon_batch.cpu().data)
save_image(save, './vae_img/image_{}.png'.format(epoch))
torch.save(model.state_dict(), './vae.pth')