多分类器softmax——绝对简单易懂的梯度推导

  • Post author:
  • Post category:其他


首先说明,求导不只是链式法则这么简单。我们常常不知道需要对谁求导,如何从最后的损失函数一步一步的计算到每一个参数上。此外,我们也有可能遇到不知道根据公式来进行编程,根本原因在于公式和编程并不是同样的语言,这是有差别的,我们如何跨越这个差别呢?

如果你有以上两个困惑,希望本文和下一篇博客能助你一臂之力。

本文主要针对第一个问题。第二个问题将会在下篇博客详细说明。



损失函数的计算

首先说明本文解决的是softmax的多分类器的梯度求导,以下先给出损失函数的计算方式:

这里将最终的loss分为4步进行计算,如下所示,当然,这里不解释为什么是这样的计算方式。

注意到,本文并不限制训练样本的数量,训练样本的特征数,以及最后分为几类。

公式(1)

这里x表示输入,w表示权重参数。

说明:这里的x和w的下标表示x的某一行和w的某一列相乘在逐项相加得到s。

然后再根据s计算每一个类的概率,如下公式(2)

公式(2)

这里采用的下标和公式(1)不相同,其中,n表示样本的个数,y表示样本为n时的正确分类标号。k表示有多少分类。这个公式就是先将s进行e次方计算,然后归一化,求得该样本正确分类下的概率p.

根据p可以计算出每一个样本的损失,如公式(3):

公式(3)

这个公式说明,每一个样本的损失仅仅是正确分类对应的概率值的log函数,这里准确说应该是ln函数,也就是以自然对数为底的,这样计算导数更方便,后面会以ln为版本进行计算。

最后,根据公式(4)计算所有样本的损失:

公式(4)

也就是将所有样本的损失求平均数。

注意:以上下标是独立系统,与下面的推导过程没有必然关系,

这里特别指ij

,其他字母的含义基本相同。



基本求导法则

所谓梯度,就是求损失函数对参数w的导数,将其用在更新参数w上,达到优化的目的。

我们知道,梯度计算遵循着链式法则,而基本求导公式也是需要的,防止有人忘记,我先给出这里将会用到的基本求导公式。知道的请跳过这一节,直接看下一节。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述



以下开始正式求梯度

计算整个损失函数对w(下标为ij)的导数。

根据链式法则,

考虑到总损失为每个样本损失的平均数,且每个样本的损失都与wij相关

,这个说明很有必要,假如某个损失与wij无关,我们就不用对它进行求导了。

有公式(5)

公式(5)

这里Ln表示样本为n时的损失函数。

不失一般性,这里对最后一项进行继续推导,然后将其相加。

同样的,由于pny是和wij的函数,有公式(6):

公式(6)

结合公式(2),前一部分有有公式(7):

在这里插入图片描述

后一个部分我们继续来考虑,pny的上下两项是否都是wij的函数?肯定的回答是,这不一定,由公式(2)和(1)可知,如果公式2中分子的下标y不是j,那么实际上这里公式2的分子就不是wij的函数。


我们细说一下,由公式1,ij是公式1中的下标,当sij与wij有关系建立在这个j相等的情况,但是公式2的分子并不一定就满足这个关系的,什么情况满足呢?那就是样本n的正确分类的下标j和wij中的下标j相等时;否则这就没有关系。


因此,我们需要分为两种情况来进一步计算公式(6)的后半部分。

(实际上,我们也可以先认为他们相关,然后进一步处理,这里我先不这么做)


情况一,公式(2)中的分子与wij无关

:也就是以下公式中y与j不相等


公式(2)中分母必然与wij有关,且只有一个与wij有关。那就是公式(2)中分母的下标k与wij的就相等时,而其他都与wij无关。


进一步考虑到e的s次方,s与wij的关系,因此针对情况一,有公式(8)

公式8

继续对第二项展开有公式(9):

在这里插入图片描述

这里还是细细说一下,这个过程,始终记住一点,那就是中间变量与wij是什么关系,可以根据公式看出来。根据公式(1),当且仅当s的下标中是ij时才会与wij有关,而对sij对wij求导时得到的就是xii,(两个i不一样的含义)只需要把公式(1)中的x和w的下标中的点号换成i即可。也就是说,s对w求导时,x的第一个下标是s的第一个下标,x的第二个下标是w的第一个下标。当然,这里我们需要再将s的下标i换成n,这样才能满足以上的推导。

我们将公式(9)根据公式(2)化简一下,再带入公式(6),可以得到公式(10),也就是情况一下的最终一个样本的梯度:

公式(10)

其中,用了一个简写,也就是求和的项简写了,请留意。

写成pnj是因为我们计算过程中会产生这个数,而且这样写起来也更整齐。


情况二,公式(2)中的分子是wij的函数



注意到这里,公式(2)中pny的下标y和wij的下标j是相等的,也就是y=j。

情况2比情况1复杂在公式(2)的分母上,其余相同,因此,对其求导过程如下:

这里先使用ynj(nj是下标)表示样本为n时第j个分类的真实值,要么是0,要么是1,1表示真实分类就是这个j.

情况一根据(1\u)’求导,情况二则根据(v/u)’来求导,因此有一点差别。

以下一步一步的写:

在这里插入图片描述

根据公式(2)将后面展开可得:

在这里插入图片描述

化简一下可以得到:

在这里插入图片描述

根据公式(2)继续化简:

在这里插入图片描述

对上式去括号操作:

在这里插入图片描述

继续求导并且根据公式2化简得公式(11)

在这里插入图片描述

可以看出,这与上面的情况一相差在最后一项上,而前面一项是相等的。

接下来我们一起探讨一下怎么求后面的一项,毕竟这还无法完全理解清楚,因为这还是一个导数,也不是输入或者中间求到的某个数。

前面我们已经说到,情况二下公式(2)中的y和wij的j是相等的。

这时候计算知道:

在这里插入图片描述

所以公式(11)进一步计算可得最终的求导公式:公式(12)

在这里插入图片描述



综合两个情况

情况二比情况一多减去一项。

一般情况下,我们直接使用pnj * xni即可。

而当wij中j是当前样本n的正确分类时要多减去xni。

以上既是多分类器softmax的梯度求导公式。



后话

其实个人感觉梯度的计算还是挺难的,而且本文只是推导公式,还没有真正的编程计算。

实际上,我们通常为了保证我们的程序正确,会写一个数值求导,正确情况下两者不会相差很多。

本文的理论推导,将会在下一篇博客中写明如何进行计算。



版权声明:本文为qq_27261889原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。