一、TripletLoss
如上图所示,
triplet是一个三元组
,这个三元组是这样构成的:从训练数据集中随机选一个样本,该样本称为Anchor,然后再随机选取一个和Anchor (记为x_a)属于同一类的样本和不同类的样本,这两个样本对应的称为Positive (记为x_p)和Negative (记为x_n),由此构成一个(
Anchor,Positive,Negative
)三元组。
有了上面的triplet的概念, triplet loss就好理解了。针对三元组中的每个元素(样本),训练一个参数共享或者不共享的网络,得到三个元素的特征表达,分别记为:
f
(
x
i
a
)
,
f
(
x
i
p
)
,
f
(
x
i
n
)
f(x_{i}^{a}),f(x_{i}^{p}), f(x_{i}^{n})
f
(
x
i
a
)
,
f
(
x
i
p
)
,
f
(
x
i
n
)
。triplet loss的目的就是通过学习,让x_a和x_p特征表达之间的距离尽可能小,而x_a和x_n的特征表达之间的距离尽可能大,并且满足以下公式:
∣
∣
f
(
x
i
a
)
−
f
(
x
i
p
)
∣
∣
2
2
+
α
<
∣
∣
f
(
x
i
a
)
−
f
(
x
i
n
)
∣
∣
2
2
||f(x_{i}^{a})-f(x_{i}^{p})||_{2}^{2}+\alpha<||f(x_{i}^{a})-f(x_{i}^{n})||_{2}^{2}
∣
∣
f
(
x
i
a
)
−
f
(
x
i
p
)
∣
∣
2
2
+
α
<
∣
∣
f
(
x
i
a
)
−
f
(
x
i
n
)
∣
∣
2
2
对应的目标函数也就很清楚了:
∑
i
N
[
∣
∣
f
(
x
i
a
)
−
f
(
x
i
p
)
∣
∣
2
2
−
∣
∣
f
(
x
i
a
)
−
f
(
x
i
n
)
∣
∣
2
2
+
α
]
+
\sum_{i}^{N}[||f(x_{i}^{a})-f(x_{i}^{p})||_{2}^{2}-||f(x_{i}^{a})-f(x_{i}^{n})||_{2}^{2}+\alpha]_{+}
i
∑
N
[
∣
∣
f
(
x
i
a
)
−
f
(
x
i
p
)
∣
∣
2
2
−
∣
∣
f
(
x
i
a
)
−
f
(
x
i
n
)
∣
∣
2
2
+
α
]
+
这里距离用欧式距离度量,+表示[]内的值大于零的时候,取该值为损失,小于零的时候,损失为零。
由目标函数可以看出:
-
当x_a与x_n之间的距离 < x_a与x_p之间的距离加
α\alpha
α
时,[]内的值大于零,就会产生损失。 -
当x_a与x_n之间的距离 >= x_a与x_p之间的距离加
α\alpha
α
时,损失为零。
代码实现:
import numpy as np
def test_pairwise_distances(squared=False):
'''两两embedding的距离,比如第一行, 0和0距离为0, 0和1距离为8, 0和2距离为16 (注意开过根号)
[[ 0. 8. 16.]
[ 8. 0. 8.]
[16. 8. 0.]]
'''
embeddings = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.float32)
dot_product = np.dot(embeddings, np.transpose(embeddings))
square_norm = np.diag(dot_product)
distances = np.expand_dims(square_norm, axis=1) - 2.0 * dot_product + np.expand_dims(square_norm, 0)
mask = np.float32(np.equal(distances, 0.0))
if not squared:
distances = distances + mask * 1e-16
distances = np.sqrt(distances)
distances = distances * (1.0 - mask)
print(distances)
return distances
def test_get_triplet_mask(labels):
'''
valid (i, j, k)满足
- i, j, k 不相等
- labels[i] == labels[j] && labels[i] != labels[k]
'''
# 初始化一个二维矩阵,坐标(i, j)不相等置为1,得到indices_not_equal
indices_equal = np.cast[np.bool](np.eye(np.shape(labels)[0], dtype=np.int32))
indices_not_equal = np.logical_not(indices_equal)
# 因为最后得到一个3D的mask矩阵(i, j, k),增加一个维度,则 i_not_equal_j 在第三个维度增加一个即,(batch_size, batch_size, 1), 其他同理
i_not_equal_j = np.expand_dims(indices_not_equal, 2)
i_not_equal_k = np.expand_dims(indices_not_equal, 1)
j_not_equal_k = np.expand_dims(indices_not_equal, 0)
# 想得到i!=j!=k, 三个不等取and即可
# 比如这里得到
'''array([[[False, False, False],
[False, False, True],
[False, True, False]],
[[False, False, True],
[False, False, False],
[ True, False, False]],
[[False, True, False],
[ True, False, False],
[False, False, False]]])'''
# 只有下标(i, j, k)不相等时才是True
distinct_indices = np.logical_and(np.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
# 同样根据labels得到对应i=j, i!=k
label_equal = np.equal(np.expand_dims(labels, 0), np.expand_dims(labels, 1))
i_equal_j = np.expand_dims(label_equal, 2)
i_equal_k = np.expand_dims(label_equal, 1)
valid_labels = np.logical_and(i_equal_j, np.logical_not(i_equal_k))
# mask即为满足上面两个约束,所以两个3D取and
mask = np.logical_and(valid_labels, distinct_indices)
return mask
def test_batch_all_triplet_loss(margin):
# 得到每两两embeddings的距离,然后增加一个维度,一维需要得到(batch_size, batch_size, batch_size)大小的3D矩阵
# 然后再点乘上valid 的 mask即可
labels = np.array([1, 0, 1]) # 比如1,3是正例,2是负例,这样计算出的loss应该是16-8 = 8
pairwise_distances = test_pairwise_distances()
anchor_positive = np.expand_dims(pairwise_distances, axis=2)
anchor_negative = np.expand_dims(pairwise_distances, axis=1)
triplet_loss = anchor_positive - anchor_negative + margin
mask = test_get_triplet_mask(labels)
mask = np.cast[np.float32](mask)
triplet_loss = np.multiply(mask, triplet_loss)
triplet_loss = np.maximum(triplet_loss, 0.0)
valid_triplet_loss = np.cast[np.float32](np.greater(triplet_loss, 1e-16))
num_positive_triplet = np.sum(valid_triplet_loss)
num_valid_triplet_loss = np.sum(mask)
fraction_positive_triplet = num_positive_triplet / (num_valid_triplet_loss + 1e-16)
triplet_loss = np.sum(triplet_loss) / (num_positive_triplet + 1e-16)
return triplet_loss, fraction_positive_triplet
if __name__ == '__main__':
test_batch_all_triplet_loss(margin = 0.0)
二、HardTripletLoss
原理即找到
Anchor
与
Positive
之间最大的距离,
Anchor
与
Negative
之间最小的距离(即最难训练的),
对应的目标函数也就很清楚了:
∑
i
N
[
m
a
x
(
∣
∣
f
(
x
i
a
)
−
f
(
x
i
p
)
∣
∣
2
2
)
−
m
i
n
(
∣
∣
f
(
x
i
a
)
−
f
(
x
i
n
)
∣
∣
2
2
)
+
α
]
+
\sum_{i}^{N}[max(||f(x_{i}^{a})-f(x_{i}^{p})||_{2}^{2})-min(||f(x_{i}^{a})-f(x_{i}^{n})||_{2}^{2})+\alpha]_{+}
i
∑
N
[
m
a
x
(
∣
∣
f
(
x
i
a
)
−
f
(
x
i
p
)
∣
∣
2
2
)
−
m
i
n
(
∣
∣
f
(
x
i
a
)
−
f
(
x
i
n
)
∣
∣
2
2
)
+
α
]
+
代码实现:
import numpy as np
def test_pairwise_distances(squared=False):
'''两两embedding的距离,比如第一行, 0和0距离为0, 0和1距离为8, 0和2距离为16 (注意开过根号)
[[ 0. 8. 16.]
[ 8. 0. 8.]
[16. 8. 0.]]
'''
embeddings = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.float32)
dot_product = np.dot(embeddings, np.transpose(embeddings))
square_norm = np.diag(dot_product)
distances = np.expand_dims(square_norm, axis=1) - 2.0 * dot_product + np.expand_dims(square_norm, 0)
mask = np.float32(np.equal(distances, 0.0))
if not squared:
distances = distances + mask * 1e-16
distances = np.sqrt(distances)
distances = distances * (1.0 - mask)
print(distances)
return distances
def test_anchor_positive_triplet_mask(labels):
# 得到positive的2D mask, i!=j and i和j有相同labels
indices_equal = np.cast[np.bool](np.eye(np.shape(labels)[0]))
indices_not_equal = np.logical_not(indices_equal)
labels_equal = np.equal(np.expand_dims(labels, 0), np.expand_dims(labels, 1))
mask = np.logical_and(indices_not_equal, labels_equal)
return mask
def test_get_anchor_negative_triplet_mask(labels):
# 得到negative的2D mask
labels_equal = np.equal(np.expand_dims(labels, 0), np.expand_dims(labels, 1))
mask = np.logical_not(labels_equal)
return mask
def test_batch_hard_triplet_loss(margin):
# 计算得到两两的距离pairwise_distances
# 计算最大的positive距离,只需要取每行最大元素即可
# 计算最小的negative距离,不能直接取每行最小的元素,因为invalid的[a, n]设置为0,这里设置invalid的位置为每一行最大的值,这样就可以取每一行最小的值了
labels = np.array([1, 0, 1])
pairwise_distances = test_pairwise_distances()#计算得到两两的距离pairwise_distances
mask_anchor_positive = test_anchor_positive_triplet_mask(labels)#得到anchor_positive的索引,布尔格式
mask_anchor_positive = np.cast[np.float](mask_anchor_positive)#转换类型,0,1
anchor_positive_dist = np.multiply(mask_anchor_positive, pairwise_distances)#相乘得到anchor_positive对应的距离
hardest_positive_dist = np.max(anchor_positive_dist, axis=1, keepdims=True)#取anchor_positive最大值距离
mask_anchor_negative = test_get_anchor_negative_triplet_mask(labels)#得到anchor_negative的索引,布尔格式
mask_anchor_negative = np.cast[np.float](mask_anchor_negative)#转换类型,0,1
max_anchor_negative_dist = np.max(pairwise_distances, axis=1, keepdims=True)#取pairwise_distances每一行最大的值
anchor_negative_dist = pairwise_distances + max_anchor_negative_dist * (1.0 - mask_anchor_negative)#直接取最小值有invalid的0干扰,所以设置invalid的位置为每一行最大的值
hardest_negative_dist = np.min(anchor_negative_dist, axis=1, keepdims=True)#取anchor_negative最小值距离
triplet_loss = np.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)#目标函数
triplet_loss = np.mean(triplet_loss)
return triplet_loss
if __name__ == '__main__':
test_batch_hard_triplet_loss(margin=0.0)
可用于训练文本相似匹配,文本分类。