首页 > 技术文章 > 【知识相关】让向量、矩阵和张量的求导更简洁些吧

lky-learning 2020-05-30 19:24 原文

本文是我在阅读Erik Learned-Miller的《Vector, Matrix, and Tensor Derivatives》时的记录,点此下载

本文的主要内容是帮助你学习如何进行向量、矩阵以及高阶张量(三维及以上的数组)的求导。并一步步引导你来进行向量、矩阵和张量的求导。

1  简化、简化,还是简化(重要的事情说三遍)

在求解涉及到数组的导数时,大部分的困难是因为试图一次性做太多事情。比如说同时求解多个组成部分的导数,在求和符号存在的情况下求解导数,或者使用链式法则。在有丰富的求导经验之前,同时执行所有的这些操作,我们就很容易出错。

1.1  将矩阵计算分解为单个标量的计算

为了简化给定的计算,我们将矩阵的求导分解为每个单独标量元素的表达式,每个表达式只包含标量变量。在写出单个标量元素与其他标量值的表达式后,就可以使用微积分来计算。这比同时进行矩阵的求和以及求导要容易一些。(看起来有点晕,没关系,看后面的案例就清晰了)。

In order to simplify a given calculation, it is often useful to write out the explicit formula for a single scalar element of the output in terms of nothing but scalar variables. Once one has an explicit formula for a single scalar element of the output in terms of other scalar values, then one can use the calculus that you used as a beginner, which is much easier than trying to do matrix math, summations, and derivatives all at the same time.

例如:假设我们有一个\(C\)阶列向量\(\overrightarrow{y}\),它是由\(C\times D\)维矩阵\(W\)和\(D\)阶列向量\(\overrightarrow{x}\)计算得到:

\(\overrightarrow{y} = W\overrightarrow{x}\tag{1}\)

假设我们计算\(\overrightarrow{y}\)关于\(\overrightarrow{x}\)的导数。要完完全全的求解导数,就需要计算\(\overrightarrow{y}\)中的每一个元素对\(\overrightarrow{x}\)中的每一个元素的(偏)导数。那么在本例中,因为\(\overrightarrow{y}\)中有\(C\)个元素,\(\overrightarrow{x}\)中有\(D\)个元素,所以一个包含\(C\times D\)次运算。

比如说,我们要计算\(\overrightarrow{y}\)的第3个元素对\(\overrightarrow{x}\)的第7个元素的(偏)导数,这就是向量中的一个标量对其他向量中的一个标量求导:

\(\frac{\partial \overrightarrow{y_3}}{\partial \overrightarrow{x_7}}\)

在求导之前,首先要做的就是写下计算\(\overrightarrow{y_3}\)的公式, 根据矩阵-向量乘法的定义,\(\overrightarrow{y_3}\)等于矩阵\(W\)中的第3行和向量\(\overrightarrow{x}\)的点积。

\(\overrightarrow{y_3} = \sum_{j=1}^{D}W_{3,j}\cdot \overrightarrow{x_j}\tag{2}\) 

现在,我们将原始的矩阵方程式(1)简化成了标量方程式。此时再进行求导就简单多了。

1.2  去除求和符号

虽然可以直接在公式(2)中求导,但是在包含求和符号(\(\sum_{}^{}\))或者连乘符号(\(\prod_{}^{}\))的方程式中求导很容易出错。在求导之前,最好先去掉求和符号,把各项相加的表达式写出来,确保每一项不出错。去掉求和符号的表达式如下所示(下标从1开始):

\(\overrightarrow{y_3}=W_{3,1}\overrightarrow{x_1}+W_{3,2}\overrightarrow{x_2}+...+W_{3,7}\overrightarrow{x_7}+...+W_{3,D}\overrightarrow{x_D}\tag{3}\)

在这个表达式中,我们专门把\(\overrightarrow{x_7}\)凸显出来,这是因为这一项正是我们要求导的项。显然,可以看出在求\(\overrightarrow{y_3}\)对\(\overrightarrow{x_7}\)的偏导数时,我们只需要关心\(W_{3,7}\overrightarrow{x_7}\)这一项即可。因为其他项都不包含\(\overrightarrow{x_7}\),它们对\(\overrightarrow{x_7}\)的偏导数均为0。接下来就很清晰了:

