一文搞懂F.cross_entropy的具体实现

  • Post author:
  • Post category:其他


交叉熵是在分类任务中常用的一种损失函数,本文详细介绍了pytorch中是如何实现交叉熵的细节!!!

pytorch中的交叉熵函数为F.cross_entropy(input, target),本文以变化检测或语义分割中用到的数据模型为例:input的维度为[batchsize,classes,width,height],target的维度为[batchsize,width,height]。


  1. 随机生成模型数据
input = torch.rand([1, 2, 3, 3])
import numpy as np
target = np.random.randint(2, size=(1, 3, 3))
target = torch.from_numpy(target)
target = target.long()
print(target)
print(input)
# target
tensor([[[1, 1, 1],
         [1, 1, 1],
         [0, 0, 1]]])
# input
tensor([[[[0.5546, 0.1304, 0.9288],
          [0.6879, 0.3553, 0.9984],
          [0.1474, 0.6745, 0.8948]],

		 [[0.8524, 0.2278, 0.6476],
          [0.6203, 0.6977, 0.3352],
          [0.4946, 0.4613, 0.6882]]]])

  1. pytorch中交叉熵的计算结果
loss = F.cross_entropy(input, target)
print(loss)
# loss
tensor(0.7403)

  1. 分析pytorch中交叉熵函数的具体实现

交叉熵在数学上的计算公式:

在这里插入图片描述

其中,p表示真实值,q表示预测值,H(p,q)表示交叉熵损失。

首先,先了解一下pytorch中是如何实现的,按住ctrl查看F.cross_entropy()函数的具体实现,发现该函数返回的是

return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

发现首先是对input求log_softmax,也就是先对input求softmax,然后在对结果求log,后面就是根据每个像素的预测值以及真实值,按照上面的交叉熵公式计算损失。


  1. 自己实现pytorch中的交叉熵函数


    大概了解了交叉熵的运算过程后,下面动手实现验证一下。
input = F.softmax(input, dim=1)
loss = 0.0
for b in range(target.shape[0]):
    for i in range(target.shape[1]):
        for j in range(target.shape[2]):
            loss -= torch.log(input[b][target[b][i][j]][i][j])
# 求均值
print(loss/9)  
# 在对pytorch中交叉熵运算过程理解后,自己实现后的结果
tensor(0.7403)

首先,明确一下,这个结果和利用pytorch中的交叉熵函数计算的损失是一样的。再解释一下上面的代码,input.shape=[1,2,3,3],target.shape=[1,3,3],input经过softmax运算后的结果如下:

tensor([[[[0.4261, 0.4757, 0.5698],
          [0.5169, 0.4152, 0.6600],
          [0.4141, 0.5531, 0.5515]],

         [[0.5739, 0.5243, 0.4302],
          [0.4831, 0.5848, 0.3400],
          [0.5859, 0.4469, 0.4485]]]])

input的第二个维度是2,表示变化检测中的两个类别,未变化的像素和变化的像素,也就是上面tensor的上下两块,分别表示每个像素未变化以及变化的概率,tensor中共计包含width

height=3

3=9个像素,这些像素对应的真实标签为:

# target
tensor([[[1, 1, 1],
         [1, 1, 1],
         [0, 0, 1]]])

每个像素的交叉熵损失的数学计算过程如下:以第一个像素为例,该像素的真实标签是1,即target[0][0][0];该像素预测标签为0的概率为0.4261,即input[0][0][0][0],预测标签为1的概率为0.5739,即input[0][1][0][0];故该像素对应的交叉熵损失为-log(0.5739)。


补充:


这里可能就会有人有疑问了,明明交叉熵的计算过程是一个求和的过程,你上面的交叉熵计算怎么只有一项,这里我再解释一下,分类过程中真实标签采用one-hot形式,可以和预测值一一对应上,看下面的表格就很容易理解计算过程了:

类别1 类别2
真实值 0 1
预测值 0.4261 0.5739

该像素的真实标签是1,转换成one-hot形式就是[0,1],预测值是[0.4261,0,5739],对应的交叉熵损失为loss=-0*log(0.4261)-1*log(0.5739)=-log(0.5739),发现和上面的计算是一样的。ok!


因此,9个像素的交叉熵损失和为:


loss = -log(0.5739)-log(0.5243)-log(0.4302)-log(0.4831)-log(0.5848)-log(0.3400)-log(0.4141)-log(0.5531)-log(0.4485)


由于pytorch交叉熵函数的参数size_average默认值是true,所以对loss求均值可以得到最终的结果。


注:本人njust cv新生,第一次写博客,希望能帮助到大家,如有错误还请指出!



版权声明:本文为code_plus原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。