反向传播算法图解:为什么梯度是学习的本质(从直觉到计算图)
反向传播不是“神秘公式”,本质是把损失函数的变化,沿着计算图用链式法则分摊到每个参数的责任上。理解它,你就能解释为什么深度网络能学、为什么会梯度消失/爆炸、以及为什么自动求导可行。本文从直觉出发,推到可实现的计算图版本,并给出工程自测与排障清单。

📷 Photo by Maxim Landolfi via Pexels
先把直觉讲清:学习 = 找到“往哪边改参数能让损失下降”
训练神经网络的核心操作是更新参数 $\theta$:
$$\theta \leftarrow \theta - \eta \nabla_\theta L$$
这里 $L$ 是损失函数,$\nabla_\theta L$ 是损失对参数的梯度。
所以“学习”的本质其实是:
- 知道参数微小变化会让损失怎么变
- 然后沿着让损失下降的方向走一步
反向传播解决的唯一问题就是:如何高效算出这个梯度。
一、从链式法则开始:反向传播只是“复合函数求导”
设有复合函数:
$$y = f(g(x))$$
链式法则告诉你:
$$\frac{dy}{dx} = \frac{dy}{dg} \cdot \frac{dg}{dx}$$
神经网络就是一个超大规模的复合函数:
$$L = L(a^{(n)}(\cdots a^{(2)}(a^{(1)}(x;\theta_1);\theta_2)\cdots);\theta_n)$$
直接展开求导会爆炸;反向传播做的是:
- 把网络拆成很多“局部函数”
- 复用局部导数
- 用一次从后往前的遍历,把所有参数的梯度都算出来
二、计算图视角:把“函数”变成“节点”
工程里理解反向传播,最好从计算图(Computation Graph)入手。
以一个极简例子:
- $z = wx + b$
- $a = \sigma(z)$
- $L = (a - y)^2$
你可以画成图:
- 输入 $x,w,b,y$
- 中间节点 $z,a$
- 输出 $L$
反向传播就是在图上计算每条边的“局部导数”,并把它们按链式法则组合。
1)两个关键量:局部导数与“上游梯度”
对任意节点 $v$,我们关心的是 $\frac{\partial L}{\partial v}$。
如果 $v$ 由上一层节点 $u$ 计算得到:$v = h(u)$,那么:
$$\frac{\partial L}{\partial u} = \frac{\partial L}{\partial v} \cdot \frac{\partial v}{\partial u}$$
其中:
- $\frac{\partial L}{\partial v}$ 是上游梯度(从后面传来)
- $\frac{\partial v}{\partial u}$ 是局部导数(由当前运算决定)
这就是反向传播的“乘一下”。
2)为什么能高效:每个节点只算一次上游梯度
关键在于“缓存”。
- 前向:把中间变量 $z,a$ 缓存下来
- 反向:用缓存的中间变量计算局部导数,然后乘上上游梯度
因此,反向传播的计算量与前向传播同阶(都是遍历一次计算图),而不是对每个参数都做一次全图求导。
三、反向传播的通用算法(工程可实现版)
在工程实现里,你可以把每个算子都实现两个函数:
forward(inputs) -> outputbackward(upstream_grad, cache) -> grads_for_inputs
伪代码:
# forward pass
for op in graph.topo_order:
op.out, op.cache = op.forward(op.inputs)
# backward pass
grad[L] = 1
for op in reverse(graph.topo_order):
grads = op.backward(grad[op.out], op.cache)
accumulate(grad[op.inputs], grads)
注意 accumulate:
- 如果一个节点有多个下游(分叉),梯度要相加
这是很多初学者写错的地方。
四、为什么会梯度消失/爆炸:乘积的数值性质
在深网络里,上游梯度会经过许多层的连乘:
$$\frac{\partial L}{\partial x} = \prod_^{n} \frac{\partial a^{(k)}}{\partial a^{(k-1)}}$$
如果这些局部导数大多 $|\cdot| < 1$,乘积会迅速趋近 0(消失);大多 $> 1$ 就会变得很大(爆炸)。
工程缓解手段(你至少要能说出 3 个)
- 合理初始化(例如让激活方差稳定)
- 选择更合适的激活函数(避免长期处于饱和区)
- 归一化(LayerNorm/BatchNorm)
- 残差连接(让梯度有“捷径”)
- 梯度裁剪(clipping)
这些手段都不是“玄学”,本质是在控制连乘的数值范围。
五、工程自测:怎么确认你的梯度是对的
1)数值梯度检查(最强通用武器)
对某个参数 $\theta$:
$$\frac{\partial L}{\partial \theta} \approx \frac{L(\theta+\epsilon)-L(\theta-\epsilon)}{2\epsilon}$$
做法:
- 在小网络、小 batch 上跑
- 选一个较小的 $\epsilon$(例如 $10^{-4}$)
- 比较解析梯度与数值梯度的相对误差
2)维度与广播检查(工程里更常见)
很多 bug 不是数学错,而是:
- shape 不对
- broadcast 导致梯度累计错
- batch 维度被误当成特征维度
建议把每个算子的 backward 写成 shape 断言 + 单测。
3)梯度流可视化
训练不收敛时,别只看 loss。
- 看每层梯度范数(是否某层变成 0 或爆炸)
- 看激活分布(是否全部饱和)
六、把反向传播和大模型工程连起来:为什么你关心它
即使你不手写反向传播,你也会在大模型工程里遇到它的影子:
- 为什么 LR 调整能救训练
- 为什么某些层需要裁剪梯度
- 为什么 LayerNorm 对稳定性关键
- 为什么位置编码与深度会影响梯度流
如果你在做 LLM/Agent 系统,这些理解会直接影响你的“排障速度”。
想看更多 LLM 基础能力文章见:
常见问题
反向传播和自动求导(autograd)是什么关系?
主流框架的自动求导,本质就是把计算图构建出来,然后对每个算子调用对应的 backward,按拓扑逆序做一次反向遍历。你理解了反向传播,就理解了 autograd 的核心工作方式。
我只做推理/应用层,还需要懂这些吗?
需要“够用的理解”。尤其当你要评估模型训练侧的取舍(例如微调、蒸馏、LoRA)或排查数值稳定性问题时,反向传播的直觉能让你不再靠猜。