\(\begin{equation}  \begin{aligned} \overrightarrow{y_3}&=W_{3,1}\overrightarrow{x_1}+W_{3,2}\overrightarrow{x_2}+...+W_{3,7}\overrightarrow{x_7}+...+W_{3,D}\overrightarrow{x_D}\\ &=0 +0+...+\frac{\partial}{\partial \overrightarrow{x_7}}\left [W_{3,7}\overrightarrow{x_7} \right ]+...+0\\ &=\frac{\partial}{\partial \overrightarrow{x_7}}\left [W_{3,7}\overrightarrow{x_7} \right ]\\ &=W_{3,7} \end{aligned} \tag{4} \end{equation}\)

在求导过程中,只关注\(\overrightarrow{y}\)中的一个量和\(\overrightarrow{x}\)中的一个量,能够把求导过程简化很多。如果以后进行求导时遇到问题,采取这种方式可以帮助我们把问题简化至最基础的程度,这样便于理清思绪、找出问题所在。

1.2.1  完成求导:雅可比矩阵

我们的最终目标是计算出\(\overrightarrow{y}\)中的每个元素对\(\overrightarrow{x}\)中每个元素的导数,共计\(C\times D\)个。下面的这个雅克比矩阵直观的表示了这些导数:

\(\begin{bmatrix} \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_1}}& \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_2}}& \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_3}}&\cdots  & \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_D}}\\ \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_1}}& \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_2}}& \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_3}}&\cdots  & \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_D}}\\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_1}}& \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_2}}& \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_3}}&\cdots  & \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_D}}\\ \end{bmatrix}\) 

对于公式\(\overrightarrow{y} = W\overrightarrow{x}\)来说,\(\overrightarrow{y_3}\)对\(\overrightarrow{x_7}\)的偏导数可以用\(W_{3,7}\)来表示。实际上对于所有的\(i\)和\(j\)来说,都有

\(\frac{\partial \overrightarrow{y_i}}{\partial \overrightarrow{x_j}}=W_{i.j}\)

即上述的偏导数矩阵等于:

\(\begin{bmatrix} \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_1}}& \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_2}}& \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_3}}&\cdots  & \frac{\partial \overrightarrow{y_1}}{\partial \overrightarrow{x_D}}\\ \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_1}}& \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_2}}& \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_3}}&\cdots  & \frac{\partial \overrightarrow{y_2}}{\partial \overrightarrow{x_D}}\\ \vdots & \vdots & \vdots & \ddots & \vdots \\ \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_1}}& \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_2}}& \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_3}}&\cdots  & \frac{\partial \overrightarrow{y_C}}{\partial \overrightarrow{x_D}}\\ \end{bmatrix}=\begin{bmatrix} W_{1,1}& W_{1,2}& W_{1,3}&\cdots  & W_{1,D}\\ W_{2,1}& W_{2,2}& W_{2,3}&\cdots  & W_{2,D}\\ \vdots & \vdots & \vdots & \ddots & \vdots \\ W_{C,1}& W_{C,2}& W_{C,3}&\cdots  & W_{C,D}\\ \end{bmatrix}\) 

显然,就是\(W\)本身嘛。

因此,我们最终可以得出,对于\(\overrightarrow{y} = W\overrightarrow{x}\),\(\overrightarrow{y}\)对于\(\overrightarrow{x}\)的偏导数为:

\(\frac{\partial \overrightarrow{y}}{\partial \overrightarrow{x}}=W\)

2  行向量的情况

现在关于神经网络的第三方包特别多,在使用这些包的时候,要特别关注权值矩阵、数据矩阵等的排列。例如:数据矩阵\(X\)中包含非常多的向量,每个向量代表一个输入,那到底是矩阵中的每一行代表一个输入,还是每一列代表一个输入呢?

在第一节中,我们介绍的示例中使用的向量\(\overrightarrow{x}\)是列向量。不过当\(\overrightarrow{x}\)是行向量时,求导的基本思想是一致的。

2.1  示例2

在本例中,\(\overrightarrow{y}\)是一个\(C\)阶行向量,它是由\(D\)阶行向量\(\overrightarrow{x}\)和\(D\times C\)维矩阵\(W\)和计算得到:

\(\overrightarrow{y} =\overrightarrow{x}W\)

