Skip to content

上一篇文章学习矩阵求导的方法,这篇文章以上一篇文章为基础,推导一个三层的MLP的反向传播算法并给出它的代码实现。

前言

进行推导前需要先列出几个后面需要使用的函数

首先是大名的鼎鼎的交叉熵(Cross Entropy)函数,它的定义如下,其中 y 是独热码编码后的类别,softmax的函数输出可以衡量输出是每个类别的概率大小。

L=yTlog(softmax(x))softmax(x)=exij=1KexjL(x)=softmax(x)y

我们后面的激活函数采用比较常见的σ函数,它的定义和导数是

σ(x)=11+exp(x)σ(x)=σ(x)(1σ(x))

当然,后面的推导也不会使用σ函数求导的完整结果(因为太长了),而是用σ表示,如果要更换ReLu或者别的激活函数,如只需要将求导结果带进去就可以了。

符号定义

假设我们的网络有 n层,wibi分别表示第i层的参数(wight)和偏置(bias),使用ai表示第i层还没有经过激活函数的结果,使用 hi表示第i层的输出同时也是第i+1层的输入,也就是相当于hi=σ(ai)σ函数作为激活函数。整个网络的前向传播的Loss表达式可以写成下面的这种形式。

L=yTlog(softmax(h3))h3=h2w3+b3a2=h1w2+b2h2=σ(a2)a1=h0w1+b1h1=σ(a1)h0=x

这里假设输入的向量x和输出向量y都是列向量

当然,如果不定义符号进行替换的话就可以写成下面的这种比较长的形式。

L=yTlog(softmax(σ(σ(xw1+b1)w2+b2)w3+b3))

可以看到,MLP的本质其实就是我们最常见的复合函数的形式,只是我们的输入的变量x是一个向量或者矩阵的形式。我们的任务就是借助上一篇文章中所提到的矩阵求导的手段求出 L 对每个wibi的偏导表达式。

第三层推导

根据前言部分log(softmax(x))函数的求导,可以直接求得h3的偏导数

Lh3=softmax(h3)y

dL 写成迹的形式

tr(dL)=tr((Lh3)Tdh3)=tr((Lh3)Td(h2w3+b3))=tr((Lh3)Td(h2w3))+tr((Lh3)Td(b3))

通过红色的部分就得到了我们可以看到L关于b3的微分形式,通过这个式子就可以写出L关于b3的偏导数:

Lb3=((Lh3)T)T=Lh3

将前面蓝色的部分单独写下来,继续经过变形

tr((Lh3)Td(h2w3))=tr((Lh3)Td(h2)w3)+tr((Lh3)Th2d(w3))=tr(w3(Lh3)Td(h2))+tr((Lh3)Th2d(w3))

同样的由红色部分,我们可以得到L关于w3的偏导,蓝色的部分我们后续的推导需要用。

Lw3=(Lh3)Th2)T=h2TLh3

到这里,我们第三层的所有参数 w3b3的梯度我们都算出来了。

第二层推导

接着把第三层推导剩下的蓝色部分单独提出来

tr(w3(Lh3)Td(h2))=tr(w3(Lh3)Td(σ(a2)))=tr(w3(Lh3)Tσ(a2)d(a2))

根据上篇文章提到的tr(AT(BC))=tr((AB)TC)可以化简为

tr(w3(Lh3)Tσ(a2)d(a2))=tr((Lh3w3T)Tσ(a2)d(a2))=tr((Lh3w3Tσ(a2))Td(a2))=tr((Lh3w3Tσ(a2))Td(h1w2+b2))=tr((Lh3w3Tσ(a2))Td(h1w2))+tr((Lh3w3Tσ(a2))Td(b2))=tr((Lh3w3Tσ(a2))Td(h1)w2)+tr((Lh3w3Tσ(a2))Th1d(w2))+tr((Lh3w3Tσ(a2))Td(b2))

可以得到

Lw2=((Lh3w3Tσ(a2))Th1)T=h1T(Lh3w3Tσ(a2))Lb2=Lh3w3Tσ(a2)

第一层推导

经过前两节的推导,我们得到了第二和第三层的表达式,但是他们的规律还不是非常的明显,我们继续推导第一层的偏导数。

tr((Lh3w3Tσ(a2))Td(h1)w2)=tr(w2(Lh3w3Tσ(a2))Td(σ(a1))=tr(w2(Lh3w3Tσ(a2))Tσ(a1)d(a1))=tr((Lh3w3Tσ(a2)w2T)Tσ(a1)d(a1))=tr((Lh3w3Tσ(a2)w2Tσ(a1))Td(h0w1+b1))=tr((Lh3w3Tσ(a2)w2Tσ(a1))Td(h0w1))+tr((Lh3w3Tσ(a2)w2Tσ(a1))Td(b1))=tr((Lh3w3Tσ(a2)w2Tσ(a1))Th0d(w1))+tr((Lh3w3Tσ(a2)w2Tσ(a1))Td(b1))+tr((Lh3w3Tσ(a2)w2Tσ(a1))Td(h0)w1)

可以得出

Lw1=(Lh3w3Tσ(a2)w2Tσ(a1))Th0)T=h0T(Lh3w3Tσ(a2)w2Tσ(a1))Lb2=Lh3w3Tσ(a2)w2Tσ(a1)

总结规律

将分别将 w1,w2,w3的偏导表达式写出来

Lw1=h0T(Lh3w3Tσ(a2)w2Tσ(a1))Lw2=h1T(Lh3w3Tσ(a2))Lw3=h2TLh3

假设网络有 n层,编号依次为 1, 2 ...,损失函数关于网络最终输出的偏导为Lout,可以比较明显地发现以下规律:

Lwi=hi1T[Loutk=nk=i+1(wkTσ(ak1))]

注意:这里的连乘是倒序的

b1,b2,b3的表达式为

Lb1=Lh3w3Tσ(a2)w2Tσ(a1)Lb2=Lh3w3Tσ(a2)Lb3=Lh3

wi类似,写出关于bi的偏导数,其实就是wi的偏导数去掉前面的hi1

Lwi=Loutk=nk=i+1(wkTσ(ak1))

注意:这里的连乘是倒序的

代码实现

有了前面的推导过程,接下来就是代码的实现,代码将包含3个部分。

  • 数据准备
  • 模型类实现
  • 全连接层实现

测试和训练数据生成

在开始实现之前,我们需要先准备一些数据用于验证我们的模型工作是否正常。这里就直接实用sklearn库所提供的make_classification函数生成测试用的数据

python
N_CLASS = 3
from sklearn.datasets import make_classification
x, y = make_classification(n_samples=2000, n_features=5, n_classes=N_CLASS,
                           n_informative=5, n_repeated=0, n_redundant=0, n_clusters_per_class=1)
sns.scatterplot(x=x[:,0], y=x[:,1], hue=y)

测试用的数据将包含5个特征,也就是说输入的向量维度是5,总共有 3个类别

最新更新: