元学习系列文章
- optimization based meta-learning
-
metric based meta-learning
- 元学习之《Matching Networks for One Shot Learning》代码解读
- model based meta-learning: 待更新…
前言
此篇是 metric-based metalearning 的第一篇,所谓 metric-based 即通过某种度量方式来判断测试样本和训练集中的哪个样本最相似,进而把最相似样本的 label 作为测试样本的 label,总体思想有点类似于 KNN。
Matching Network
此篇论文的核心思想就是构造了一个端到端的最近邻分类器,并通过 meta-learning 的训练,可以使得该分类器在新的少样本任务上快速适应,并对该任务的测试样本进行预测。下图是 Matching Network 的网络结构:
初看论文时看到这个图时会比较懵,以及论文里的各种公式也让人摸不着头脑,但是看作者的代码就能理清楚这里面的结构了,话不多上代码。
def build(self, support_set_image, support_set_label, image):
"""the main graph of matching networks"""
# image [None, 28, 28, 1] -> [None, 1*1*64]
# 1. 原始图片特征提取模块
image_encoded = self.image_encoder(image) # (batch_size, 64)
#[(batch_size, 64), ] list 长度是 n*k
support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]
# 2. Full Context Embeddings 模块
if self.use_fce:
g_embedding = self.fce_g(support_set_image_encoded) # (n * k, batch_size, 64)
f_embedding = self.fce_f(image_encoded, g_embedding) # (batch_size, 64)
else:
g_embedding = tf.stack(support_set_image_encoded) # (n * k, batch_size, 64)
f_embedding = image_encoded # (batch_size, 64)
# c(f(x_hat), g(x_i))
# 3. 距离度量模块
# g 已知 label,f 是 test,未知 label
embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)
# compute softmax on similarity to get a(x_hat, x_i)
# 4. attention 模块
attention = tf.nn.softmax(embeddings_similarity)
# \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i
# [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]
y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))
# [batch_size,1,n] -> [batch_size, n]
self.logits = tf.squeeze(y_hat) # (batch_size, n)
self.pred = tf.argmax(self.logits, 1)
整个网络结构可以分为四个模块:
- 原始图片特征提取模块
- Full Context Embeddings 模块
- 距离度量模块
- attention 模块
特征提取模块
特征提取模块比较简单,就是用一个4层的卷积网络,提取原始图片的全连接层特征,全连接层维度是64,即卷积网络后的输出shape是 [batch_size, 64]。该卷积网络的代码如下:
def image_encoder(self, image):
"""the embedding function for image (potentially f = g)
For omniglot it's a simple 4 layer ConvNet, for mini-imagenet it's VGG or Inception
"""
with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):
net = slim.conv2d(image)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net)
net = slim.max_pool2d(net, [2, 2])
return tf.reshape(net, [-1, 1 * 1 * 64])
Full Context Embeddings 模块
此模块是论文的重点也是创新的地方,即如何对某个抽样任务的训练集样本进行 embedding 得到
g
θ
g_\theta
g
θ
,如何对该任务的测试样本进行 embedding 得到
f
θ
f_\theta
f
θ
。
-
gθ
g_\theta
g
θ
其中对训练样本,是输入到一个双向LSTM网络中,LSTM的前向和后向隐藏层单元数都是 32,LSTM 网络的输出是一个长为
n*k
的list,list中每个元素的shape是 (batch_size,64)。最后将输入embedding和 LSTM output 相加,相加后的结果即是
gθ
g_\theta
g
θ
,相当于做了一个 skip connection的操作。
gθ
g_\theta
g
θ
的实现过程如下:
def fce_g(self, encoded_x_i):
"""the fully conditional embedding function g
This is a bi-directional LSTM, g(x_i, S) = h_i(->) + h_i(<-) + g'(x_i) where g' is the image encoder
For omniglot, this is not used.
encoded_x_i: g'(x_i) in the equation. length n * k list of (batch_size ,64)
"""
fw_cell = rnn.BasicLSTMCell(32) # 32 is half of 64 (output from cnn)
bw_cell = rnn.BasicLSTMCell(32)
# outputs: [(batch_size, 64), (batch_size, 64), ...], list 长度是 n*k
outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(fw_cell, bw_cell, encoded_x_i, dtype=tf.float32)
# [n*k, batch_size, 64] + [n*k, batch_size, 64]
return tf.add(tf.stack(encoded_x_i), tf.stack(outputs))
其中需要注意的是 batch_size 是随机抽样的 batch_size 个task,每个 task 共有
n*k
个训练样本,n值该task是n分类任务,k指每个类别共有k个样本。实际训练时,相当于LSTM网络共有
n*k
个时刻,每个时刻的输入shape都是(batch_size,64),每个时刻的前向输出shape是(batch_size,32),后向输出shape是(batch_size,32)。LSTM的训练过程示意图如下:
-
fθ
f_\theta
f
θ
对测试任务的样本求 embedding 时,同样也是输入到一个LSTM网络中,只不过这个LSTM是有固定步数的单向lstm,共有
processing_steps
步,
processing_steps
可以提取设定。特殊的地方是,在每步的计算中加了 attention 部分,即让上一步的输出状态 h 乘以
g
θ
g_\theta
g
θ
。最后将最后一个时刻 lstm 网络的 softmax 输出作为
f
θ
f_\theta
f
θ
。此部分的代码实现如下:
def fce_f(self, encoded_x, g_embedding):
"""the fully conditional embedding function f
This is just a vanilla LSTM with attention where the input at each time step is constant and the hidden state
is a function of previous hidden state but also a concatenated readout vector.
For omniglot, this is not used.
encoded_x: f'(x_hat) in equation (3) in paper appendix A.1. (batch_size, 64)
g_embedding: g(x_i) in equation (5), (6) in paper appendix A.1. (n * k, batch_size, 64)
"""
cell = rnn.BasicLSTMCell(64)
prev_state = cell.zero_state(self.batch_size, tf.float32) # state[0] is c, state[1] is h
for step in xrange(self.processing_steps):
output, state = cell(encoded_x, prev_state) # output: (batch_size, 64)
h_k = tf.add(output, encoded_x) # (batch_size, 64)
content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding)) # (n * k, batch_size, 64)
r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0) # (batch_size, 64)
prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))
距离度量模块
现在有了
g
θ
g_\theta
g
θ
和
f
θ
f_\theta
f
θ
,其中
g
θ
g_\theta
g
θ
d shape是
(n*k, batch_size, 64)
,
f
θ
f_\theta
f
θ
的shape是
(batch_size,64)
。距离度量模块就是针对每个 task,求出test和train中每个样本的余弦距离,最后输出shape为
(batch_size,n*k)
。余弦相似性的代码实现如下:
def cosine_similarity(self, target, support_set):
"""the c() function that calculate the cosine similarity between (embedded) support set and (embedded) target
note: the author uses one-sided cosine similarity as zergylord said in his repo (zergylord/oneshot)
"""
#target_normed = tf.nn.l2_normalize(target, 1) # (batch_size, 64)
target_normed = target
sup_similarity = []
for i in tf.unstack(support_set):
i_normed = tf.nn.l2_normalize(i, 1) # (batch_size, 64)
similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2)) # (batch_size, )
sup_similarity.append(similarity)
return tf.squeeze(tf.stack(sup_similarity, axis=1)) # (batch_size, n * k)
attention模块
此模块将求出每个测试样本的label。所谓 attention,其实很简单,就是将上一步求出的相似度结果做了 softmax 激活操作,然后将最大值处的train label作为 test label。此模块的代码实现如下:
embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)
# compute softmax on similarity to get a(x_hat, x_i)
# 4. attention 模块
attention = tf.nn.softmax(embeddings_similarity)
# \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i
# [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]
y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))
# [batch_size,1,n] -> [batch_size, n]
self.logits = tf.squeeze(y_hat) # (batch_size, n)
self.pred = tf.argmax(self.logits, 1)
实验结果
参考资料
- https://github.com/markdtw/matching-networks
- https://github.com/karpathy/paper-notes/blob/master/matching_networks.md