虽然\(\overrightarrow{y}\)和\(\overrightarrow{x}\)的元素数量和之前的列向量是一样的,但矩阵\(W\)相当于第一节使用的矩阵\(W\)的转置。并且本例中是矩阵\(W\)左乘\(\overrightarrow{x}\),而不是之前的右乘。

在本例中,我们同样可以写出\(\overrightarrow{y_3}\)的表达式:

\(\overrightarrow{y_3} = \overrightarrow{x_j}\sum_{j=1}^{D}W_{j,3} \) 

同样地,

\(\frac{\partial \overrightarrow{y_3}}{\partial \overrightarrow{x_7}}=W_{7,3}\)

注意本例中的\(W\)的下标和第一节中的相反。如果我们写出完整的雅克比矩阵的话, 我们仍然可以得出完整的求导结果:

\(\frac{\partial \overrightarrow{y}}{\partial \overrightarrow{x}}=W\)

3  维度大于2的情况

让我们考虑另一个密切相关的情形,如下式:\(\frac{\partial\overrightarrow{y}}{\partial W} \) 

在这种情形中,\(\overrightarrow{y}\)沿着一个坐标变化,而\(W\)沿着两个坐标变化。因此,整个导数自然是一个三维数组。一般避免使用“三维矩阵”这种术语,因为矩阵乘法和其他矩阵操作在三维数组中的定义尚不明确。

在处理三维数组时,试图去找到一种展示它们的方法可能带来不必要的麻烦。直接将结果定义为公式会更简单一些,这些公式可用于计算三维中的任何元素。

我们继续从计算标量的导数开始,比如\(\overrightarrow{y}\)中的一个元素\(\overrightarrow{y_3}\)和\(W\)中的一个元素\(W_{7,8}\)。首先要做的还是写出\(\overrightarrow{y_3}\)的表达式:

\(\overrightarrow{y_3}=\overrightarrow{x_1}W_{1,3}+\overrightarrow{x_2}W_{2,3}+...+\overrightarrow{x_D}W_{D,3}\tag{5}\)

显然, \(W_{7,8}\)在\(\overrightarrow{y}\)的表达式中没有起到任何作用,因此,\frac{\partial\overrightarrow{y_i}}{\partial W_{7,8}}=0

同时,\(\overrightarrow{y}\)对\(W\)中第3列元素的求导结果是非零的,正如公式(5)中展示的那样。例如\(\overrightarrow{y_3}\)对\(W_{2,3}\)的偏导数为:

\(\frac{\partial\overrightarrow{y_3}}{\partial W_{2,3}}=\overrightarrow{x_2}\tag{6}\)

一般来说,当\(\overrightarrow{y}\)中元素的下标等于\(W\)中元素的第二个下标时,其偏导数就是非零的,其他情况则为零。整理如下:

\(\frac{\partial\overrightarrow{y_j}}{\partial W_{i,j}}=\overrightarrow{x_i}\tag{7}\)

除此之外,三维数组中其他的元素都是0。如果我们用\(F\)来表示\(\overrightarrow{y}\)对\(W\)的导数,

\(F_{i,j,k}= \frac{\partial\overrightarrow{y_i}}{\partial W_{j,k}}\)

那么,\(F_{i,j,i}= \overrightarrow{x_j}\),其余的情况等于0

此时如果我们使用一个二维数组\(G\)来表示三维数组\(F\),

\(G_{i,j}=F_{i,j,i}\)

可以看出,三维数组\(F\)中的全部数据实际上都可以使用二维数组\(G\)来存储,也就是说,\(F\)中的非零部分其实是二维的,而非三维的。

以更加紧凑的方式来表示导数数组对于神经网络的高效实现来说,意义重大。

4  多维数据

前面提到的实例中,不论是\(\overrightarrow{y}\)还是\(\overrightarrow{x}\)都只是一个向量。当需要多条数据时,例如多个向量\(\overrightarrow{x}\)组成一个矩阵\(X\)时,又该如何计算呢?

我们假设每个单独的\(\overrightarrow{x}\)都是一个\(D\)阶行向量,矩阵\(X\)则是一个\(N\times D\)的二维数组。而矩阵\(W\)和之前实例中的一样,为\(D\times C\)的矩阵。此时\(Y\)的表达式为:

\(Y = XW\)

 \(Y\)是一个\(N\)行\(C\)列的矩阵。因此, \(Y\)中的每一行给出一个与输入\(X\)中对应行相关的行向量。按照之前的方式,可以写出如下表达式:

\(Y_{i,j} = \sum_{k=1}^{D}X_{i,k}W_{k,j}\)

 从这个方程式可以看出,对于偏导数\(\frac{\partial Y_{a,b}}{\partial X_{c,d}}\),只有当\(a=c\)的情况下不为0,其他情况均为0。因为 \(Y\)中的每一个元素都只对 与\(X\)中对应的那一行求导, \(Y\)与 \(X\)的不同行元素之间的导数均为0。

还可以进一步看出,计算偏导数

\(\frac{\partial Y_{i,j}}{\partial X_{i,k}}=W_{k.j}\tag{8}\)

与\(Y\)和 \(X\)的行没关系。

实际上,矩阵\(W\)包含了所有的偏导数,我们只需要根据公式(8)来找到我们想要的某个具体地偏导数。

如果用\(Y_{i,:}\)来表示\(Y\)中的第\(i\)行,用\(X_{i,:}\)来表示\(X\)中的第\(i\)行,那么

\(\frac{\partial Y_{i,:}}{\partial X_{i,:}}=W\)

5  链式法则

上面介绍了两个基本示例和求导方法,本节将上述方法和链式法则结合起来。同样,假设\(\overrightarrow{y}\)和\(\overrightarrow{x}\)为两个列向量,

\(\overrightarrow{y}=VW\overrightarrow{x}\)

在计算\(\overrightarrow{y}\)对\(\overrightarrow{x}\)的导数时,我们可以直观地将两个矩阵\(V\)和\(W\)的乘积视为另一个矩阵\(U\),则

\(\frac{\mathrm{d}\overrightarrow{y}}{\mathrm{d}\overrightarrow{x}}=VW=U\)

但是,我们想明确使用链式法则来定义中间量的过程,从而观察非标量求导是如何应用链式法则的。我们将中间量定义为\(\overrightarrow{m}=W\overrightarrow{x}\)

此时,\(\overrightarrow{y}=V\overrightarrow{m}\)

那么在求导时,我们使用链式法则:

\(\frac{\mathrm{d}\overrightarrow{y}}{\mathrm{d}\overrightarrow{x}}=\frac{\mathrm{d}\overrightarrow{y}}{\mathrm{d}\overrightarrow{m}}\frac{\mathrm{d}\overrightarrow{m}}{\mathrm{d}\overrightarrow{x}}\)

为了确保确切地清楚该式的含义,我们还是使用每次只分析一个元素的方法,\(\overrightarrow{y}\)中的一个元素对\(\overrightarrow{x}\)中的一个元素的导数为:

\(\frac{\mathrm{d}\overrightarrow{y_i}}{\mathrm{d}\overrightarrow{x_j}}=\frac{\mathrm{d}\overrightarrow{y_i}}{\mathrm{d}\overrightarrow{m}}\frac{\mathrm{d}\overrightarrow{m}}{\mathrm{d}\overrightarrow{x_j}}\)

链式法则的思想是当某个函数由复合函数表示,那么该复合函数的导师,可以用构成复合函数的各个函数的导数乘积来表示。

如果\(\overrightarrow{m}\)中有M个元素,那么上式可以写成:

\(\frac{\mathrm{d}\overrightarrow{y_i}}{\mathrm{d}\overrightarrow{x_j}}=\sum_{k=1}^{M}\frac{\mathrm{d}\overrightarrow{y_i}}{\mathrm{d}\overrightarrow{m_k}}\frac{\mathrm{d}\overrightarrow{m_k}}{\mathrm{d}\overrightarrow{x_j}}\)

回忆一下之前向量对向量的求导方法,我们可以发现,

\(\left\{\begin{matrix} \frac{\mathrm{d}\overrightarrow{y_i}}{\mathrm{d}\overrightarrow{m_k}}= V_{i,k}& \\ \frac{\mathrm{d}\overrightarrow{m_k}}{\mathrm{d}\overrightarrow{x_j}}= W_{k,j}& \end{matrix}\right.\)

整理可得:

\(\frac{\mathrm{d}\overrightarrow{y_i}}{\mathrm{d}\overrightarrow{x_j}}=\sum_{k=1}^{M}V_{i,k}W_{k,j}\)

至此,我们用\(V\)和\(W\)中的元素表示出了求导表达式。

推荐阅读