交叉熵是在分类任务中常用的一种损失函数,本文详细介绍了pytorch中是如何实现交叉熵的细节!!!
pytorch中的交叉熵函数为F.cross_entropy(input, target),本文以变化检测或语义分割中用到的数据模型为例:input的维度为[batchsize,classes,width,height],target的维度为[batchsize,width,height]。
-
随机生成模型数据
input = torch.rand([1, 2, 3, 3])
import numpy as np
target = np.random.randint(2, size=(1, 3, 3))
target = torch.from_numpy(target)
target = target.long()
print(target)
print(input)
# target
tensor([[[1, 1, 1],
[1, 1, 1],
[0, 0, 1]]])
# input
tensor([[[[0.5546, 0.1304, 0.9288],
[0.6879, 0.3553, 0.9984],
[0.1474, 0.6745, 0.8948]],
[[0.8524, 0.2278, 0.6476],
[0.6203, 0.6977, 0.3352],
[0.4946, 0.4613, 0.6882]]]])
-
pytorch中交叉熵的计算结果
loss = F.cross_entropy(input, target)
print(loss)
# loss
tensor(0.7403)
-
分析pytorch中交叉熵函数的具体实现
交叉熵在数学上的计算公式:
其中,p表示真实值,q表示预测值,H(p,q)表示交叉熵损失。
首先,先了解一下pytorch中是如何实现的,按住ctrl查看F.cross_entropy()函数的具体实现,发现该函数返回的是
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
发现首先是对input求log_softmax,也就是先对input求softmax,然后在对结果求log,后面就是根据每个像素的预测值以及真实值,按照上面的交叉熵公式计算损失。
-
自己实现pytorch中的交叉熵函数
大概了解了交叉熵的运算过程后,下面动手实现验证一下。
input = F.softmax(input, dim=1)
loss = 0.0
for b in range(target.shape[0]):
for i in range(target.shape[1]):
for j in range(target.shape[2]):
loss -= torch.log(input[b][target[b][i][j]][i][j])
# 求均值
print(loss/9)
# 在对pytorch中交叉熵运算过程理解后,自己实现后的结果
tensor(0.7403)
首先,明确一下,这个结果和利用pytorch中的交叉熵函数计算的损失是一样的。再解释一下上面的代码,input.shape=[1,2,3,3],target.shape=[1,3,3],input经过softmax运算后的结果如下:
tensor([[[[0.4261, 0.4757, 0.5698],
[0.5169, 0.4152, 0.6600],
[0.4141, 0.5531, 0.5515]],
[[0.5739, 0.5243, 0.4302],
[0.4831, 0.5848, 0.3400],
[0.5859, 0.4469, 0.4485]]]])
input的第二个维度是2,表示变化检测中的两个类别,未变化的像素和变化的像素,也就是上面tensor的上下两块,分别表示每个像素未变化以及变化的概率,tensor中共计包含width
height=3
3=9个像素,这些像素对应的真实标签为:
# target
tensor([[[1, 1, 1],
[1, 1, 1],
[0, 0, 1]]])
每个像素的交叉熵损失的数学计算过程如下:以第一个像素为例,该像素的真实标签是1,即target[0][0][0];该像素预测标签为0的概率为0.4261,即input[0][0][0][0],预测标签为1的概率为0.5739,即input[0][1][0][0];故该像素对应的交叉熵损失为-log(0.5739)。
补充:
这里可能就会有人有疑问了,明明交叉熵的计算过程是一个求和的过程,你上面的交叉熵计算怎么只有一项,这里我再解释一下,分类过程中真实标签采用one-hot形式,可以和预测值一一对应上,看下面的表格就很容易理解计算过程了:
类别1 | 类别2 | |
---|---|---|
真实值 | 0 | 1 |
预测值 | 0.4261 | 0.5739 |
该像素的真实标签是1,转换成one-hot形式就是[0,1],预测值是[0.4261,0,5739],对应的交叉熵损失为loss=-0*log(0.4261)-1*log(0.5739)=-log(0.5739),发现和上面的计算是一样的。ok!
因此,9个像素的交叉熵损失和为:
loss = -log(0.5739)-log(0.5243)-log(0.4302)-log(0.4831)-log(0.5848)-log(0.3400)-log(0.4141)-log(0.5531)-log(0.4485)
由于pytorch交叉熵函数的参数size_average默认值是true,所以对loss求均值可以得到最终的结果。
注:本人njust cv新生,第一次写博客,希望能帮助到大家,如有错误还请指出!