TripletLoss、HardTripletLoss笔记

  • Post author:
  • Post category:其他



一、TripletLoss


在这里插入图片描述

如上图所示,

triplet是一个三元组

,这个三元组是这样构成的:从训练数据集中随机选一个样本,该样本称为Anchor,然后再随机选取一个和Anchor (记为x_a)属于同一类的样本和不同类的样本,这两个样本对应的称为Positive (记为x_p)和Negative (记为x_n),由此构成一个(

Anchor,Positive,Negative

)三元组。

有了上面的triplet的概念, triplet loss就好理解了。针对三元组中的每个元素(样本),训练一个参数共享或者不共享的网络,得到三个元素的特征表达,分别记为:



f

(

x

i

a

)

,

f

(

x

i

p

)

,

f

(

x

i

n

)

f(x_{i}^{a}),f(x_{i}^{p}), f(x_{i}^{n})






f


(



x











i










a

















)


,




f


(



x











i










p

















)


,




f


(



x











i










n

















)





。triplet loss的目的就是通过学习,让x_a和x_p特征表达之间的距离尽可能小,而x_a和x_n的特征表达之间的距离尽可能大,并且满足以下公式:




f

(

x

i

a

)

f

(

x

i

p

)

2

2

+

α

<

f

(

x

i

a

)

f

(

x

i

n

)

2

2

||f(x_{i}^{a})-f(x_{i}^{p})||_{2}^{2}+\alpha<||f(x_{i}^{a})-f(x_{i}^{n})||_{2}^{2}












f


(



x











i










a

















)













f


(



x











i










p

















)


















2










2



















+








α




<














f


(



x











i










a

















)













f


(



x











i










n

















)


















2










2






















对应的目标函数也就很清楚了:





i

N

[

f

(

x

i

a

)

f

(

x

i

p

)

2

2

f

(

x

i

a

)

f

(

x

i

n

)

2

2

+

α

]

+

\sum_{i}^{N}[||f(x_{i}^{a})-f(x_{i}^{p})||_{2}^{2}-||f(x_{i}^{a})-f(x_{i}^{n})||_{2}^{2}+\alpha]_{+}















i

















N
















[








f


(



x











i










a

















)













f


(



x











i










p

















)


















2










2


































f


(



x











i










a

















)













f


(



x











i










n

















)


















2










2



















+








α



]











+






















这里距离用欧式距离度量,+表示[]内的值大于零的时候,取该值为损失,小于零的时候,损失为零。

由目标函数可以看出:

  • 当x_a与x_n之间的距离 < x_a与x_p之间的距离加



    α

    \alpha






    α





    时,[]内的值大于零,就会产生损失。

  • 当x_a与x_n之间的距离 >= x_a与x_p之间的距离加



    α

    \alpha






    α





    时,损失为零。

代码实现:

import numpy as np

def test_pairwise_distances(squared=False):
    '''两两embedding的距离,比如第一行, 0和0距离为0, 0和1距离为8, 0和2距离为16 (注意开过根号)
    [[ 0.  8. 16.]
     [ 8.  0.  8.]
     [16.  8.  0.]]
    '''
    embeddings = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.float32)
    dot_product = np.dot(embeddings, np.transpose(embeddings))
    square_norm = np.diag(dot_product)
    distances = np.expand_dims(square_norm, axis=1) - 2.0 * dot_product + np.expand_dims(square_norm, 0)
    mask = np.float32(np.equal(distances, 0.0))
    if not squared:
        distances = distances + mask * 1e-16
        distances = np.sqrt(distances)
        distances = distances * (1.0 - mask)
    print(distances)
    return distances

def test_get_triplet_mask(labels):
    '''
    valid (i, j, k)满足
         - i, j, k 不相等
         - labels[i] == labels[j]  && labels[i] != labels[k]

    '''
    # 初始化一个二维矩阵,坐标(i, j)不相等置为1,得到indices_not_equal
    indices_equal = np.cast[np.bool](np.eye(np.shape(labels)[0], dtype=np.int32))
    indices_not_equal = np.logical_not(indices_equal)
    # 因为最后得到一个3D的mask矩阵(i, j, k),增加一个维度,则 i_not_equal_j 在第三个维度增加一个即,(batch_size, batch_size, 1), 其他同理
    i_not_equal_j = np.expand_dims(indices_not_equal, 2)
    i_not_equal_k = np.expand_dims(indices_not_equal, 1)
    j_not_equal_k = np.expand_dims(indices_not_equal, 0)
    # 想得到i!=j!=k, 三个不等取and即可
    # 比如这里得到
    '''array([[[False, False, False],
               [False, False,  True],
               [False,  True, False]],
              [[False, False,  True],
               [False, False, False],
               [ True, False, False]],
              [[False,  True, False],
              [ True, False, False],
              [False, False, False]]])'''
    # 只有下标(i, j, k)不相等时才是True
    distinct_indices = np.logical_and(np.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)

    # 同样根据labels得到对应i=j, i!=k
    label_equal = np.equal(np.expand_dims(labels, 0), np.expand_dims(labels, 1))
    i_equal_j = np.expand_dims(label_equal, 2)
    i_equal_k = np.expand_dims(label_equal, 1)
    valid_labels = np.logical_and(i_equal_j, np.logical_not(i_equal_k))

    # mask即为满足上面两个约束,所以两个3D取and
    mask = np.logical_and(valid_labels, distinct_indices)
    return mask

