元学习之《Matching Networks for One Shot Learning》代码解读

  • Post author:
  • Post category:其他



元学习系列文章

  1. optimization based meta-learning


    1. 《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》 论文翻译笔记

    2. 元学习方向 optimization based meta learning 之 MAML论文详细解读

    3. MAML 源代码解释说明 (一)

    4. MAML 源代码解释说明 (二)

    5. 元学习之《On First-Order Meta-Learning Algorithms》论文详细解读

    6. 元学习之《OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING》论文详细解读
  2. metric based meta-learning

    1. 元学习之《Matching Networks for One Shot Learning》代码解读
  3. model based meta-learning: 待更新…



前言

此篇是 metric-based metalearning 的第一篇,所谓 metric-based 即通过某种度量方式来判断测试样本和训练集中的哪个样本最相似,进而把最相似样本的 label 作为测试样本的 label,总体思想有点类似于 KNN。



Matching Network

此篇论文的核心思想就是构造了一个端到端的最近邻分类器,并通过 meta-learning 的训练,可以使得该分类器在新的少样本任务上快速适应,并对该任务的测试样本进行预测。下图是 Matching Network 的网络结构:

在这里插入图片描述

初看论文时看到这个图时会比较懵,以及论文里的各种公式也让人摸不着头脑,但是看作者的代码就能理清楚这里面的结构了,话不多上代码。

    def build(self, support_set_image, support_set_label, image):
        """the main graph of matching networks"""
        # image [None, 28, 28, 1] -> [None, 1*1*64]
        # 1. 原始图片特征提取模块
        image_encoded = self.image_encoder(image)   # (batch_size, 64)
        #[(batch_size, 64), ] list 长度是 n*k
        support_set_image_encoded = [self.image_encoder(i) for i in tf.unstack(support_set_image, axis=1)]

        # 2. Full Context Embeddings 模块
        if self.use_fce:
            g_embedding = self.fce_g(support_set_image_encoded)     # (n * k, batch_size, 64)
            f_embedding = self.fce_f(image_encoded, g_embedding)    # (batch_size, 64)
        else:
            g_embedding = tf.stack(support_set_image_encoded)       # (n * k, batch_size, 64)
            f_embedding = image_encoded                             # (batch_size, 64)

        # c(f(x_hat), g(x_i))
        # 3. 距离度量模块
        # g 已知 label,f 是 test,未知 label
        embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)

        # compute softmax on similarity to get a(x_hat, x_i)
        # 4. attention 模块
        attention = tf.nn.softmax(embeddings_similarity)

        # \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i
        # [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]
        y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))
        # [batch_size,1,n] -> [batch_size, n]
        self.logits = tf.squeeze(y_hat)   # (batch_size, n)

        self.pred = tf.argmax(self.logits, 1)

整个网络结构可以分为四个模块:

  1. 原始图片特征提取模块
  2. Full Context Embeddings 模块
  3. 距离度量模块
  4. attention 模块



特征提取模块

特征提取模块比较简单,就是用一个4层的卷积网络,提取原始图片的全连接层特征,全连接层维度是64,即卷积网络后的输出shape是 [batch_size, 64]。该卷积网络的代码如下:

    def image_encoder(self, image):
        """the embedding function for image (potentially f = g)
        For omniglot it's a simple 4 layer ConvNet, for mini-imagenet it's VGG or Inception
        """
        with slim.arg_scope([slim.conv2d], num_outputs=64, kernel_size=3, normalizer_fn=slim.batch_norm):
            net = slim.conv2d(image)
            net = slim.max_pool2d(net, [2, 2])
            net = slim.conv2d(net)
            net = slim.max_pool2d(net, [2, 2])
            net = slim.conv2d(net)
            net = slim.max_pool2d(net, [2, 2])
            net = slim.conv2d(net)
            net = slim.max_pool2d(net, [2, 2])
        return tf.reshape(net, [-1, 1 * 1 * 64])



Full Context Embeddings 模块

此模块是论文的重点也是创新的地方,即如何对某个抽样任务的训练集样本进行 embedding 得到



g

θ

g_\theta







g










θ





















,如何对该任务的测试样本进行 embedding 得到



f

θ

f_\theta







f










θ
























  1. g

    θ

    g_\theta







    g










    θ






















    其中对训练样本,是输入到一个双向LSTM网络中,LSTM的前向和后向隐藏层单元数都是 32,LSTM 网络的输出是一个长为

    n*k

    的list,list中每个元素的shape是 (batch_size,64)。最后将输入embedding和 LSTM output 相加,相加后的结果即是



    g

    θ

    g_\theta







    g










    θ





















    ,相当于做了一个 skip connection的操作。



    g

    θ

    g_\theta







    g










    θ





















    的实现过程如下:

   def fce_g(self, encoded_x_i):
        """the fully conditional embedding function g
        This is a bi-directional LSTM, g(x_i, S) = h_i(->) + h_i(<-) + g'(x_i) where g' is the image encoder
        For omniglot, this is not used.

        encoded_x_i: g'(x_i) in the equation.   length n * k list of (batch_size ,64)
        """
        fw_cell = rnn.BasicLSTMCell(32) # 32 is half of 64 (output from cnn)
        bw_cell = rnn.BasicLSTMCell(32)
        # outputs: [(batch_size, 64), (batch_size, 64), ...], list 长度是 n*k
        outputs, state_fw, state_bw = rnn.static_bidirectional_rnn(fw_cell, bw_cell, encoded_x_i, dtype=tf.float32)

        # [n*k, batch_size, 64] + [n*k, batch_size, 64]
        return tf.add(tf.stack(encoded_x_i), tf.stack(outputs))

