前言
在 CMU 10-414/714 Deep Learning System 第二个 homework 有一个小任务要对数值稳定形式的 LogSumExp 的梯度进行推导,查阅了不少资料 1,琢磨好半天才搞懂,特此记录。
推导过程
符号说明
推导过程中使用的符号说明如下:
$$
\begin{align*}
z &\in \mathbb{R}^n\\
z_k &= \max{z}\\
\hat{z} &= z - \max{z}\\
f &= \log{\sum_{i=1}^n{\exp{(z_i - \max{z})}}+\max{z}}\\
&=\log{\sum_{i=1}^n\exp\hat{z}_i}+z_k
\end{align*}
$$
非最大情况推导
当 $z_j\neq z_k$ 时,$\frac{\partial{f}}{\partial{z_j}}$ 推导如下:
$$
\begin{align*}
\frac{\partial{f}}{\partial{z_j}}
&=\frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\partial z_j} + \frac{\partial z_k}{\partial{z_j}} \\
&= \frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\sum_{i=1}^n\exp\hat{z}_i}\cdot \frac{\sum_{i=1}^n\exp\hat{z}_i}{\partial{z_j}}+0 \\
&=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot(\sum_{i\neq j} \frac{\partial\exp{\hat z_i}}{\partial z_j}+\frac{\partial \exp{\hat z_j}}{\partial z_j}) \\
&=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot(0+\exp{\hat{z}_j}) \\
&=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}
\end{align*}
$$
最大情况推导
当 $z_j= z_k$ 时,$\frac{\partial{f}}{\partial{z_j}}$ 推导如下:
$$
\begin{align*}
\frac{\partial{f}}{\partial{z_j}}
&=\frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\partial z_j} + \frac{\partial z_k}{\partial{z_j}} \\
&= \frac{\partial{(\log{\sum_{i=1}^n\exp\hat{z}_i)}}}{\sum_{i=1}^n\exp\hat{z}_i}\cdot \frac{\sum_{i=1}^n\exp\hat{z}_i}{\partial{z_j}}+1 \\
&=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot [\sum_{z_i \neq z_k}{\frac{\partial \exp{(z_i-z_k)}}{\partial z_j}}+\sum_{z_i=z_k}{\frac{\partial \exp{(z_i-z_k)}}{\partial z_j}}]+1\\
&\text{注意,上式中有}z_j=z_k\\
&=\frac{1}{\sum_{i=1}^n\exp\hat{z}_i}\cdot[\sum_{z_i \neq z_k}{-\exp(z_i-z_k)}+0]+1 \\
&= 1-\frac{\sum_{z_i \neq z_k}{\exp(z_i-z_k)}}{\sum_{i=1}^n\exp\hat{z}_i} \\
&=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}
\end{align*}
$$
一般情况
注意到无论 $z_j$ 是不是最大值,都有:
$$
\frac{\partial{f}}{\partial{z_j}}=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}
$$
这里我们讨论的是 $f\in \mathbb{R}$ 且 $z\in\mathbb{R}^n$ 的情况,实际情况中,$f$ 和 $z$ 都是高维张量,我们要求 $z$ 关于 $z$ 的梯度,即 $\nabla_z f$。
代码实现
首先感谢 yofufufufu 的不吝赐教,代码实现主要参考他的解释 2。我们继续来化简公式:
$$
\begin{align*}
\frac{\partial{f}}{\partial{z_j}}
&=\frac{\exp{\hat{z}_j}}{\sum_{i=1}^n\exp\hat{z}_i}\\
&=\exp(z_j - \log \sum_{i=1}^n\exp\hat{z}_i)\\
&=\exp(z_j - f)
\end{align*}
$$
惊喜地发现,LogSumExp 这个函数的梯度可以用其输入和输出来表示,那在代码实现中,只要获取该节点的输入和输出就可以计算出梯度,即在 cmu10414 课程,该节点实现如下:
|
|