def test_batch_all_triplet_loss(margin):
    # 得到每两两embeddings的距离,然后增加一个维度,一维需要得到(batch_size, batch_size, batch_size)大小的3D矩阵
    # 然后再点乘上valid 的 mask即可
    labels = np.array([1, 0, 1])  # 比如1,3是正例,2是负例,这样计算出的loss应该是16-8 = 8
    pairwise_distances = test_pairwise_distances()
    anchor_positive = np.expand_dims(pairwise_distances, axis=2)
    anchor_negative = np.expand_dims(pairwise_distances, axis=1)
    triplet_loss = anchor_positive - anchor_negative + margin

    mask = test_get_triplet_mask(labels)
    mask = np.cast[np.float32](mask)
    triplet_loss = np.multiply(mask, triplet_loss)
    triplet_loss = np.maximum(triplet_loss, 0.0)

    valid_triplet_loss = np.cast[np.float32](np.greater(triplet_loss, 1e-16))
    num_positive_triplet = np.sum(valid_triplet_loss)
    num_valid_triplet_loss = np.sum(mask)
    fraction_positive_triplet = num_positive_triplet / (num_valid_triplet_loss + 1e-16)

    triplet_loss = np.sum(triplet_loss) / (num_positive_triplet + 1e-16)
    return triplet_loss, fraction_positive_triplet
    
if __name__ == '__main__':
    test_batch_all_triplet_loss(margin = 0.0)


二、HardTripletLoss

原理即找到

Anchor



Positive

之间最大的距离,

Anchor



Negative

之间最小的距离(即最难训练的),

对应的目标函数也就很清楚了:





i

N

[

m

a

x

(

f

(

x

i

a

)

f

(

x

i

p

)

2

2

)

m

i

n

(

f

(

x

i

a

)

f

(

x

i

n

)

2

2

)

+

α

]

+

\sum_{i}^{N}[max(||f(x_{i}^{a})-f(x_{i}^{p})||_{2}^{2})-min(||f(x_{i}^{a})-f(x_{i}^{n})||_{2}^{2})+\alpha]_{+}















i

















N
















[


m


a


x


(








f


(



x











i










a

















)













f


(



x











i










p

















)


















2










2

















)













m


i


n


(








f


(



x











i










a

















)













f


(



x











i










n

















)


















2










2

















)




+








α



]











+






















代码实现:

import numpy as np

def test_pairwise_distances(squared=False):
    '''两两embedding的距离,比如第一行, 0和0距离为0, 0和1距离为8, 0和2距离为16 (注意开过根号)
    [[ 0.  8. 16.]
     [ 8.  0.  8.]
     [16.  8.  0.]]
    '''
    embeddings = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.float32)
    dot_product = np.dot(embeddings, np.transpose(embeddings))
    square_norm = np.diag(dot_product)
    distances = np.expand_dims(square_norm, axis=1) - 2.0 * dot_product + np.expand_dims(square_norm, 0)
    mask = np.float32(np.equal(distances, 0.0))
    if not squared:
        distances = distances + mask * 1e-16
        distances = np.sqrt(distances)
        distances = distances * (1.0 - mask)
    print(distances)
    return distances

def test_anchor_positive_triplet_mask(labels):
 # 得到positive的2D mask, i!=j and i和j有相同labels
    indices_equal = np.cast[np.bool](np.eye(np.shape(labels)[0]))
    indices_not_equal = np.logical_not(indices_equal)
    labels_equal = np.equal(np.expand_dims(labels, 0), np.expand_dims(labels, 1))
    mask = np.logical_and(indices_not_equal, labels_equal)
    return mask

def test_get_anchor_negative_triplet_mask(labels):
    # 得到negative的2D mask
    labels_equal = np.equal(np.expand_dims(labels, 0), np.expand_dims(labels, 1))
    mask = np.logical_not(labels_equal)
    return mask


def test_batch_hard_triplet_loss(margin):
    # 计算得到两两的距离pairwise_distances
    # 计算最大的positive距离,只需要取每行最大元素即可
    # 计算最小的negative距离,不能直接取每行最小的元素,因为invalid的[a, n]设置为0,这里设置invalid的位置为每一行最大的值,这样就可以取每一行最小的值了
    labels = np.array([1, 0, 1])
    pairwise_distances = test_pairwise_distances()#计算得到两两的距离pairwise_distances
    mask_anchor_positive = test_anchor_positive_triplet_mask(labels)#得到anchor_positive的索引,布尔格式
    mask_anchor_positive = np.cast[np.float](mask_anchor_positive)#转换类型,0,1
    anchor_positive_dist = np.multiply(mask_anchor_positive, pairwise_distances)#相乘得到anchor_positive对应的距离
    hardest_positive_dist = np.max(anchor_positive_dist, axis=1, keepdims=True)#取anchor_positive最大值距离
    mask_anchor_negative = test_get_anchor_negative_triplet_mask(labels)#得到anchor_negative的索引,布尔格式
    mask_anchor_negative = np.cast[np.float](mask_anchor_negative)#转换类型,0,1
    max_anchor_negative_dist = np.max(pairwise_distances, axis=1, keepdims=True)#取pairwise_distances每一行最大的值
    anchor_negative_dist = pairwise_distances + max_anchor_negative_dist * (1.0 - mask_anchor_negative)#直接取最小值有invalid的0干扰,所以设置invalid的位置为每一行最大的值
    hardest_negative_dist = np.min(anchor_negative_dist, axis=1, keepdims=True)#取anchor_negative最小值距离
    triplet_loss = np.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)#目标函数
    triplet_loss = np.mean(triplet_loss)
    return triplet_loss


if __name__ == '__main__':
    test_batch_hard_triplet_loss(margin=0.0)

可用于训练文本相似匹配,文本分类。



版权声明:本文为weixin_40746796原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。