其中需要注意的是 batch_size 是随机抽样的 batch_size 个task,每个 task 共有

n*k

个训练样本,n值该task是n分类任务,k指每个类别共有k个样本。实际训练时,相当于LSTM网络共有

n*k

个时刻,每个时刻的输入shape都是(batch_size,64),每个时刻的前向输出shape是(batch_size,32),后向输出shape是(batch_size,32)。LSTM的训练过程示意图如下:

在这里插入图片描述




  1. f

    θ

    f_\theta







    f










    θ




















对测试任务的样本求 embedding 时,同样也是输入到一个LSTM网络中,只不过这个LSTM是有固定步数的单向lstm,共有

processing_steps

步,

processing_steps

可以提取设定。特殊的地方是,在每步的计算中加了 attention 部分,即让上一步的输出状态 h 乘以



g

θ

g_\theta







g










θ





















。最后将最后一个时刻 lstm 网络的 softmax 输出作为



f

θ

f_\theta







f










θ





















。此部分的代码实现如下:

    def fce_f(self, encoded_x, g_embedding):
        """the fully conditional embedding function f
        This is just a vanilla LSTM with attention where the input at each time step is constant and the hidden state
        is a function of previous hidden state but also a concatenated readout vector.
        For omniglot, this is not used.

        encoded_x: f'(x_hat) in equation (3) in paper appendix A.1.     (batch_size, 64)
        g_embedding: g(x_i) in equation (5), (6) in paper appendix A.1. (n * k, batch_size, 64)
        """
        cell = rnn.BasicLSTMCell(64)
        prev_state = cell.zero_state(self.batch_size, tf.float32) # state[0] is c, state[1] is h

        for step in xrange(self.processing_steps):
            output, state = cell(encoded_x, prev_state) # output: (batch_size, 64)
            
            h_k = tf.add(output, encoded_x) # (batch_size, 64)

            content_based_attention = tf.nn.softmax(tf.multiply(prev_state[1], g_embedding))    # (n * k, batch_size, 64)
            r_k = tf.reduce_sum(tf.multiply(content_based_attention, g_embedding), axis=0)      # (batch_size, 64)

            prev_state = rnn.LSTMStateTuple(state[0], tf.add(h_k, r_k))



距离度量模块

现在有了



g

θ

g_\theta







g










θ

























f

θ

f_\theta







f










θ





















,其中



g

θ

g_\theta







g










θ





















d shape是

(n*k, batch_size, 64)

,



f

θ

f_\theta







f










θ





















的shape是

(batch_size,64)

。距离度量模块就是针对每个 task,求出test和train中每个样本的余弦距离,最后输出shape为

(batch_size,n*k)

。余弦相似性的代码实现如下:

    def cosine_similarity(self, target, support_set):
        """the c() function that calculate the cosine similarity between (embedded) support set and (embedded) target
        
        note: the author uses one-sided cosine similarity as zergylord said in his repo (zergylord/oneshot)
        """
        #target_normed = tf.nn.l2_normalize(target, 1) # (batch_size, 64)
        target_normed = target
        sup_similarity = []
        for i in tf.unstack(support_set):
            i_normed = tf.nn.l2_normalize(i, 1) # (batch_size, 64)
            similarity = tf.matmul(tf.expand_dims(target_normed, 1), tf.expand_dims(i_normed, 2)) # (batch_size, )
            sup_similarity.append(similarity)

        return tf.squeeze(tf.stack(sup_similarity, axis=1)) # (batch_size, n * k)



attention模块

此模块将求出每个测试样本的label。所谓 attention,其实很简单,就是将上一步求出的相似度结果做了 softmax 激活操作,然后将最大值处的train label作为 test label。此模块的代码实现如下:

        embeddings_similarity = self.cosine_similarity(f_embedding, g_embedding) # (batch_size, n * k)

        # compute softmax on similarity to get a(x_hat, x_i)
        # 4. attention 模块
        attention = tf.nn.softmax(embeddings_similarity)

        # \hat{y} = \sum_{i=1}^{k} a(\hat{x}, x_i)y_i
        # [batch_size, 1, n*k] * [batch_size,n*k, n] = [batch_size, 1, n]
        y_hat = tf.matmul(tf.expand_dims(attention, 1), tf.one_hot(support_set_label, self.n_way))
        # [batch_size,1,n] -> [batch_size, n]
        self.logits = tf.squeeze(y_hat)   # (batch_size, n)

        self.pred = tf.argmax(self.logits, 1)



实验结果

在这里插入图片描述

在这里插入图片描述



参考资料

  • https://github.com/markdtw/matching-networks
  • https://github.com/karpathy/paper-notes/blob/master/matching_networks.md



版权声明:本文为tiangcs原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。