自动微分与隐式微分¶
前置自测¶
📋 前置自测(答不出 ≥ 2 题 → 先回凸分析基础与非线性优化复习)
- 什么是链式法则?对复合函数 \(f(g(x))\),如何用 \(f'\) 和 \(g'\) 表达 \((f \circ g)'(x)\)?
- 给定矩阵 \(A \in \mathbb{R}^{m \times n}\),什么是 Jacobian 矩阵?它的维度是多少?
- 什么是 KKT 条件?写出等式约束优化 \(\min f(x) \text{ s.t. } h(x)=0\) 的 KKT 系统。
- 隐函数定理的直觉含义是什么?如果 \(F(x,y)=0\) 且 \(\partial F/\partial y\) 可逆,能得到什么结论?
- 解释有限差分法计算导数 \(f'(x) \approx \frac{f(x+h)-f(x)}{h}\) 的两个误差来源。
本章目标¶
学完本章后,你将能够:(1) 在白板上用 dual number 推导前向模式 AD,用 adjoint 变量推导反向模式 AD,并清楚两者的适用场景和复杂度差异;(2) 在 JAX、PyTorch、CasADi、Ceres 中选择合适的 AD 工具实现梯度计算,并理解它们的设计哲学差异;(3) 从隐函数定理出发推导可微优化层(OptNet/cvxpylayers)的反向传播公式,理解为什么隐式微分优于展开式 AD;(4) 判断可微仿真中 first-order gradient 何时有效、何时崩溃,并据此选择合适的策略梯度方法。
本章知识地图¶
自动微分与隐式微分
├── 基础概念(§1-§2)
│ ├── 计算图与三种微分方式
│ └── 为什么 AD 既不是有限差分也不是符号微分
├── 前向模式 AD(§3)
│ ├── dual number 代数
│ ├── JVP(雅可比-向量积)
│ └── Ceres Jet 模板实现
├── 反向模式 AD(§4)
│ ├── adjoint 变量与反向传播
│ ├── VJP(向量-雅可比积)
│ └── Cheap Gradient Principle
├── 高阶 AD 与工程技巧(§5)
│ ├── Hessian-向量积(forward-over-reverse)
│ ├── hyper-dual number
│ ├── checkpointing 节省内存
│ └── Taylor 模式 AD
├── AD 框架比较与选型(§6)
│ ├── JAX / PyTorch / CasADi / Ceres / CppAD / Enzyme
│ └── 选型决策树
├── 隐式微分(§7-§8)
│ ├── 隐函数定理完整证明
│ ├── VJP 形式的隐式微分
│ └── 可微优化层(OptNet / cvxpylayers)
├── 可微物理仿真(§9)
│ ├── Pinocchio 解析微分
│ ├── Brax / MuJoCo MJX / Dojo
│ ├── 接触微分的困难与 FoBG/ZoBG 分析
│ └── SHAC 截断窗口方法
├── 工程实践(§10)
│ ├── 梯度检查
│ ├── 数值稳定性
│ └── 可微渲染简介(NeRF / 3DGS)
├── 连续伴随方法与 Neural ODE(§11)
│ ├── 连续伴随方程推导
│ ├── Neural ODE 优势与局限
│ └── discretize-then-optimize vs optimize-then-discretize
├── 深度平衡模型与定点隐式微分(§12)
│ ├── DEQ 定点隐式微分推导
│ ├── Neumann 级数与 Anderson 加速
│ └── Blondel 2022 统一框架
├── 可微 MPC 与 PDP(§13)
│ ├── Differentiable MPC(Amos 2018)
│ ├── Pontryagin Differentiable Programming(Jin 2020)
│ └── 工程选型指南
├── 可微渲染与三维表示学习(§14)
│ └── NeRF 与 3D Gaussian Splatting 的 AD 分析
└── 典型例题与累积项目(§15)
1. 计算图:理解一切的起点 ⭐¶
动机¶
你写下一段 Python 代码来计算函数值。对计算机而言,这段代码被分解为一系列**基本运算**——加、减、乘、除、sin、exp 等。这些运算之间的数据依赖关系构成一张**有向无环图**(DAG),称为**计算图**。
理解计算图至关重要,因为自动微分的全部内容都可以归结为一句话:在计算图上系统地应用链式法则。
一个具体例子¶
考虑函数 \(f(x_1, x_2) = x_1 x_2 + \sin(x_1)\)。将其拆解为基本运算:
这些中间变量 \(v_i\) 构成计算图的节点,箭头表示数据流向。在 \(x_1 = \pi/4, x_2 = 2\) 处,前向计算(primal evaluation)依次得到:
| 变量 | 运算 | 数值 |
|---|---|---|
| \(v_1\) | \(x_1\) | \(\pi/4 \approx 0.785\) |
| \(v_2\) | \(x_2\) | \(2\) |
| \(v_3\) | \(v_1 \cdot v_2\) | \(\pi/2 \approx 1.571\) |
| \(v_4\) | \(\sin(v_1)\) | \(\sin(\pi/4) \approx 0.707\) |
| \(v_5\) | \(v_3 + v_4\) | \(2.278\) |
计算图是**静态的结构信息**——它记录了哪些运算以什么顺序执行。有了这张图,我们就可以用两种不同的策略在上面传播导数。
本质洞察:计算图把"一段程序"翻译成"一张数学图"。在这张图上,链式法则不再是一个抽象的数学定理,而是一个**可以机械执行的算法**。前向模式从输入到输出传播导数,反向模式从输出到输入传播导数——两者的区别仅在于遍历方向。
跨领域类比:计算图就像工厂的装配线¶
计算图可以类比为工厂的装配线:每个节点是一个加工工位,输入的原材料(变量)经过一系列工位加工,最终产出成品(函数值)。前向模式 AD 就像在装配线上每个工位贴一张标签"这个零件对原材料 \(x_1\) 的敏感度是多少",标签从头传到尾。反向模式 AD 则是从成品开始,逆着装配线追溯"成品质量对每个工位的调整有多敏感"。
两者都利用了装配线的线性结构(链式法则),但方向不同——这决定了在不同场景下哪种更高效。
⚠️ 常见陷阱¶
💡 概念误区:认为自动微分就是数值微分(有限差分) - 错误想法:"AD 不就是用 \((f(x+h)-f(x))/h\) 来近似导数吗?" - 实际上:AD 是**精确的**(在浮点精度范围内),它通过在计算图上应用链式法则得到导数的精确值,不存在截断误差 - 根本原因:有限差分是对导数定义的数值近似,AD 是对链式法则的精确执行——两者的数学基础完全不同 - 正确理解:AD 与有限差分的关系,就像精确解与数值解的关系
🧠 思维陷阱:认为计算图只存在于深度学习框架中 - 错误想法:"计算图是 PyTorch/TensorFlow 的概念,传统优化不需要" - 实际上:任何可微分的计算过程都有隐式的计算图。Ceres 的模板元编程在编译期构建计算图;CasADi 在符号层面构建计算图;手写的 RNEA 递推也是一张特殊的计算图 - 启示:理解计算图让你能统一理解所有 AD 工具的底层原理
练习¶
- 给出 \(f(x_1, x_2, x_3) = x_1 x_2 \sin(x_3) + e^{x_1} x_2\) 的计算图(画出所有中间变量和运算)。在 \((x_1, x_2, x_3) = (1, 2, \pi/6)\) 处计算所有中间变量的值。
- 对上述计算图,数一数共有多少条边(数据依赖关系)。思考:如果函数有 \(n\) 个输入和 \(m\) 个输出,计算图的边数与 \(n, m\) 有什么关系?
- (跨章综合)回顾凸分析中的仿射函数 \(f(x) = Ax + b\)。画出 \(f: \mathbb{R}^3 \to \mathbb{R}^2\)(\(A\) 为 \(2 \times 3\) 矩阵)的计算图。这个计算图有什么特殊结构?(提示:所有运算都是线性的)
2. 三种微分方式的本质区别 ⭐¶
动机¶
在编写优化求解器时,你总需要计算目标函数的梯度。有三条路可走:手动推导符号导数、有限差分近似、自动微分。它们各有什么优劣?这个问题的答案决定了你选什么工具。
2.1 符号微分(Symbolic Differentiation)¶
符号微分是你在微积分课上学的方法:把 \(f(x) = x^2 \sin(x)\) 用求导法则变换为 \(f'(x) = 2x \sin(x) + x^2 \cos(x)\)。
优点:得到封闭形式的表达式,便于理论分析和进一步化简。
致命缺点——表达式膨胀(expression swell):对复合函数反复应用乘积法则和链式法则,表达式的长度会**指数级增长**。例如,对 \(f(x) = \prod_{i=1}^{100} \sin(ix)\) 求导,符号微分会产生一个包含 \(2^{100}\) 项的和式——而函数值本身只需要 100 次乘法。
如果不做表达式膨胀会怎样?对于复杂的目标函数(如一个 1000 步的 MPC 展开、一个 100 层的神经网络),符号微分产生的导数表达式会大到无法存储,更不用说计算。这就是为什么深度学习不用 SymPy 来求梯度。
2.2 有限差分(Finite Differences)¶
有限差分用导数的极限定义做近似:
其中 \(e_i\) 是第 \(i\) 个单位向量。这是最"朴素"的方法。
优点:实现极其简单(几行代码),不需要了解函数内部结构。
两大致命缺点:
(1) 截断误差与舍入误差的矛盾。\(h\) 太大,截断误差 \(O(h)\) 大(中心差分 \((f(x+h)-f(x-h))/(2h)\) 可减至 \(O(h^2)\),但仍非零)。\(h\) 太小,\(f(x+h)\) 和 \(f(x)\) 的差被浮点舍入误差淹没。对于 double 精度(\(\epsilon_{\text{mach}} \approx 10^{-16}\)),最优的 \(h \approx \sqrt{\epsilon_{\text{mach}}} \approx 10^{-8}\)(前向差分)或 \(h \approx \epsilon_{\text{mach}}^{1/3} \approx 10^{-5}\)(中心差分),精度上限分别约为 8 位和 10 位有效数字。
(2) 计算量与输入维度成正比。计算 \(n\) 维梯度需要 \(n\) 次(前向差分)或 \(2n\) 次(中心差分)函数求值。对于 SLAM 的百万维参数或神经网络的亿级参数,这完全不可行。
反事实推理:如果有限差分的精度是足够的(比如你只需要 3 位精度做一个粗略的调参),那么它仍然是最简单的选择。很多工程师在调试阶段用有限差分来**验证** AD 的正确性——这是它的正确用法。
2.3 自动微分(Automatic Differentiation)¶
自动微分站在符号微分和有限差分的"中间地带":它像符号微分一样**精确**(没有截断误差),又像有限差分一样**高效**(不产生表达式膨胀)。
AD 的核心思想:不对整个函数做符号求导,而是在计算图的**每个基本运算节点**上应用已知的求导规则,然后通过链式法则将这些局部导数"传播"成全局导数。
| 方法 | 精度 | 计算量 | 实现复杂度 | 适用场景 |
|---|---|---|---|---|
| 符号微分 | 精确(封闭形式) | 可能指数级膨胀 | 需要 CAS | 小规模解析推导 |
| 有限差分 | \(O(h^p)\) 近似 | \(O(n \cdot \text{cost}(f))\) | 极低 | 调试验证、低维问题 |
| 前向 AD | 精确(浮点范围内) | \(O(n \cdot \text{cost}(f))\) | 中等 | 输入少输出多 |
| 反向 AD | 精确(浮点范围内) | \(O(m \cdot \text{cost}(f))\)* | 较高 | 输入多输出少(深度学习) |
*注:反向模式的常数因子一般为 2-5 倍,这就是著名的 Cheap Gradient Principle。
历史:AD 的发现之路¶
AD 的历史可以追溯到 1960 年代。Robert E. Wengert 在 1964 年发表了第一篇关于自动微分的论文,提出了前向模式。但反向模式的发现更加曲折:Seppo Linnainmaa 在 1970 年的硕士论文中独立提出了反向累积(reverse accumulation),而 Paul Werbos 在 1974 年的博士论文中将其应用于神经网络训练,称为"反向传播"。直到 Rumelhart、Hinton 和 Williams 在 1986 年的 Nature 论文中推广了反向传播,AD 在机器学习中的重要性才被广泛认识。
Andreas Griewank 在 1989-2008 年的系统化工作将 AD 建立为一个严谨的数学学科,其著作《Evaluating Derivatives》(与 Walther 合著,SIAM 2008)至今仍是该领域的权威参考。
⚠️ 常见陷阱¶
⚠️ 编程陷阱:用有限差分做最终的梯度计算 - 错误做法:在 SLAM 求解器中用
(f(x+1e-8) - f(x)) / 1e-8计算雅可比 - 现象:对于良态问题可能看起来"差不多对",但在病态问题(条件数大)中产生严重偏差,导致优化不收敛 - 根本原因:\(h=10^{-8}\) 时截断误差和舍入误差恰好平衡,但如果函数值本身很大或变化很剧烈,这个平衡点会移动 - 正确做法:用 AD 计算精确梯度,仅在调试时用有限差分做验证💡 概念误区:认为 AD 只能算一阶导数 - 错误想法:"AD 只能算梯度,Hessian 还是得用有限差分" - 实际上:AD 可以嵌套使用(forward-over-reverse 或 reverse-over-forward)来计算任意阶导数。JAX 的
jax.hessian就是自动嵌套两次 AD - 延伸:Hessian-向量积 \(Hv = \nabla(\nabla f \cdot v)\) 只需一次前向+一次反向,成本与计算 \(f\) 本身相当
练习¶
- 用 Python 编写一个函数,分别用前向差分、中心差分和复步微分(complex step)计算 \(f(x) = e^{\sin(x^3)}\) 在 \(x=1.5\) 处的导数。比较三者的精度与真实值 \(f'(1.5)\) 的差异。
- 解释为什么中心差分 \((f(x+h)-f(x-h))/(2h)\) 的截断误差是 \(O(h^2)\) 而不是 \(O(h)\)。(提示:对 \(f(x\pm h)\) 做 Taylor 展开。)
- 如果一个函数 \(f: \mathbb{R}^{1000} \to \mathbb{R}\) 的一次求值需要 1ms,用中心差分计算完整梯度需要多长时间?用反向 AD 呢?
3. 前向模式 AD:dual number 与 JVP ⭐⭐¶
动机¶
我们要计算函数 \(f: \mathbb{R}^n \to \mathbb{R}^m\) 在 \(x_0\) 处沿方向 \(v\) 的方向导数 \(J \cdot v\)(其中 \(J = \partial f / \partial x\) 是 Jacobian 矩阵)。如果 \(n\) 很小(比如 Ceres 中每个残差只依赖 3-4 个参数),前向模式是最自然的选择。
3.1 Dual Number 的代数基础 ⭐⭐¶
定义与直觉¶
Dual number 是实数的一种扩展,定义为:
这个 \(\varepsilon\) 称为**无穷小量**,它满足 \(\varepsilon^2 = 0\) 这一关键性质。
跨领域类比:dual number 之于实数,就像复数之于实数。复数引入了 \(i^2 = -1\) 的虚数单位来扩展实数系统,使得所有多项式都有根。dual number 引入了 \(\varepsilon^2 = 0\) 的无穷小量来扩展实数系统,使得我们可以"自动"提取导数。它们都是**在实数上添加一个新元素并规定其代数规则**的代数扩展。
但两者的"用途"完全不同:复数用来解方程、做旋转;dual number 用来**无截断误差地传播导数信息**。
运算规则¶
Dual number 的四则运算规则完全由 \(\varepsilon^2 = 0\) 推导而来:
加法:\((a + b\varepsilon) + (c + d\varepsilon) = (a+c) + (b+d)\varepsilon\)
乘法:\((a + b\varepsilon)(c + d\varepsilon) = ac + (ad + bc)\varepsilon + bd\varepsilon^2 = ac + (ad + bc)\varepsilon\)
这里 \(bd\varepsilon^2 = 0\) 被丢弃——这正是"无穷小量"的妙处。
除法:\(\frac{a + b\varepsilon}{c + d\varepsilon} = \frac{(a + b\varepsilon)(c - d\varepsilon)}{(c + d\varepsilon)(c - d\varepsilon)} = \frac{ac + (bc - ad)\varepsilon}{c^2} = \frac{a}{c} + \frac{bc - ad}{c^2}\varepsilon\)
为什么 dual number 能自动求导¶
核心观察:对任意解析函数 \(f\),将其在 \(a\) 处 Taylor 展开:
因为 \(\varepsilon^2 = 0\),所有二阶及以上的项全部消失。这意味着**把 \(x_0 + 1 \cdot \varepsilon\) 代入 \(f\),实部给出 \(f(x_0)\),虚部给出 \(f'(x_0)\)**——函数值和导数同时得到,精确无误。
让我们用几个具体的基本函数验证这一点。
sin:\(\sin(a + b\varepsilon) = \sin(a) + \cos(a) \cdot b\varepsilon\)(因为 \(\sin'(x) = \cos(x)\))
exp:\(\exp(a + b\varepsilon) = \exp(a) + \exp(a) \cdot b\varepsilon = \exp(a)(1 + b\varepsilon)\)
log:\(\log(a + b\varepsilon) = \log(a) + \frac{b}{a}\varepsilon\)
幂函数:\((a + b\varepsilon)^n = a^n + n a^{n-1} b\varepsilon\)
这些规则与微积分中的求导法则**完全一致**——这不是巧合,而是 dual number 代数结构的必然结果。
本质洞察:dual number 把"求导"这个分析学操作编码进了代数运算中。你不需要"对程序做分析"来求导,只需要"用 dual number 替换实数运行程序"——导数会自动出现在 \(\varepsilon\) 系数中。这就是为什么 Ceres 的
Jet<T,N>类只需要重载+, -, *, /, sin, cos, exp, log等运算符,就能对任意模板化的 cost function 自动求导。
3.2 从 dual number 到 JVP ⭐⭐¶
多元函数的推广¶
对多元函数 \(f: \mathbb{R}^n \to \mathbb{R}^m\),我们引入 \(n\) 个独立的无穷小方向 \(\dot{x}_1, \ldots, \dot{x}_n\)(称为"tangent"或"种子向量"),计算:
其中 \(J(x_0) = \frac{\partial f}{\partial x}\big|_{x_0}\) 是 \(m \times n\) 的 Jacobian 矩阵。\(\varepsilon\) 系数给出的是 Jacobian-向量积(Jacobian-Vector Product, JVP):
如果选择 \(\dot{x} = e_i\)(第 \(i\) 个单位向量),则 JVP 给出 Jacobian 的第 \(i\) 列。因此,计算完整 Jacobian 需要 \(n\) 次 JVP 调用。
前向传播的逐步执行¶
回到 §1 的例子 \(f(x_1, x_2) = x_1 x_2 + \sin(x_1)\)。选种子 \(\dot{x}_1 = 1, \dot{x}_2 = 0\)(即求 \(\partial f / \partial x_1\)):
| 变量 | primal 值 | tangent(\(\varepsilon\) 系数) | 运算规则 |
|---|---|---|---|
| \(v_1 = x_1\) | \(\pi/4\) | \(\dot{v}_1 = 1\) | 种子 |
| \(v_2 = x_2\) | \(2\) | \(\dot{v}_2 = 0\) | 种子 |
| \(v_3 = v_1 \cdot v_2\) | \(\pi/2\) | \(\dot{v}_3 = \dot{v}_1 v_2 + v_1 \dot{v}_2 = 2\) | 乘法的 dual rule |
| \(v_4 = \sin(v_1)\) | \(\sin(\pi/4)\) | \(\dot{v}_4 = \cos(v_1) \cdot \dot{v}_1 = \cos(\pi/4)\) | sin 的 dual rule |
| \(v_5 = v_3 + v_4\) | \(2.278\) | \(\dot{v}_5 = \dot{v}_3 + \dot{v}_4 = 2.707\) | 加法的 dual rule |
最终 \(\dot{v}_5 = 2 + \cos(\pi/4) \approx 2.707\) 就是 \(\partial f / \partial x_1\)。可以手工验证:\(\partial f / \partial x_1 = x_2 + \cos(x_1) = 2 + \cos(\pi/4) \approx 2.707\)。精确匹配!
复杂度分析¶
一次 JVP 的计算量等于**一次函数求值加一次同样结构的 tangent 传播**。每个基本运算的 tangent 传播最多是常数倍的额外运算(如乘法需要 2 次额外乘法和 1 次加法)。因此:
计算完整 \(m \times n\) Jacobian 需要 \(n\) 次 JVP,总成本 \(O(n \cdot \text{cost}(f))\)。
前向模式的最佳场景:\(n \ll m\)——输入维度远小于输出维度。典型例子:Ceres 中每个残差 \(r_i: \mathbb{R}^6 \to \mathbb{R}^2\)(相机位姿 6 参数 → 重投影误差 2 维),\(n=6, m=2\),前向模式只需 6 次 JVP。
如果 \(n\) 很大会怎样? 考虑一个神经网络 \(f: \mathbb{R}^{10^7} \to \mathbb{R}\)(一亿参数 → 标量 loss)。前向模式需要 \(10^7\) 次 JVP——显然不可行。这就是反向模式登场的理由。
3.3 Ceres Jet 的模板实现 ⭐⭐¶
Ceres Solver 用 C++ 模板元编程实现了前向 AD,核心是 Jet<T, N> 类。理解 Jet 的实现有助于深刻理解前向 AD 的工程化。
// Ceres Jet 的核心结构(简化版,展示设计思想)
template <typename T, int N>
struct Jet {
T a; // primal 值(实部)
T v[N]; // tangent 值(N 个方向的导数)
// 构造函数:从标量构造(所有 tangent 为零)
explicit Jet(const T& value) : a(value) {
for (int i = 0; i < N; ++i) v[i] = T(0);
}
// 构造函数:指定第 k 个方向的种子
Jet(const T& value, int k) : a(value) {
for (int i = 0; i < N; ++i) v[i] = (i == k) ? T(1) : T(0);
}
};
// 乘法重载:(a + v*ε) * (b + w*ε) = ab + (aw + bv)ε
template <typename T, int N>
Jet<T, N> operator*(const Jet<T, N>& f, const Jet<T, N>& g) {
Jet<T, N> result;
result.a = f.a * g.a; // primal: 普通乘法
for (int i = 0; i < N; ++i)
result.v[i] = f.a * g.v[i] + f.v[i] * g.a; // tangent: 乘法法则
return result;
}
// sin 重载:sin(a + v*ε) = sin(a) + cos(a)*v*ε
template <typename T, int N>
Jet<T, N> sin(const Jet<T, N>& f) {
Jet<T, N> result;
result.a = std::sin(f.a); // primal: 普通 sin
T cos_a = std::cos(f.a);
for (int i = 0; i < N; ++i)
result.v[i] = cos_a * f.v[i]; // tangent: cos(a) * v
return result;
}
设计哲学:Ceres 要求用户把 cost function 写成 C++ 模板 template<typename T> bool operator()(const T* x, T* residual)。当 T = double 时,编译器生成计算函数值的代码;当 T = Jet<double, N> 时,编译器自动生成同时计算函数值和导数的代码——两者**共享同一份源码**,只是运算符重载不同。这就是"一份代码,两种用途"的模板元编程范式。
为什么 N 是模板参数而不是运行时参数? 因为 N 在编译期已知时,编译器可以展开循环、使用 SIMD 指令、消除内存分配——对 SLAM 中需要每秒调用上百万次的 cost function,这种优化至关重要。
⚠️ 常见陷阱¶
⚠️ 编程陷阱:在 Ceres cost function 中调用非模板化的外部函数 - 错误做法:
double result = sqrt(x[0]);而不是T result = ceres::sqrt(x[0]);- 现象:当T = Jet<double, N>时,double sqrt(double)会把 Jet 隐式转成 double,丢失所有导数信息,最终得到零导数 - 根本原因:C++ 的隐式类型转换会悄无声息地丢弃 tangent 信息 - 正确做法:cost function 中所有函数调用必须template<typename T>化💡 概念误区:认为 Jet 的 N 等于总参数个数 - 错误想法:"我的问题有 1000 个参数,所以 N=1000" - 实际上:Ceres 的
AutoDiffCostFunction<F, residual_dim, param_block_1_dim, param_block_2_dim, ...>中,N = 所有参数块维度之和。但由于 Ceres 每次只对**一个残差**求 Jacobian,N 通常很小(6-15),不是总参数个数 - 根本原因:稀疏性——每个残差只依赖少数参数块,Jacobian 是稀疏的
练习¶
- 手动实现一个 40 行的 C++
Dual类(不用模板N,只支持一个方向),重载+, -, *, /, sin, cos, exp, log。用它计算 \(f(x) = \sin(x^2) \cdot e^x\) 在 \(x=1\) 的 \(f\) 和 \(f'\),与真实值比较。 - 解释为什么 Jet 的乘法规则 \(\dot{v}_3 = \dot{v}_1 v_2 + v_1 \dot{v}_2\) 与微积分的乘法法则 \((fg)' = f'g + fg'\) 完全一致。
- 假设你有一个 cost function \(r: \mathbb{R}^3 \to \mathbb{R}^2\),使用 Ceres
AutoDiffCostFunction<F, 2, 3>。计算一次 Jacobian 需要多少次 JVP?总共涉及多少个 Jet 的基本运算?
4. 反向模式 AD:adjoint 与 VJP ⭐⭐¶
动机¶
深度学习的成功建立在一个关键数学事实之上:对于 \(f: \mathbb{R}^n \to \mathbb{R}\)(\(n\) 可以是数亿),计算完整梯度 \(\nabla f\) 的代价**与 \(n\) 无关**,只与计算 \(f\) 本身的代价成正比。这就是反向模式 AD(backpropagation 是其特例)的力量。
4.1 Adjoint 变量的定义 ⭐⭐¶
对计算图中的每个中间变量 \(v_i\),定义其 adjoint(伴随变量、灵敏度):
其中 \(y\) 是最终输出。\(\bar{v}_i\) 回答的问题是:"如果 \(v_i\) 变化 \(\delta\),最终输出 \(y\) 变化多少?"
关键观察:输出节点的 adjoint 是 \(\bar{y} = \partial y / \partial y = 1\)。从这个"种子"出发,我们可以用链式法则**逆着计算图的方向**逐层传播 adjoint。
对于节点 \(v_j\),假设它的值被直接用在节点 \(v_{i_1}, v_{i_2}, \ldots, v_{i_k}\) 的计算中。由全微分公式(链式法则的多路版本):
这就是反向模式 AD 的核心递推公式。它从输出向输入**反向**传播,每一步只需要知道"局部导数"\(\partial v_i / \partial v_j\)(这些在前向计算时已经确定了)。
4.2 一个完整的反向传播例子 ⭐⭐¶
继续使用 \(f(x_1, x_2) = x_1 x_2 + \sin(x_1)\),在 \(x_1 = \pi/4, x_2 = 2\) 处。
前向计算(存储所有中间值):
| 变量 | 运算 | 值 |
|---|---|---|
| \(v_1\) | \(x_1\) | \(\pi/4\) |
| \(v_2\) | \(x_2\) | \(2\) |
| \(v_3\) | \(v_1 \cdot v_2\) | \(\pi/2\) |
| \(v_4\) | \(\sin(v_1)\) | \(\sqrt{2}/2\) |
| \(v_5\) | \(v_3 + v_4\) | \(\pi/2 + \sqrt{2}/2\) |
反向传播(从 \(\bar{v}_5 = 1\) 开始):
| adjoint | 递推公式 | 值 |
|---|---|---|
| \(\bar{v}_5\) | 种子 | \(1\) |
| \(\bar{v}_3\) | \(\bar{v}_5 \cdot \frac{\partial v_5}{\partial v_3} = 1 \cdot 1\) | \(1\) |
| \(\bar{v}_4\) | \(\bar{v}_5 \cdot \frac{\partial v_5}{\partial v_4} = 1 \cdot 1\) | \(1\) |
| \(\bar{v}_1\) | \(\bar{v}_3 \cdot \frac{\partial v_3}{\partial v_1} + \bar{v}_4 \cdot \frac{\partial v_4}{\partial v_1}\) | \(1 \cdot v_2 + 1 \cdot \cos(v_1) = 2 + \cos(\pi/4)\) |
| \(\bar{v}_2\) | \(\bar{v}_3 \cdot \frac{\partial v_3}{\partial v_2}\) | \(1 \cdot v_1 = \pi/4\) |
结果:\(\frac{\partial f}{\partial x_1} = 2 + \cos(\pi/4) \approx 2.707\),\(\frac{\partial f}{\partial x_2} = \pi/4 \approx 0.785\)。与前向模式得到的结果完全一致——这不是巧合,两者都是链式法则的精确应用,只是遍历方向不同。
4.3 VJP(向量-雅可比积)⭐⭐¶
反向模式计算的是 VJP(Vector-Jacobian Product):
其中 \(\bar{y}\) 是输出空间的"种子向量"(cotangent)。当 \(m=1\) 时 \(\bar{y}=1\),一次 VJP 直接给出完整梯度——这就是为什么反向模式对标量输出函数如此高效。
如果选择 \(\bar{y} = e_j\)(第 \(j\) 个单位向量),则 VJP 给出 Jacobian 的第 \(j\) 行。计算完整 Jacobian 需要 \(m\) 次 VJP 调用,总成本 \(O(m \cdot \text{cost}(f))\)。
前向与反向的对称关系:
| 前向模式(JVP) | 反向模式(VJP) | |
|---|---|---|
| 计算 | \(J \cdot v\)(Jacobian 乘向量) | \(v^\top J\)(向量乘 Jacobian) |
| 种子 | 输入空间的向量 \(\dot{x}\) | 输出空间的向量 \(\bar{y}\) |
| 遍历方向 | 计算图正方向 | 计算图逆方向 |
| 一次得到 | Jacobian 的一列 | Jacobian 的一行 |
| 全 Jacobian | \(n\) 次调用 | \(m\) 次调用 |
| 最佳场景 | \(n \ll m\) | \(n \gg m\) |
| 额外存储 | \(O(1)\) | \(O(\text{计算图节点数})\) |
反事实推理:如果深度学习的 loss 不是标量而是高维向量(比如 \(m = 10^6\)),那么反向模式就不再高效——每个输出分量需要一次反向传播。在这种情况下(如计算完整 Jacobian),最优策略取决于 \(n\) 和 \(m\) 的相对大小:\(n < m\) 用前向,\(n > m\) 用反向。
4.4 Cheap Gradient Principle ⭐⭐⭐¶
定理(Baur-Strassen 1983;Griewank-Walther Thm. 4.11):对任意 \(f: \mathbb{R}^n \to \mathbb{R}\),反向模式计算完整梯度 \(\nabla f\) 的运算量满足:
其中 \(c\) 是常数,典型值 2-3,与输入维度 \(n\) 无关。
为什么 \(c\) 有上界? 在反向传播中,每个基本运算的反向对应有确定的运算量上界:
| 基本运算 | 前向代价 | 反向附加代价 | \(c\) |
|---|---|---|---|
| \(v = a + b\) | 1 ADD | \(\bar{a} += \bar{v}\), \(\bar{b} += \bar{v}\) | 2 ADD |
| \(v = a \times b\) | 1 MUL | \(\bar{a} += \bar{v} \cdot b\), \(\bar{b} += \bar{v} \cdot a\) | 2 MUL + 2 ADD |
| \(v = \sin(a)\) | 1 SIN | \(\bar{a} += \bar{v} \cdot \cos(a)\) | 1 COS + 1 MUL + 1 ADD |
| \(v = a / b\) | 1 DIV | \(\bar{a} += \bar{v}/b\), \(\bar{b} -= \bar{v} \cdot a / b^2\) | 2 DIV + 1 MUL + 2 ADD |
每个基本运算的反向代价是常数倍的前向代价,因此总的反向代价也是常数倍。
这个定理的深远意义:它意味着无论你的神经网络有多少参数(\(n = 10^9\)),计算完整梯度的代价始终只是"多做几次前向计算"。这是深度学习得以训练超大模型的数学基础——如果梯度计算的代价与 \(n\) 成正比,现代的 GPT 级模型根本不可能训练。
4.5 反向模式的内存代价分析 ⭐⭐⭐¶
反向模式 AD 的一个经常被忽视的代价是**内存**。为什么反向模式需要额外内存,而前向模式不需要?
在前向模式中,tangent \(\dot{v}_i\) 与 primal \(v_i\) 同步计算——当你计算到节点 \(v_i\) 时,所有父节点的 tangent 已经可用。因此,前向模式可以用 \(O(1)\) 额外内存(只存储当前层的 tangent)实现流式计算。
在反向模式中,adjoint \(\bar{v}_i\) 的计算需要 \(v_i\) 的**局部导数** \(\partial v_i / \partial v_j\),而这些局部导数依赖于**前向计算时的 primal 值**。例如,\(v_3 = v_1 \cdot v_2\) 的局部导数是 \(\partial v_3 / \partial v_1 = v_2\)——反向传播到节点 \(v_1\) 时需要用到 \(v_2\) 的值,但此时前向计算早已结束。
因此,反向模式必须**先完成整个前向计算并存储所有中间值**,然后才能开始反向传播。存储这些中间值的数据结构称为 tape(磁带)。
tape 的内存消耗:\(O(W)\),其中 \(W\) 是计算图的总节点数(work)。对于深度学习,\(W\) 与网络参数量和 batch size 成正比;对于 MPC 展开,\(W\) 与 horizon 长度成正比。
| 情况 | tape 大小 | 典型值 |
|---|---|---|
| ResNet-50 | \(\sim 10^7\) 节点 | \(\sim 200\) MB |
| 1000 步 RNN | \(\sim 10^6 \times 1000\) | 数 GB |
| MPC 展开 50 步 | \(\sim 10^4 \times 50\) | \(\sim 10\) MB |
这就是 §5.3 中 checkpointing 技术的必要性来源。
4.6 反向传播就是反向模式 AD 的特例 ⭐⭐¶
"反向传播"(backpropagation)和"反向模式 AD"在数学上完全相同。历史上之所以有两个名字,是因为它们在两个不同的社区(机器学习 vs 科学计算)被独立发现。
具体来说,深度学习中的标准反向传播是反向模式 AD 应用于一类特殊的计算图——层级结构**的计算图(每层的输出是下一层的输入)。但反向模式 AD 适用于**任意 DAG 结构,包括有多路分支、汇合的复杂计算图。
本质洞察:反向传播不是一个独立的算法,而是链式法则在计算图上从输出到输入方向遍历的自然结果。一旦你理解了"adjoint 变量沿反向边传播"这个机制,所有"xx 的反向传播公式"(如 LSTM 的、Transformer 的、GNN 的)都不需要单独记忆——它们都是同一个通用算法在不同图结构上的实例化。
⚠️ 常见陷阱¶
⚠️ 编程陷阱:反向 AD 的内存爆炸 - 错误做法:对一个 1000 步的 RNN 或 MPC 展开做完整反向传播 - 现象:需要存储所有 1000 步的中间激活(前向计算的值),GPU 内存耗尽 - 根本原因:反向模式需要在反向传播时**重用前向计算的中间值**(如 \(\cos(v_1)\)),所以这些值必须保存到反向传播到达该节点时 - 正确做法:使用 checkpointing(§5.3)或切换到隐式微分(§7-§8)
🧠 思维陷阱:认为"反向模式总是比前向模式好" - 错误推理:"Cheap Gradient Principle 说反向模式代价与 \(n\) 无关,所以永远用反向模式" - 实际上:(1) 反向模式需要 \(O(\text{图大小})\) 的额外内存存储中间值,前向模式只需 \(O(1)\);(2) 当 \(n < m\) 时前向模式更快;(3) 某些硬件(如没有足够缓存的嵌入式系统)上,前向模式的简单内存模式更有优势 - 正确思维:前向和反向是**互补**的工具,选择取决于 \(n, m\) 的比例和内存约束
练习¶
- 对 \(f(x_1, x_2, x_3) = (x_1 x_2 + x_3)^2\) 执行完整的反向传播,验证 \(\nabla f\)。
- 考虑函数 \(f: \mathbb{R}^{100} \to \mathbb{R}^{200}\)。分别用前向和反向模式计算完整 Jacobian,哪个更快?需要多少次 JVP/VJP 调用?
- (跨章综合)回顾凸优化中的梯度下降:\(x_{k+1} = x_k - \alpha \nabla f(x_k)\)。如果 \(f\) 是一个神经网络的 loss(\(n = 10^6\) 参数),解释为什么我们必须用反向 AD 而非前向 AD 来计算 \(\nabla f\)。估算两者的计算时间比。
5. 高阶 AD 与工程技巧 ⭐⭐⭐¶
动机¶
一阶导数(梯度/Jacobian)是优化的基本工具,但很多场景需要更高阶的信息:Newton 法需要 Hessian、灵敏度分析需要二阶导数、不确定性传播需要曲率信息。高阶 AD 提供了高效计算这些量的方法。
5.1 Hessian-向量积:forward-over-reverse ⭐⭐⭐¶
对 \(f: \mathbb{R}^n \to \mathbb{R}\),Hessian 矩阵 \(H = \nabla^2 f\) 是 \(n \times n\) 的。直接计算完整 Hessian 的代价是 \(O(n^2 \cdot \text{cost}(f))\)——对大 \(n\) 不可接受。
但 Hessian-向量积 \(Hv = \nabla^2 f \cdot v\) 可以用"forward-over-reverse"技术以 \(O(\text{cost}(f))\) 的代价计算:
步骤: 1. **反向模式**计算 \(\nabla f(x)\)(代价 \(c \cdot \text{cost}(f)\)) 2. 计算内积 \(g(x) = \nabla f(x)^\top v\)(代价 \(O(n)\)) 3. **前向模式**计算 \(\nabla g(x)\)(对 \(g\) 做一次 JVP,代价 \(c' \cdot \text{cost}(g)\))
总代价约 \(O(\text{cost}(f))\),与 \(n\) 无关。这使得 truncated Newton 法、共轭梯度法中的 Hessian-向量积变得高效可行。
在 JAX 中,这可以简洁地表达为:
import jax
import jax.numpy as jnp
def hvp(f, x, v):
"""Hessian-向量积:H @ v = d/dt grad(f)(x + t*v)|_{t=0}"""
return jax.jvp(lambda x: jax.grad(f)(x), (x,), (v,))[1]
# 等价写法(更高效)
def hvp_alt(f, x, v):
return jax.grad(lambda x: jnp.dot(jax.grad(f)(x), v))(x)
5.2 Hyper-dual Number ⭐⭐⭐⭐¶
Dual number 只能提取一阶导数(\(\varepsilon^2 = 0\) 丢弃了二阶信息)。Hyper-dual number 引入两个独立的无穷小量 \(\varepsilon_1, \varepsilon_2\),满足:
一个 hyper-dual number 写作 \(a + b\varepsilon_1 + c\varepsilon_2 + d\varepsilon_1\varepsilon_2\)。
将函数 \(f\) 作用在 \(x_0 + 1 \cdot \varepsilon_1 + 1 \cdot \varepsilon_2 + 0 \cdot \varepsilon_1\varepsilon_2\) 上:
\(\varepsilon_1\varepsilon_2\) 的系数直接给出**精确的二阶导数** \(f''(x_0)\),无截断误差。
与中心差分的二阶导数对比:中心差分 \(f''(x) \approx (f(x+h) - 2f(x) + f(x-h))/h^2\) 有 \(O(h^2)\) 的截断误差和严重的舍入误差(分母 \(h^2\) 放大了数值噪声)。Hyper-dual number 完全避免了这些问题。
反事实推理:如果没有 hyper-dual number,计算精确的二阶导数只能通过(1)符号微分(表达式膨胀)或(2)嵌套 AD(需要构建二层计算图)。Fike 和 Alonso 在 2011 年提出 hyper-dual number 后,一些简单场景下的二阶导数计算变得和一阶一样直接。
5.3 Checkpointing:时间换内存 ⭐⭐⭐¶
反向模式 AD 的主要瓶颈是内存:它需要存储前向计算的所有中间值,供反向传播使用。对于长序列模型(RNN、MPC 展开、ODE 数值积分),内存消耗与序列长度成正比。
Checkpointing(又称 rematerialization)是一种**时间换内存**的策略:
基本思想:不存储所有中间值,只在若干"检查点"处存储状态。反向传播到某个区间时,从最近的检查点开始**重新前向计算**该区间的中间值。
经典结果(Griewank 1992):对 \(K\) 步的计算序列: - 完全存储:内存 \(O(K)\),不需要重计算 - 均匀 checkpointing(每 \(\sqrt{K}\) 步存一次):内存 \(O(\sqrt{K})\),重计算代价 \(O(\sqrt{K})\) - 最优 binomial checkpointing(Revolve 算法):给定 \(c\) 个检查点,最优时间代价为 \(O(K \log K / \log c)\)
| 策略 | 内存 | 时间(相对于无 checkpoint) | 适用场景 |
|---|---|---|---|
| 不用 checkpoint | \(O(K)\) | \(1\times\) | 短序列 (\(K < 100\)) |
| \(\sqrt{K}\)-checkpoint | \(O(\sqrt{K})\) | \(\sim 1.5\times\) | 中等长度 |
| Revolve 最优 | \(O(\log K)\) | \(\sim 2\times\) | 超长序列 |
在 PyTorch 中通过 torch.utils.checkpoint 使用,在 JAX 中通过 jax.checkpoint 使用。
5.4 Taylor 模式 AD ⭐⭐⭐⭐¶
Taylor 模式 AD 是前向模式的高阶推广。它不是用 dual number(\(\varepsilon^2 = 0\)),而是保留 \(\varepsilon\) 的高阶项,到 \(\varepsilon^d\) 为止。这相当于同时计算函数在一点处的 Taylor 系数 \(f(x_0), f'(x_0), f''(x_0)/2!, \ldots, f^{(d)}(x_0)/d!\)。
每个基本运算的传播规则变成了多项式乘法/卷积。例如,乘法 \(w = u \cdot v\) 的第 \(k\) 阶 Taylor 系数满足:
这是 Cauchy 乘积公式——与多项式/形式幂级数的乘法完全相同。
Taylor 模式的应用场景包括:ODE 初值问题的高阶方法(Taylor 积分器比 Runge-Kutta 更高效)、函数的 Padé 近似、以及代数方程的级数解。
⚠️ 常见陷阱¶
🧠 思维陷阱:总是计算完整 Hessian - 错误想法:"Newton 法需要 Hessian,所以我要算出完整的 \(n \times n\) 矩阵" - 实际上:(1) Newton 法的核心步骤 \(H^{-1} \nabla f\) 可以通过 CG 迭代用 Hessian-向量积实现,不需要显式 Hessian;(2) 对 \(n > 1000\) 的问题,存储 Hessian 本身就不可行(\(10^6\) 个 double = 8 MB,\(10^4\) 维就是 800 MB) - 正确思维:先问"我真的需要完整 Hessian 吗?"——多数情况下 Hessian-向量积就够了
⚠️ 编程陷阱:不用 checkpointing 训练长序列 - 错误做法:对 1000 步 RNN 直接
loss.backward()- 现象:OOM(Out of Memory)或 GPU 利用率极低(因为大部分显存被中间值占满) - 正确做法:使用torch.utils.checkpoint包装 RNN cell,或切到 truncated BPTT
练习¶
- 推导 hyper-dual number 下 \(f(x) = x^3\) 的展开:\(f(a + b\varepsilon_1 + c\varepsilon_2 + d\varepsilon_1\varepsilon_2) = ?\),验证 \(\varepsilon_1\varepsilon_2\) 系数为 \(6a \cdot bc\)(当 \(b=c=1, d=0\) 时等于 \(f''(a)=6a\))。
- 对一个 \(K=1024\) 步的 ODE 积分器,如果使用 \(\sqrt{K} = 32\) 个 checkpoint,内存节省多少倍?重计算的时间开销大约增加多少?
- 用 JAX 实现 Hessian-向量积函数
hvp(f, x, v)并验证其正确性。
6. AD 框架比较与选型 ⭐⭐¶
动机¶
"该用哪个 AD 框架?"这是工程师最常问的问题之一。答案取决于你的语言偏好、问题规模、部署环境和性能要求。本节给出系统化的选型指南。
6.1 框架全景¶
| 框架 | 主模式 | 实现机制 | 语言 | 高阶 | GPU | 典型用途 |
|---|---|---|---|---|---|---|
| Ceres Jet | 前向 | 模板+运算符重载 | C++ | 嵌套 | 否 | SLAM/BA |
| CppAD | 反向 | tape(运行时) | C++ | 任意 | 否 | NLP/MPC |
| CppADCodeGen | 反向 | tape→C 代码 | C++ | 是 | 部分 | 嵌入式 MPC |
| CasADi | 前向+反向 | 符号图 SX/MX | C++/Py | 是 | codegen | OCP/MPC |
| Enzyme | 前向+反向 | LLVM IR | C/C++/Rust | 是 | 是 | HPC |
| PyTorch | 反向(+前向) | 动态 tape | Python | 嵌套 | 是 | 深度学习 |
| JAX | 前向+反向 | tracing→XLA | Python | 任意 | 是 | ML/RL/科学计算 |
| autodiff | 前向+反向 | C++17 header | C++ | 是 | 否 | 原型 |
| Pinocchio | 解析+AD | 手工链式 | C++/Py | 一阶 | 有限 | 刚体动力学 |
6.2 六大框架的设计哲学 ⭐⭐¶
JAX:函数式变换的组合
JAX 的核心是四个可组合的函数变换:grad(反向 AD)、jit(JIT 编译到 XLA)、vmap(自动向量化/批处理)、pmap(多设备并行)。这种设计来自函数式编程——每个变换把一个函数映射为另一个函数,变换之间可以任意嵌套。
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(jnp.sin(x) ** 2)
# grad: f -> grad_f(反向 AD)
grad_f = jax.grad(f)
# jit: grad_f -> 编译后的 grad_f(加速 10-100x)
fast_grad_f = jax.jit(grad_f)
# vmap: fast_grad_f -> 批处理版本(自动向量化)
batched_grad_f = jax.vmap(fast_grad_f)
JAX 的约束:函数必须是"纯"的(无副作用),不能包含 Python 控制流(需用 jax.lax.cond/scan 替代 if/for)。这种约束换来了**可预测的性能**和**任意阶可微**。
PyTorch autograd:动态图的灵活性
PyTorch 在每次前向计算时动态构建计算图(tape),反向传播后图被释放。这意味着你可以在 forward 中自由使用 Python 的 if/else、for 循环,每次前向可以走不同的代码路径。
这种设计对**研究**极为友好(你可以随时 print、debug、改图结构),但对**部署**不利(每次前向都要重建图,无法像 JAX 那样一次编译多次执行)。
CasADi:符号图 + 代码生成
CasADi 走了一条独特的路:用 Python/MATLAB 构建**符号计算图**(不执行数值计算),然后对这张图做符号微分和代码生成(导出 C 代码)。
CasADi 有两种图类型: - SX:标量级别的符号图。每个标量是图的一个节点,适合稠密的小问题(维度 < 100) - MX:矩阵级别的符号图。节点是矩阵运算,适合结构化的大问题(OCP、NLP)
import casadi as ca
# SX 符号变量
x = ca.SX.sym('x', 3) # 3维符号向量
f = ca.sin(x[0]) * x[1] + ca.exp(x[2])
J = ca.jacobian(f, x) # 符号 Jacobian(不执行数值计算!)
# 编译为高效的 C 函数
func = ca.Function('f', [x], [f, J])
func.generate('my_func.c') # 导出 C 代码,可被 acados/OCS2 调用
CasADi 的核心限制:不支持 Python 控制流(因为图是静态的)。如果你的 ODE 右端函数中有 if/else,必须用 CasADi 的 ca.if_else 替代。
Ceres Jet:编译期的 SLAM 专用 AD
Ceres 的 Jet 通过 C++ 模板在**编译期**实例化 AD 代码。优势是零运行时开销和极高的性能(对 SLAM/BA 这种需要每秒调用数百万次 cost function 的场景至关重要)。劣势是只支持前向模式(反向模式需要运行时构建图)。
Enzyme:编译器级别的 AD
Enzyme 工作在 LLVM 中间表示(IR)层面——它直接对编译后的 IR 做 AD 变换,然后再生成机器码。这意味着你可以对**任何编译为 LLVM IR 的语言**(C、C++、Rust、Fortran、Julia)做 AD,且不需要修改源码或使用特殊的数值类型。
Enzyme 的突破性优势是它能对 GPU kernel 做 AD——这在以前只有 JAX/PyTorch 等框架级解决方案才能做到。
6.3 Enzyme:编译器级 AD 的革命 ⭐⭐⭐¶
Enzyme(Moses & Churavy, NeurIPS 2020)代表了 AD 实现的一种全新范式。传统 AD 要么在源码层面重载运算符(如 Ceres Jet),要么在框架层面构建计算图(如 PyTorch/JAX)。Enzyme 则在**编译器的中间表示**(LLVM IR)层面做 AD。
工作原理: 1. 你用任何支持 LLVM 的语言(C、C++、Rust、Fortran、Julia)写正常代码 2. Clang/GCC 将代码编译为 LLVM IR(一种低层级的中间表示) 3. Enzyme 作为 LLVM 的一个 pass,对 IR 做 AD 变换——在 IR 层面自动生成前向/反向模式的导数代码 4. LLVM 后端将变换后的 IR 编译为优化的机器码
关键优势: - 零源码修改:不需要重写为模板函数、不需要用特殊的张量类型、不需要用框架的 API - 跨语言:C、C++、Rust、Fortran、Julia 的代码都可以直接做 AD - GPU 支持(Moses et al., SC 2021):可以对 CUDA kernel 做 AD——这在以前只有 JAX/PyTorch 等框架级方案才能实现 - 性能:由于在 IR 层面操作,可以利用 LLVM 的全部优化 pass(内联、循环展开、向量化),生成的导数代码质量极高
在机器人学中的应用前景:Enzyme 使得对已有的 C/C++ 物理引擎(如 MuJoCo、Drake、Bullet)做 AD 成为理论上可能的事情——不需要从头用 JAX 重写。当然,实践中仍有很多工程挑战(如对复杂控制流、内存分配、虚函数的支持)。
6.4 选型决策树 ⭐⭐¶
你的问题需要 AD?
├── 是深度学习(训练神经网络)?
│ ├── 是 → PyTorch(研究灵活)或 JAX(性能极致)
│ └── 否 ↓
├── 是 SLAM/BA?
│ ├── 是 → Ceres Jet(C++ 模板 AD,行业标准)
│ └── 否 ↓
├── 是 MPC/OCP(需要实时代码生成)?
│ ├── 是 → CasADi(符号→C 代码→ acados/IPOPT)
│ │ 或 CppAD + CppADCodeGen(C++ tape→ C 代码)
│ └── 否 ↓
├── 是刚体动力学微分?
│ ├── 是 → Pinocchio 解析微分(比黑盒 AD 快 3×+)
│ └── 否 ↓
├── 是 RL + 可微仿真?
│ ├── 是 → JAX + Brax/MJX
│ │ 或 Dojo(Julia,精确接触梯度)
│ └── 否 ↓
└── 通用科学计算?
├── C/C++ 环境 → Enzyme 或 autodiff(header-only)
├── Python → JAX
└── Julia → ForwardDiff.jl + Enzyme.jl
反事实推理:如果你为 SLAM 选了 PyTorch 而不是 Ceres,会发生什么?PyTorch 的动态图每次前向都要分配内存、构建 tape、Python 解释器开销——对于 BA 中每秒数百万次的 cost function 调用,Python 开销比计算本身大一个数量级。Ceres 的编译期模板 AD 完全消除了这些开销。所以选对工具不是锦上添花,而是决定系统能不能跑起来。
⚠️ 常见陷阱¶
🧠 思维陷阱:用黑盒 AD 对刚体动力学求微分 - 错误做法:用 CppAD 或 PyTorch 对 RNEA/ABA 做黑盒反向 AD - 现象:对 7 自由度机械臂,黑盒 AD 约需 50 \(\mu\)s,而 Pinocchio 解析微分只需 3 \(\mu\)s - 根本原因:RNEA/ABA 的递推结构天然编码了空间代数的链式法则。Carpentier 和 Mansard(RSS 2018)证明了手工推导的解析微分利用了递推的稀疏结构,避免了黑盒 AD 中大量冗余运算 - 正确做法:使用
pinocchio.computeRNEADerivatives()或pinocchio.computeABADerivatives()💡 概念误区:混淆 CasADi 的符号图和 PyTorch 的动态图 - CasADi 的图是**完全符号的**:构建阶段不执行任何数值计算,只记录运算结构。图构建完成后再"编译"为高效的数值函数 - PyTorch 的图是**边执行边构建**的:每次
forward()同时执行数值计算和构建反向图 - 后果:在 CasADi 中不能用 Python 的if x > 0(因为 x 是符号、没有数值),必须用ca.if_else
练习¶
- 对同一个函数 \(f(x) = \sum_i \sin(x_i)^2\)(\(n = 1000\)),分别用 JAX
jax.grad和 PyTorchtorch.autograd.grad计算梯度,比较 wall time。加上jax.jit后差异如何变化? - 用 CasADi 构建单摆 \(\dot{\theta} = \omega, \dot{\omega} = -g/l \sin\theta + u/(ml^2)\) 的符号模型,生成 Jacobian \(\partial f / \partial x\) 和 \(\partial f / \partial u\),与手工推导对比。
- 列出你目前的研究中使用的 AD 工具,分析是否选择了最合适的方案。如果要切换到更合适的工具,需要修改哪些代码?
7. 隐函数定理与隐式微分 ⭐⭐¶
动机¶
到目前为止,我们讨论的 AD 都是**前向/反向传播**——沿着计算图传播导数。但有一大类重要的函数,其值不是通过一段显式的代码计算出来的,而是**通过求解一个方程或优化问题隐式定义的**。
例如: - 解线性方程组:\(y^*(A, b) = A^{-1}b\),\(y^*\) 是方程 \(Ay = b\) 的解 - 解优化问题:\(y^*(\theta) = \arg\min_y f(y, \theta)\),\(y^*\) 是目标函数的最优解 - 解 ODE:\(y^*(T, \theta)\) 是 \(\dot{z} = f_\theta(z, t)\) 在 \(t=T\) 时的解 - 解不动点方程:\(y^* = T_\theta(y^*)\),\(y^*\) 是映射 \(T\) 的不动点
对这些"隐式定义的函数",如何计算 \(\partial y^* / \partial \theta\)?
朴素方法——展开(unrolling):把求解过程的每一步迭代(如 Newton 迭代、梯度下降迭代)展开成计算图,然后用反向 AD 穿过所有迭代步骤。这种方法可以工作,但有严重的缺点:
- 内存:需要存储所有迭代步骤的中间状态,\(O(K)\)(\(K\) 为迭代步数)
- 梯度质量:如果迭代没有完全收敛,梯度会受到"收敛路径"的污染
- 计算量:反向传播穿过 \(K\) 步迭代,每步都要计算局部导数
正确方法——隐式微分:利用隐函数定理,只在**收敛点** \(y^*\) 处计算导数,完全不需要知道"是怎么求解到 \(y^*\) 的"。内存 \(O(1)\),梯度精确。
跨领域类比:展开法就像在地图上追踪你从家走到学校的每一步路径,然后计算"如果起点偏移 \(\delta\),终点偏移多少"。隐式微分则是直接问:"学校的位置(终点)对家的位置(起点)有什么依赖关系?"——你不需要知道走了哪条路,只需要知道终点满足什么条件。
7.1 隐函数定理(IFT)的严格陈述 ⭐⭐¶
定理(经典隐函数定理):设 \(F: \mathbb{R}^n \times \mathbb{R}^m \to \mathbb{R}^m\) 是 \(C^1\) 映射。如果在点 \((x_0, y_0)\) 处:
- \(F(x_0, y_0) = 0\)(\(y_0\) 满足方程)
- \(\frac{\partial F}{\partial y}(x_0, y_0)\) 是可逆的(\(m \times m\) 矩阵非奇异)
则存在 \(x_0\) 的邻域 \(U\) 和唯一的 \(C^1\) 映射 \(y^*: U \to \mathbb{R}^m\),使得:
并且:
7.2 IFT 的完整证明思路 ⭐⭐⭐¶
证明的核心思想:构造压缩映射,然后用 Banach 不动点定理。
Step 1:定义辅助映射 \(G(x, y) = y - A^{-1} F(x, y)\),其中 \(A = \frac{\partial F}{\partial y}(x_0, y_0)\)。
Step 2:验证 \(G(x_0, y_0) = y_0 - A^{-1} \cdot 0 = y_0\)。
Step 3:计算 \(\frac{\partial G}{\partial y}(x_0, y_0) = I - A^{-1} \frac{\partial F}{\partial y}(x_0, y_0) = I - A^{-1} A = 0\)。
因为 \(\partial G / \partial y\) 在 \((x_0, y_0)\) 处为零(Lipschitz 常数为 0),由连续性,在 \((x_0, y_0)\) 的足够小邻域内 \(G\) 关于 \(y\) 是**压缩映射**。
Step 4:由 Banach 不动点定理,\(G(x, \cdot)\) 有唯一不动点 \(y^*(x) = G(x, y^*(x))\),即 \(y^*(x) - A^{-1} F(x, y^*(x)) = y^*(x)\),化简得 \(F(x, y^*(x)) = 0\)。
Step 5:求导。对恒等式 \(F(x, y^*(x)) \equiv 0\) 两边对 \(x\) 求全微分:
解出:
这就是**隐式微分公式**。
物理直觉:\(\frac{\partial F}{\partial x}\) 衡量"参数变化对方程左边的直接影响",\((\frac{\partial F}{\partial y})^{-1}\) 衡量"解需要调整多少来重新满足方程"。两者的乘积就是"参数变化对解的间接影响"。
7.3 VJP 形式的隐式微分 ⭐⭐⭐¶
在实际应用中,我们通常不需要完整的 Jacobian \(dy^*/dx\),而只需要 VJP:给定上游梯度 \(\bar{y}\),计算 \(\bar{x} = \bar{y}^\top \frac{dy^*}{dx}\)。
其中 \(u\) 满足线性方程:
这意味着隐式微分的反向传播只需要**解一个线性方程组**(与前向求解规模相同),不需要构造或存储完整的 Jacobian 或其逆矩阵。
7.4 IFT 失效的情况与对策 ⭐⭐⭐¶
IFT 的核心条件是 \(\partial F / \partial y\) 非奇异。在实际应用中,这个条件可能失效:
场景 1:退化 QP。当 QP 有多个最优解时(\(Q\) 半正定、不唯一),KKT 矩阵奇异。对策:添加 Tikhonov 正则化 \(Q \to Q + \epsilon I\)。
场景 2:Active set 切换。含不等式约束的优化问题中,当参数 \(\theta\) 变化导致某个约束从 active 变为 inactive(或反之),最优解不可微(存在拐点)。对策:使用 interior-point 方法的 central-path 参数化 \(\mu\) 做光滑化。
场景 3:非唯一解。非凸问题可能有多个局部最优解。参数微小变化可能导致全局最优解从一个局部最优跳到另一个。对策:在明确指定的局部最优附近使用 IFT,不追求全局最优。
反事实推理:如果 \(\partial F / \partial y\) 恰好奇异,IFT 公式 \(dy^*/dx = -(F_y)^{-1} F_x\) 中出现了矩阵求逆失败。但这并不意味着 \(y^*(x)\) 不存在或不连续——它可能仍然存在且连续,只是不可微。这类似于 \(|x|\) 在 \(x=0\) 处连续但不可微。处理这种情况需要更高级的工具:广义隐函数定理(Dontchev & Rockafellar 2014)用 Aubin property 和度量正则性替代经典的非奇异条件。
⚠️ 常见陷阱¶
⚠️ 编程陷阱:展开 QP/iLQR 迭代做反向传播 - 错误做法:把 QP 求解器的每步迭代放进 PyTorch 计算图 - 现象:(1) 内存与迭代步数成正比;(2) 梯度依赖于收敛路径而非收敛解;(3) 如果求解器使用了 warm-start,梯度会被 warm-start 的质量污染 - 根本原因:展开法混淆了"求解过程"和"解的性质"——导数应该只依赖于"解在哪里",不应该依赖于"怎么找到解的" - 正确做法:用隐式微分,在收敛解 \(y^*\) 处通过 IFT 计算梯度
💡 概念误区:认为 IFT 只适用于等式约束 - 实际上:对不等式约束 \(g(y, x) \leq 0\),在严格互补条件下(active 约束上 \(\lambda > 0\)),可以把 KKT 条件写成等式 \(F(z, x) = 0\)(其中 \(z = (y, \lambda, \nu)\)),然后对 \(F\) 用 IFT - 这正是 OptNet/cvxpylayers 的数学基础
练习¶
- 对方程 \(F(x, y) = y^3 + xy - 1 = 0\),在 \((x_0, y_0) = (0, 1)\) 处:(a) 验证 IFT 的条件;(b) 用隐式微分公式计算 \(dy^*/dx\);(c) 用 SymPy 求解 \(y^*(x)\) 的显式表达式并验证导数。
- 考虑线性方程组 \(Ay = b\)(\(A\) 可逆)。将其写成 \(F(A, b, y) = Ay - b = 0\)。用 IFT 推导 \(\partial y^* / \partial b\) 和 \(\partial y^* / \partial A\)(后者需要使用矩阵微分规则)。
- (跨章综合)回顾凸优化中的 proximal 算子 \(\text{prox}_f(x) = \arg\min_y \{f(y) + \frac{1}{2}\|y - x\|^2\}\)。写出最优性条件 \(F(x, y^*) = 0\),用 IFT 推导 \(\partial \text{prox}_f / \partial x\) 的表达式。
8. 可微优化层:KKT 的微分 ⭐⭐⭐¶
动机¶
2017 年 Amos 和 Kolter 提出了一个改变游戏规则的想法:把优化问题本身当作神经网络的一层(OptNet)。这意味着你可以在网络中嵌入一个 QP 求解器,梯度穿过 QP 层继续反向传播——优化的输出对网络的参数是可微的。
这是隐式微分(§7)在优化问题上的直接应用。
8.1 QP 层的隐式微分推导 ⭐⭐⭐¶
考虑参数化的等式约束 QP:
Step 1:写出 KKT 条件。引入 Lagrange 乘子 \(\nu\):
Step 2:KKT 矩阵(Jacobian \(\partial F / \partial (y, \nu)\)):
这就是凸优化中经典的 KKT 矩阵。KKT 矩阵非奇异的条件是:
- \(Q \succ 0\)(正定,保证目标严格凸);且
- \(A\) 满行秩(即 \(\text{rank}(A) = p\),其中 \(p\) 为等式约束数),这是**线性独立约束品性(LICQ)**的要求。
为什么需要 \(A\) 满行秩? KKT 矩阵是 \(2 \times 2\) 分块鞍点矩阵。由 Sylvester 惯性定理,它非奇异当且仅当 \(Q\) 在 \(\ker(A)\) 上正定**且** \(A\) 满行秩。如果 \(A\) 有线性相关的行(冗余约束),KKT 矩阵的零空间非平凡——乘子 \(\nu^*\) 不唯一,隐式微分无法定义。工程中出现冗余等式约束(如多只脚都约束在地面上但几何上有依赖)时,必须先去冗余再做隐式微分。
Step 3:用 IFT 得到隐式微分公式。对参数 \(q\) 的微分最简单:
用 Schur 补提取 \(\partial y^*/\partial q\)。KKT 系统的分块消元过程:从第二行 \(A \cdot dy/dq = 0\)(因为 \(b\) 不依赖 \(q\)),第一行 \(Q \cdot dy/dq + A^\top \cdot d\nu/dq = -I\)。在约束 \(A \cdot dy/dq = 0\) 下求解——这等价于将 \(dy/dq\) 限制在 \(\ker(A)\) 中,再用 \(Q\) 的逆。正确的 Schur 补公式为:
这正是将 \(Q^{-1}\) 投影到 \(\ker(A)\) 上的结果(即 \(Q^{-1}\) 的零空间正交投影)。直觉含义:参数 \(q\) 的变化只能沿着满足约束 \(Ay = b\) 的方向影响最优解——正交于约束的分量被"锁死"了。
Step 4:VJP 形式。给定上游梯度 \(\bar{y}\),需要解:
然后 \(\bar{q} = -u\)。这只需要解**一个 KKT 系统**——与前向求解 QP 的计算量相当。
为什么不直接对 \(y^* = -Q^{-1}q\) 做 AD? 对无约束 QP,最优解有封闭形式,确实可以直接对封闭形式做 AD。但封闭形式有几个问题:(1) 求 \(Q^{-1}\) 的计算量是 \(O(n^3)\),而 KKT 系统的求解(如用 Cholesky 分解)也是 \(O(n^3)\),但后者可以复用前向求解时的分解;(2) 有约束时没有封闭形式;(3) 大规模问题中 \(Q\) 是稀疏的,KKT 求解可以利用稀疏性而显式求逆不能。
所以 IFT/KKT 隐式微分不仅是唯一的选择(有约束时),而且通常是更高效的选择(利用前向求解的中间结果和稀疏结构)。
8.2 含不等式约束的推广(OptNet) ⭐⭐⭐¶
对含不等式约束的 QP \(\min \frac{1}{2}y^\top Qy + q^\top y\) s.t. \(Gy \leq h, Ay = b\),KKT 条件增加了互补松弛条件 \(\lambda \odot (Gy - h) = 0\)。
在**严格互补**条件下(active 约束上 \(\lambda_i > 0\)),可以把 KKT 写成 \(F(z, \theta) = 0\)(\(z = (y, \lambda, \nu)\)),然后直接用 IFT。
关键的实现细节是:需要识别 active 约束集 \(\mathcal{A} = \{i : g_i(y^*) = 0\}\),只对 active 约束建立 KKT 矩阵。
8.3 cvxpylayers:更通用的可微凸优化 ⭐⭐⭐¶
Agrawal et al.(NeurIPS 2019)将 OptNet 推广到**任意凸优化问题**。核心思想是利用 CVXPY 的 DCP(Disciplined Convex Programming)框架:
- 用 CVXPY 定义凸问题(参数用
cp.Parameter标记) - 框架自动将问题 canonicalize 为锥规划
- 用锥求解器(SCS、ECOS)求解
- 用锥 KKT 条件做隐式微分
DPP(Disciplined Parametrized Programming)约束:参数必须以**仿射方式**出现在问题数据中(即 \(Q(\theta) = Q_0 + \sum_i \theta_i Q_i\))。这个约束保证了 canonicalization 是可微的。
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
import torch
# 定义参数化 QP
n = 10
Q_sqrt = cp.Parameter((n, n))
q = cp.Parameter(n)
y = cp.Variable(n)
prob = cp.Problem(cp.Minimize(0.5 * cp.sum_squares(Q_sqrt @ y) + q @ y),
[y >= -1, y <= 1])
# 包装为 PyTorch 层
layer = CvxpyLayer(prob, parameters=[Q_sqrt, q], variables=[y])
# 前向 + 反向(梯度穿过 QP 求解!)
Q_sqrt_val = torch.randn(n, n, requires_grad=True)
q_val = torch.randn(n, requires_grad=True)
y_star, = layer(Q_sqrt_val, q_val)
loss = y_star.sum()
loss.backward() # dL/dQ_sqrt 和 dL/dq 通过隐式微分计算
本质洞察:可微优化层的价值不在于"让优化可微"这个技术本身,而在于它建立了**优化与学习的双向通道**。正向:学习系统生成优化问题的参数;反向:优化解的质量反馈给学习系统调整参数。这使得"端到端训练一个包含优化的系统"成为可能——这正是可微 MPC、可微 motion planning 的数学基础。
⚠️ 常见陷阱¶
⚠️ 编程陷阱:在 cvxpylayers 中使用非 DPP 兼容的构造 - 错误做法:参数出现在非线性位置,如
cp.quad_form(x, P)中P是参数 - 现象:cvxpylayers 报错 "problem is not DPP compliant" - 根本原因:DPP 要求参数以仿射方式出现在问题数据中。quad_form(x, P)中 P 在二次项中不是仿射的 - 正确做法:改写为cp.sum_squares(P_sqrt @ x),其中P_sqrt是参数🧠 思维陷阱:认为隐式微分总是比展开式更好 - 例外情况:(1) 当求解器没有精确收敛时,隐式微分给出的是"精确解的梯度"而非"近似解的梯度"——两者之间有 gap;(2) 当 KKT 矩阵接近奇异(如 active set 即将变化时),隐式微分的数值稳定性变差 - 正确思维:隐式微分是默认首选,但需要监控 KKT 矩阵的条件数
练习¶
- 对无约束 QP \(\min \frac{1}{2} y^\top Q y + q^\top y\)(\(Q \succ 0\)),最优解 \(y^* = -Q^{-1}q\)。直接微分得到 \(\partial y^* / \partial q = -Q^{-1}\)。用 IFT 推导同样的结果(把 KKT 条件写成 \(F = Qy + q = 0\)),验证两者一致。
- 用 cvxpylayers 实现一个简单的"学习代价函数"实验:给定专家轨迹 \(y_{\text{expert}}\),训练 QP 参数使得 \(\|y^*(\theta) - y_{\text{expert}}\|^2\) 最小化。
- (跨章综合)回顾凸分析中的 Slater 条件。解释为什么 Slater 条件是隐式微分可行的前提——它保证了什么?如果 Slater 条件不满足,KKT 矩阵会有什么问题?
9. 可微物理仿真 ⭐⭐⭐¶
动机¶
可微仿真是机器人学习的前沿方向:如果物理仿真器是可微的,就可以直接计算"策略参数→仿真轨迹→目标函数"的端到端梯度,而不需要像 RL 那样靠采样估计。这在平滑动力学下可以带来 10-100 倍的样本效率提升。
但接触和碰撞引入了**不连续性**,使得梯度问题变得微妙。这一节讨论可微仿真的原理、主要框架和核心挑战。
9.1 Pinocchio 的解析微分 ⭐⭐⭐¶
Pinocchio 是机器人刚体动力学的 C++ 库。Carpentier 和 Mansard(RSS 2018)证明了一个关键结果:RNEA(递归 Newton-Euler 算法)和 ABA(Articulated Body 算法)可以**手工推导**解析微分公式,比黑盒 AD 快 3 倍以上。
为什么解析微分更快? RNEA 的前向递推和反向递推天然编码了空间代数的链式法则。黑盒 AD 不知道这个结构,会产生大量冗余计算。Carpentier 和 Mansard 利用递推结构,将 Jacobian 计算分解为与前向/反向递推同形的递推过程。
性能数据(Carpentier-Mansard RSS 2018):
| 模型 | 自由度 | 解析微分 | CppAD 黑盒 AD | 加速比 |
|---|---|---|---|---|
| 7-dof arm | 7 | 3 \(\mu\)s | 15 \(\mu\)s | 5\(\times\) |
| 人形 | 36 | 17 \(\mu\)s | 78 \(\mu\)s | 4.6\(\times\) |
import pinocchio as pin
# 加载模型
model = pin.buildModelFromUrdf("robot.urdf")
data = model.createData()
# 计算 RNEA 微分:∂τ/∂q, ∂τ/∂v, ∂τ/∂a
pin.computeRNEADerivatives(model, data, q, v, a)
dtau_dq = data.dtau_dq # n×n 矩阵
dtau_dv = data.dtau_dv # n×n 矩阵
# dtau_da = M(q) 就是质量矩阵(RNEA 关于加速度是线性的)
Pinocchio 解析微分的数学原理。
RNEA 计算的是逆动力学 \(\tau = \text{RNEA}(q, \dot{q}, \ddot{q})\)。其内部结构是从根到叶的前向递推(计算每个关节的速度和加速度)和从叶到根的反向递推(累积力和力矩)。
Carpentier 和 Mansard 的关键观察:RNEA 的递推结构**本身就是一种链式法则的具体实现**——空间代数中的递推关系编码了关节到关节的局部变换和力的传递。对 RNEA 求微分,等价于在这个递推结构上叠加一层"微分递推"。
具体来说,\(\partial \tau / \partial q\) 可以分解为: - 对每个关节 \(i\),计算"关节位形对空间速度/加速度的影响"(前向递推的微分) - 累积"空间力对扭矩的影响"(反向递推的微分)
这两层递推的计算量与原始 RNEA 相当(同样是 \(O(n)\),\(n\) 为关节数),但通过利用稀疏结构,常数因子只有 2-3 倍——而黑盒 AD 的常数因子可达 10-15 倍。
9.2 可微仿真引擎对比 ⭐⭐⭐¶
| 引擎 | 语言 | 可微方式 | 接触模型 | GPU | 主要用途 |
|---|---|---|---|---|---|
| Brax | JAX/Python | JAX AD | 弹簧/PBD | 是 | 大规模 RL |
| MuJoCo MJX | JAX/Python | JAX AD | 凸优化 | 是 | RL + 控制 |
| Dojo | Julia | IFT | NCP+IPM | 否 | 精确接触梯度 |
| DiffTaichi | Python/Taichi | 编译器 AD | 各种 | 是 | 流体/弹性体 |
| NVIDIA Warp | Python→CUDA | 编译器 AD | 各种 | 是 | 通用 GPU 仿真 |
| Nimble | C++/Python | adjoint spatial | LCP | 否 | 接触敏感任务 |
Brax 的设计哲学:用 JAX 的 jit + vmap + grad 组合实现"数千个环境并行仿真 + 整体可微"。接触模型简化为弹簧或 PBD(Position Based Dynamics),牺牲物理精度换取可微性和速度。
Dojo 的设计哲学(Howell et al., RSS 2022):通过 primal-dual 内点法求解接触 NCP(非线性互补问题),然后在 IPM 的 KKT 系统上用 IFT 求梯度。关键参数 \(\kappa\)(central-path 参数)控制接触的"软化"程度——\(\kappa\) 越大接触越软、梯度越平滑,但物理精度下降。
9.3 接触微分的核心困难 ⭐⭐⭐¶
接触力学引入了**Heaviside 类的不连续性**:物体要么接触、要么不接触,接触力要么存在、要么为零。这种不连续性使得 \(\partial f / \partial x\) 在接触切换点上不存在(或为零),first-order gradient 失效。
Suh et al.(ICML 2022 Outstanding Paper)给出了定量分析:
定理(FoBG 偏差):设 \(z = \phi(\theta, \epsilon)\) 包含 Heaviside 不连续性,smooth 目标 \(J(\theta) = \mathbb{E}_\epsilon[R(\phi(\theta, \epsilon))]\)。则:
- FoBG(first-order batched gradient)\(= \mathbb{E}[\frac{\partial R}{\partial z} \cdot \frac{\partial \phi}{\partial \theta}]\) 几乎处处为零(因为 \(\partial \phi / \partial \theta = 0\)),但 \(\nabla J(\theta)\) 非零——FoBG 有偏
- ZoBG(zeroth-order batched gradient,即 REINFORCE)\(= \mathbb{E}[R \cdot \nabla \log p(\epsilon)]\) 不经过 \(\phi\) 的 Jacobian——ZoBG 无偏
实际含义:在有接触的机器人任务中,可微仿真给出的梯度可能完全错误。对策包括:
- 随机光滑化(randomized smoothing):对参数加噪声 \(\theta + \sigma \xi\),在期望意义下获得光滑梯度
- 互补松弛(complementarity relaxation):将接触条件 \(0 \leq \lambda \perp \phi \geq 0\) 松弛为 \(\lambda \phi = \kappa\)(Dojo 方法)
- 混合策略:光滑阶段用可微仿真(FoBG),接触密集阶段切到 REINFORCE(ZoBG)
反事实推理:如果所有物理系统都是光滑的(没有接触、碰撞),那么可微仿真将是 RL 的终极解决方案——样本效率比无模型 RL 高 100 倍。但现实世界充满了不连续性(抓取、行走、碰撞),这使得"用什么梯度"成为一个需要仔细分析的研究问题。
9.4 SHAC:截断窗口解决梯度爆炸 ⭐⭐⭐⭐¶
Xu et al.(ICLR 2022)提出 SHAC(Short-Horizon Actor-Critic),解决可微仿真中长 horizon BPTT 的梯度爆炸问题。核心思想:
- 只在短窗口(如 16-32 步)内通过可微仿真做 BPTT
- 用 critic 估计窗口外的 value function
- 结合短窗口的 FoG 和 critic 的 bootstrap
这在一些连续控制任务(如 Ant locomotion)上实现了比 PPO 更高的样本效率和更快的训练速度。
⚠️ 常见陷阱¶
💡 概念误区:认为可微仿真可以完全替代 model-free RL - 错误推理:"有了可微仿真,直接反向传播就能训练策略,不需要 PPO/SAC 了" - 实际上:(1) 接触导致梯度偏差;(2) 长 horizon 导致梯度爆炸(Lyapunov 指数);(3) 混沌系统的梯度方差随时间指数增长。Suh 2022 的结论是:没有银弹,需要根据任务特性选择 FoBG、ZoBG 或混合方法 - 正确思维:可微仿真是工具箱中的一个强力工具,但不是唯一的
⚠️ 编程陷阱:MuJoCo MJX 的
while_loop反向失败 - 现象:用jax.grad对 MJX 仿真做反向 AD 时报错 - 根本原因:MJX 默认 CG 求解器用jax.lax.while_loop(动态终止条件),JAX 的反向 AD 不支持 - 正确做法:切换到 Newton 求解器或使用jax.lax.fori_loop(静态上界)
练习¶
- 解释 Pinocchio 的
computeRNEADerivatives为什么比黑盒 AD 更快。(提示:利用了 RNEA 递推的什么结构性质?) - 设计一个简单的 1D 接触问题:小球落地 \(x_{t+1} = \max(x_t + v_t \Delta t, 0)\)。分析 \(\partial x_{t+1} / \partial v_t\) 在 \(x_t + v_t \Delta t = 0\) 处的行为。这解释了为什么接触梯度是困难的。
- (跨章综合)回顾 §7 的隐式微分。解释 Dojo 如何用 IFT 绕过接触的不连续性——它把什么写成 \(F(z, \theta) = 0\)?
10. 工程实践 ⭐⭐¶
动机¶
理解 AD 的数学原理是基础,但要在工程中**正确、高效、可靠地使用 AD**,还需要掌握一系列实践技巧。本节覆盖三个关键主题:梯度验证、数值稳定性和可微渲染。
10.1 梯度检查(Gradient Check) ⭐⭐¶
每当你实现了一个自定义的 AD 规则(如 torch.autograd.Function 或 jax.custom_vjp),第一件事就是**用有限差分验证它的正确性**。这就是"有限差分做验证"的正确用法。
import jax
import jax.numpy as jnp
def check_grad(f, x, eps=1e-5):
"""用中心差分验证 jax.grad 的正确性"""
grad_ad = jax.grad(f)(x)
grad_fd = jnp.zeros_like(x)
for i in range(len(x)):
e_i = jnp.zeros_like(x).at[i].set(1.0)
grad_fd = grad_fd.at[i].set(
(f(x + eps * e_i) - f(x - eps * e_i)) / (2 * eps)
)
# 相对误差应小于 sqrt(eps) ~ 1e-2.5
rel_error = jnp.linalg.norm(grad_ad - grad_fd) / (jnp.linalg.norm(grad_ad) + 1e-10)
print(f"相对误差: {rel_error:.2e}")
return rel_error < 1e-3 # 经验阈值
精度预期:对中心差分(\(h = 10^{-5}\)),截断误差 \(O(h^2) \approx 10^{-10}\),舍入误差 \(O(\epsilon_{\text{mach}} / h) \approx 10^{-11}\)。AD 与有限差分的相对误差应在 \(10^{-5}\) 到 \(10^{-7}\) 之间。如果差异超过 \(10^{-3}\),几乎可以确定 AD 实现有 bug。
10.2 复步微分法(Complex Step Method) ⭐⭐⭐¶
有限差分有截断误差和舍入误差的矛盾。复步微分法是一个巧妙的替代方案:
其中 \(i = \sqrt{-1}\)。关键观察:分子不涉及两个接近的实数相减——\(f(x)\) 出现在实部,\(f'(x) \cdot h\) 出现在虚部,两者在完全不同的"通道"中。因此没有 catastrophic cancellation,可以取极小的 \(h\)(如 \(h = 10^{-200}\))获得接近机器精度的导数。
与 dual number 的关系:复步微分法实际上是 dual number AD 的一个特例。在 \(\mathbb{C}\) 中,\(f(x + ih) \approx f(x) + if'(x)h\)(当 \(h\) 足够小时)。虚部除以 \(h\) 就得到 \(f'(x)\)。但 dual number 更一般——它不要求 \(h\) 趋于零,因为 \(\varepsilon^2 = 0\) **精确地**消除了二阶项,而不是靠 \(h\) 很小来近似消除。
import numpy as np
def complex_step_derivative(f, x, h=1e-30):
"""复步微分法:精度接近机器精度,无 catastrophic cancellation"""
return np.imag(f(x + 1j * h)) / h
# 测试
f = lambda x: np.sin(x**3) * np.exp(x)
x0 = 1.5
# 比较三种方法
deriv_fd = (f(x0 + 1e-8) - f(x0)) / 1e-8 # 前向差分
deriv_cd = (f(x0 + 1e-5) - f(x0 - 1e-5)) / 2e-5 # 中心差分
deriv_cs = complex_step_derivative(f, x0) # 复步微分
# 解析导数
exact = (3*x0**2*np.cos(x0**3) + np.sin(x0**3)) * np.exp(x0)
print(f"前向差分误差: {abs(deriv_fd - exact):.2e}") # ~1e-8
print(f"中心差分误差: {abs(deriv_cd - exact):.2e}") # ~1e-10
print(f"复步微分误差: {abs(deriv_cs - exact):.2e}") # ~1e-16(机器精度!)
复步微分法的局限:它要求函数 \(f\) 可以接受复数输入,且在实轴附近是解析的。对于包含 abs、max、分支判断的函数不适用。
10.3 数值稳定性 ⭐⭐¶
AD 虽然没有截断误差,但仍然受**浮点舍入**和**数学条件数**的影响。以下是最常见的数值问题:
Softmax 的 log-sum-exp 技巧:直接计算 \(\log(\sum_i e^{x_i})\) 在 \(x_i\) 很大时会溢出。标准技巧是减去最大值 \(M = \max_i x_i\):
这不影响 AD 的正确性(因为减去常数不影响梯度),但对数值稳定性至关重要。
梯度消失/爆炸:在长序列反向传播中,梯度通过大量矩阵连乘传播。如果矩阵的谱范数 \(> 1\),梯度指数增长;如果 \(< 1\),梯度指数衰减。对策包括梯度裁剪(gradient clipping)、残差连接、LSTM 的门控机制等。
病态 Hessian:当 Hessian 矩阵的条件数 \(\kappa = \lambda_{\max} / \lambda_{\min}\) 很大时,AD 计算的梯度方向虽然正确,但优化收敛速度极慢(梯度下降的收敛率 \(\sim ((\kappa-1)/(\kappa+1))^k\))。此时需要预条件化或自适应步长。
10.4 AD 中的常见数值灾难 ⭐⭐¶
以下是 AD 实践中最容易遇到的数值问题及其标准解决方案:
灾难 1:\(\sqrt{0}\) 的导数。\(\frac{d}{dx}\sqrt{x} = \frac{1}{2\sqrt{x}}\),在 \(x=0\) 处趋于无穷。如果你的代码中有 loss = sqrt(x**2 + y**2)(如 L2 范数),在 \((x,y) = (0,0)\) 处梯度会出现 NaN。
标准对策:添加小常数 loss = sqrt(x**2 + y**2 + eps),或使用 jnp.linalg.norm(v + eps) 而非手写 sqrt(sum(v**2))。
灾难 2:\(\log(0)\) 和 \(\log(\text{softmax})\)。\(\log(0) = -\infty\),\(\frac{d}{dx}\log(x) = 1/x\) 在 \(x \to 0\) 时趋于无穷。在分类任务中,loss = -log(softmax(logits)[label]) 的梯度在 logits 差异大时可能出问题。
标准对策:使用 log_softmax(一步计算,数值稳定),而非先 softmax 再 log。
灾难 3:条件数的放大效应。AD 的结果在数学上是精确的(浮点范围内),但如果函数本身是病态的(条件数大),AD 给出的梯度虽然"精确",但对输入微小扰动高度敏感。例如,矩阵求逆 \(f(A) = A^{-1}\) 的条件数为 \(\kappa(A)^2\)——如果 \(\kappa(A) = 10^8\),那么 \(A\) 的 8 位精度变化会导致 \(A^{-1}\) 完全不准。
标准对策:避免显式矩阵求逆,改用分解+求解;添加正则化降低条件数。
10.5 可微渲染简介 ⭐⭐⭐⭐¶
可微渲染是 AD 在计算机视觉中的前沿应用。两个里程碑式的工作:
NeRF(Mildenhall et al., ECCV 2020)用 MLP \(F_\theta(\mathbf{x}, \mathbf{d}) \to (c, \sigma)\) 表示场景的颜色和密度,通过**体积渲染**将 3D 表示投影到 2D 图像。体积渲染公式:
整个渲染过程是可微的(通过 PyTorch 反向 AD),因此可以用多视角图像反向传播优化 MLP 参数 \(\theta\)。
3D Gaussian Splatting(Kerbl et al., SIGGRAPH 2023)用三维高斯椭球代替 MLP 表示场景。每个高斯由位置、协方差、颜色和不透明度参数化。渲染通过"splatting"(将 3D 高斯投影到 2D 图像平面并做 alpha compositing)实现。
关键的 AD 挑战:高斯 splatting 的渲染管线涉及 3D→2D 投影、排序和 alpha 合成——这些操作需要**自定义 CUDA kernel 的手工 VJP**(因为通用 AD 框架对 GPU kernel 的支持有限),Kerbl 等人用手写的 CUDA 反向传播实现了这一点。
10.6 AD 工程的黄金法则 ⭐⭐¶
总结本节的实践经验,以下六条法则涵盖了 AD 工程中最重要的注意事项:
-
永远做梯度检查:任何自定义的前向/反向规则,在投入使用前必须用有限差分或复步微分验证。没有通过梯度检查的代码不可信。
-
先用框架提供的 AD,再考虑手写:PyTorch 的
autograd、JAX 的grad已经处理了大量边界情况。只有当性能不满足需求时才考虑手写 VJP(如 3DGS 的 CUDA kernel)。 -
关注 \(n/m\) 比值:\(n \gg m\) 用反向模式,\(n \ll m\) 用前向模式。选错模式可能导致 1000 倍的性能差异。
-
监控条件数:AD 计算的梯度在数学上是精确的,但如果函数本身是病态的,梯度仍然不可靠。定期检查 Hessian 条件数或 KKT 矩阵条件数。
-
隐式优于展开:对任何涉及迭代求解的计算(优化、ODE、不动点),默认使用隐式微分(IFT),除非有充分理由使用展开式 AD。
-
不可微处需要特殊处理:
abs、max、relu、接触、碰撞等不可微运算需要 log-barrier、softplus、随机光滑化等处理。忽略不可微性会导致零梯度或错误梯度。
⚠️ 常见陷阱¶
⚠️ 编程陷阱:忘记做梯度检查就信任自定义 VJP - 错误做法:写了
custom_vjp后直接用于训练 - 现象:训练 loss 不降或出现 NaN,但不知道原因 - 根本原因:手写的 VJP 公式可能有符号错误、转置错误或维度不匹配 - 正确做法:先用有限差分验证,再用于训练。PyTorch 提供torch.autograd.gradcheck💡 概念误区:认为 float32 够用就不需要考虑数值稳定性 - 错误想法:"GPU 上用 float32 训练就行了,不用管 float64" - 实际上:float32 只有 7 位有效数字。当两个接近的数相减时(catastrophic cancellation),有效位数可能骤降到 1-2 位,导致梯度严重失真 - 正确思维:在关键路径上使用数值稳定的实现(如 log-sum-exp、稳定的 softmax),而不是简单地切换到 float64
练习¶
- 实现
gradcheck(f, x)函数,使用中心差分验证 JAX 自动梯度的正确性。对 \(f(x) = \log(\sum_i e^{x_i})\) 在 \(x = [1000, 1001, 1002]\) 处测试——如果不用 log-sum-exp 技巧会发生什么? - 对一个 50 层的全连接网络(无残差连接),实验观察反向传播中梯度范数随层数的变化。在什么条件下出现梯度消失/爆炸?
- 阅读 NeRF 的体积渲染公式,解释为什么这个积分的离散近似是可微的。每个采样点的"贡献权重" \(T(t_i) (1 - \exp(-\sigma_i \delta_i))\) 对 \(\sigma_i\) 的梯度是什么?
11. 连续伴随方法与 Neural ODE ⭐⭐⭐¶
动机¶
到目前为止,我们讨论的隐式微分(§7-§8)处理的是**离散的方程组** \(F(x, y) = 0\)。但在最优控制和科学计算中,许多问题天然是**连续时间**的:ODE 约束下的参数估计、连续深度模型(Neural ODE)、轨迹优化。对这些问题,连续伴随方程(continuous adjoint method)提供了一种内存高效的梯度计算方法。
11.1 问题设定¶
考虑参数化的 ODE 系统:
目标是最小化终端损失 \(L(z(T))\)。问题:如何计算 \(dL/d\theta\)?
朴素方法(discretize-then-optimize):将 ODE 用 Euler/RK4 离散化为 \(K\) 步,把所有步骤展开成计算图,然后反向 AD。内存 \(O(K)\),因为需要存储所有中间状态。
伴随方法(optimize-then-discretize):在连续层面推导梯度公式,然后再离散化计算。内存 \(O(1)\)(只存储终端状态和伴随变量)。
11.2 连续伴随方程的推导 ⭐⭐⭐¶
这是自动微分理论中最优美的推导之一。 它把链式法则从离散图推广到连续时间。
Step 1:构造增广 Lagrangian
引入伴随变量(Lagrange 乘子函数)\(\lambda(t)\),将 ODE 约束纳入目标:
为什么要这样做? 当 \(z(t)\) 满足 ODE 约束 \(\dot{z} = f_\theta\) 时,积分项为零,\(\mathcal{L} = L(z(T))\)。但通过引入 \(\lambda\),我们可以对 \(z(t)\) 做变分(虚位移),而不需要约束 \(z(t)\) 始终满足 ODE。
Step 2:对 \(z(t)\) 做变分
设 \(z(t) \to z(t) + \delta z(t)\),则:
对 \(\delta \dot{z}\) 做分部积分:
代入得:
Step 3:令各阶变分为零
由于 \(\delta z(t)\) 是任意的,各系数必须分别为零:
- 终端条件:\(\lambda(T) = \frac{\partial L}{\partial z}\bigg|_{z(T)}\)
- 伴随 ODE:\(\dot{\lambda} = -\lambda^\top \frac{\partial f}{\partial z}\)(即 \(\dot{\lambda} = -\left(\frac{\partial f}{\partial z}\right)^\top \lambda\))
- 初始条件不变:\(\delta z(0) = 0\)(初始条件是给定的)
Step 4:参数梯度
对 \(\theta\) 求全微分:
这就是**连续伴随方程**:反向从 \(T\) 到 \(0\) 积分 \(\lambda\),同时累积参数梯度。
本质洞察:连续伴随方程是反向模式 AD 在连续时间的自然推广。离散图上的 adjoint 变量 \(\bar{v}_i = \partial y / \partial v_i\) 变成了连续的 \(\lambda(t) = \partial L / \partial z(t)\)。离散图上"逆拓扑序遍历"变成了"反向时间积分"。链式法则的"乘法+求和"变成了"ODE + 积分"。形式不同,本质完全相同。
11.3 Neural ODE(Chen et al., NeurIPS 2018) ⭐⭐⭐¶
Chen et al. 的核心观察:如果把 ResNet 的层 \(z_{k+1} = z_k + f_\theta(z_k, k)\) 视为 Euler 离散化,那么当层数趋于无穷、步长趋于零时,ResNet 变成了 ODE \(\dot{z} = f_\theta(z, t)\)。
Neural ODE 的优势: 1. 自适应计算量:ODE 求解器根据函数的"刚性"自动调整步长,简单区域大步走、困难区域小步走 2. 常数内存:用伴随方法计算梯度,内存不随"深度"增长 3. 可逆性:ODE 的前向和反向积分互为逆过程,无需存储前向状态
Neural ODE 的局限: 1. 表达能力有限:ODE 的轨迹不能交叉(唯一性定理),这限制了模型的表达能力 2. 刚性 ODE 的问题:当 \(f_\theta\) 产生刚性动力学时,adjoint ODE 可能产生 stiff 反向积分,导致数值不稳定 3. discretize-then-optimize vs optimize-then-discretize 的不一致:Gholami et al.(2019)指出,当前向用自适应步长求解器时,连续伴随的离散版本与"展开求解器再反向 AD"给出的梯度不完全相同——因为步长选择依赖于状态,而状态对参数可微
跨领域类比:Neural ODE 之于 ResNet,就像微分方程之于差分方程。ResNet 是"离散时间的动力系统",Neural ODE 是"连续时间的动力系统"。就像连续分析往往比离散分析更优美(可以用微积分工具),Neural ODE 的伴随方法也比 BPTT 更优美(常数内存、自适应精度)。
Neural ODE 在机器人学中的应用:
| 应用场景 | 为什么用 Neural ODE | 代替什么 |
|---|---|---|
| 系统辨识 | 学习连续时间动力学 \(\dot{x} = f_\theta(x, u)\) | 传统 NARX/LSTM |
| 轨迹预测 | 自适应步长,在简单阶段节省计算 | 固定步长 RNN |
| 物理信息模型 | 可嵌入已知物理结构(如 Hamiltonian、Lagrangian) | 纯数据驱动模型 |
| 连续时间 RL | 策略和值函数是 ODE 的解 | 离散时间 RL |
特别值得关注的是 Hamiltonian Neural Networks(Greydanus et al., NeurIPS 2019)和 Lagrangian Neural Networks(Cranmer et al., ICLR 2020)——它们把物理守恒律(能量守恒、动量守恒)嵌入 Neural ODE 的结构中,使学到的动力学模型在长时间预测中保持物理一致性。这对机器人学中的"学习-控制"闭环至关重要:如果学到的模型违反能量守恒,基于模型的控制器可能产生不稳定的行为。
11.4 Discretize-then-Optimize vs Optimize-then-Discretize ⭐⭐⭐⭐¶
这是计算科学中一个经典的二元对立,理解它对正确使用 Neural ODE 至关重要。
| 维度 | Discretize-then-Optimize(展开) | Optimize-then-Discretize(伴随) |
|---|---|---|
| 做法 | 先离散化 ODE(如 RK4),再对离散步骤做 AD | 先在连续层面推导伴随方程,再离散化伴随 ODE |
| 内存 | \(O(K)\)(存所有步的状态) | \(O(1)\)(只存终端状态) |
| 梯度一致性 | 与前向计算严格一致 | 可能有离散化误差 |
| 自适应步长 | 前向步长选择本身是可微的(复杂) | 前向和反向可用不同步长(灵活) |
| 实现复杂度 | 简单(直接用 AD 框架) | 需要实现伴随 ODE 的求解器 |
工程建议:对于大多数应用,先用 checkpointed discretize-then-optimize(如 torch.utils.checkpoint + torchdiffeq),因为它的梯度与前向严格一致。只有在内存极度受限时才切换到纯伴随方法。
⚠️ 常见陷阱¶
⚠️ 编程陷阱:对刚性 ODE 用伴随方法导致梯度不准 - 现象:伴随方程的反向积分数值发散,梯度出现 NaN 或大幅震荡 - 根本原因:刚性 ODE 的 Jacobian \(\partial f / \partial z\) 有很大的负实部特征值,反向积分时这些变成正实部→指数增长 - 正确做法:使用 Gholami et al. 推荐的 checkpointed augmented-state 方法,或切换到隐式 ODE 求解器
💡 概念误区:认为 Neural ODE 总是比 ResNet 好 - 实际上:(1) Neural ODE 的训练通常比 ResNet 慢 5-10 倍(ODE 求解器开销);(2) 表达能力有限(轨迹不交叉);(3) 在图像分类等标准任务上并不比 ResNet 更准确 - Neural ODE 的真正价值在于建模连续时间动力学(如时间序列、物理系统)和理论分析(如动力系统稳定性)
练习¶
- 对简谐振子 \(\ddot{x} = -\omega^2 x\)(改写为一阶系统 \(\dot{z} = Az\),\(A = \begin{pmatrix} 0 & 1 \\ -\omega^2 & 0 \end{pmatrix}\)),写出伴随方程 \(\dot{\lambda} = -A^\top \lambda\)。验证伴随 ODE 也是简谐振子(频率相同)。
- 推导参数梯度 \(dL/d\omega^2\) 的积分表达式。如果 \(L = z_1(T)^2\)(终端位移的平方),\(\lambda(T) = ?\)
- 用
torchdiffeq实现一个简单的 Neural ODE 分类器(如螺旋线分类),比较(1) adjoint 方法和(2) 直接 BPTT 的内存消耗和梯度差异。
12. 深度平衡模型与定点隐式微分 ⭐⭐⭐⭐¶
动机¶
Neural ODE 把无限深度解释为连续 ODE 的积分。Deep Equilibrium Models(DEQ, Bai et al. NeurIPS 2019)走了另一条路:把无限深度解释为**定点迭代的收敛**。
考虑一个共享权重的深度网络 \(z_{k+1} = f_\theta(z_k, x)\)。当 \(k \to \infty\) 时(如果收敛),\(z^* = f_\theta(z^*, x)\)——这就是**定点方程**。DEQ 直接求解定点方程(用 Anderson acceleration、Broyden 法等),然后用 IFT 计算梯度。
12.1 定点隐式微分的推导 ⭐⭐⭐¶
定点条件 \(z^* = f_\theta(z^*, x)\) 可以写成 \(F(z, \theta, x) = z - f_\theta(z, x) = 0\)。对 \(\theta\) 求微分:
VJP 形式:给定上游梯度 \(\bar{z}\),需要解:
然后 \(\bar{\theta} = u^\top \frac{\partial f}{\partial \theta}\)。
关键观察:这个线性方程本身也是一个定点问题!可以用迭代法求解:
这是 Neumann 级数 \((I - A)^{-1} = \sum_{k=0}^{\infty} A^k\) 的迭代版本(当 \(\|\partial f / \partial z\| < 1\) 时收敛)。实际中通常用 Anderson acceleration 加速收敛。
12.2 DEQ 的优势与局限 ⭐⭐⭐⭐¶
优势: 1. 常数内存:前向只需要存储 \(z^*\)(不需要展开 \(K\) 步的中间状态) 2. 反向也是常数内存:VJP 通过解线性方程得到(不需要反向穿过 \(K\) 步迭代) 3. 解耦前向与反向:前向求解器和反向求解器可以完全不同(前向用 Broyden,反向用 CG)
局限: 1. 需要收敛保证:如果 \(f\) 不是压缩映射(\(\|\partial f / \partial z\| \geq 1\)),定点迭代不收敛 2. 隐式微分假设精确收敛:如果前向没有完全收敛到 \(z^*\),隐式微分给出的梯度是"精确解的梯度"而非"近似解的梯度"——两者之间有 gap 3. Jacobian 条件数:\(I - \partial f / \partial z\) 接近奇异时,隐式微分数值不稳定
反事实推理:如果我们不用 IFT 而是展开 1000 步定点迭代做 BPTT,会发生什么?(1) 内存 \(O(1000)\) vs IFT 的 \(O(1)\);(2) 梯度受"收敛路径"污染——如果迭代只到 \(z_{500}\) 就几乎收敛了,后 500 步的梯度几乎为零,但 BPTT 仍然要穿过它们,浪费计算并引入数值噪声;(3) 梯度的质量依赖于迭代步数 \(K\),而 IFT 的梯度只依赖于收敛解 \(z^*\),与 \(K\) 无关。
12.3 Anderson 加速与 Broyden 法 ⭐⭐⭐⭐¶
Neumann 级数收敛速度可能很慢(线性收敛,速率 \(\|\partial f / \partial z\|\))。实践中常用两种加速方法:
Anderson acceleration(混合法):保存最近 \(m\) 步的迭代历史 \(\{z_k, g_k = T(z_k) - z_k\}\),在这 \(m\) 个点的仿射组合中找最小残差:
然后 \(z_{k+1} = \sum_i \alpha_i T(z_{k-i})\)。Anderson 加速通常将线性收敛加速为超线性收敛。
Broyden 法:用秩一更新维护 Jacobian 的近似逆 \(B_k \approx (I - \partial f / \partial z)^{-1}\):
其中 \(\Delta z_k = z_{k+1} - z_k\), \(\Delta g_k = g_{k+1} - g_k\)。Broyden 法不需要计算 Jacobian(\(O(n^2)\) 而非 \(O(n^3)\)),是 DEQ 论文中使用的默认方法。
为什么不直接用 Newton 法? Newton 法 \(z_{k+1} = z_k - (I - \partial f / \partial z)^{-1}(z_k - f(z_k))\) 需要在每步计算完整 Jacobian \(\partial f / \partial z\) 并求解线性方程——对深度网络来说计算量过大。Broyden 法用近似 Jacobian 替代精确 Jacobian,每步只需 \(O(n^2)\)。
12.4 Blondel 2022 的统一框架 ⭐⭐⭐¶
Blondel et al.(NeurIPS 2022)给出了一个统一视角:所有隐式微分问题都可以归结为 \(F(x^*, \theta) = 0\) 的 IFT。
| 问题类型 | \(F\) 的形式 | 例子 |
|---|---|---|
| 定点 | \(F = z - T_\theta(z)\) | DEQ, RNN, Almeida-Pineda |
| 最优性 | \(F = \nabla_z f(z, \theta)\) | 无约束优化、proximal 算子 |
| KKT | \(F = (KKT 条件)\) | OptNet, cvxpylayers, diff-MPC |
| ODE 终端 | \(F = z(T) - \text{ODESolve}(\theta)\) | Neural ODE |
JAXopt 库实现了这个统一框架:用户只需要实现 \(F\) 和前向求解器,隐式微分的 VJP 自动生成。
练习¶
- 对映射 \(f(z) = \tanh(Wz + b)\)(\(W\) 是 \(n \times n\) 矩阵),写出定点条件 \(z^* = f(z^*)\)。什么条件下 \(f\) 是压缩映射?(提示:\(\|\tanh'\|_\infty = 1\),所以需要 \(\|W\| < 1\)。)
- 用 Neumann 级数 \(u_{k+1}^\top = \bar{z}^\top + u_k^\top J\)(\(J = \partial f / \partial z|_{z^*}\))实现 DEQ 的反向传播,验证与 JAX 的
custom_vjp+ 直接求逆的结果一致。 - 比较 DEQ 和 Neural ODE:两者都实现了"无限深度",但机制完全不同。列出至少 3 个维度的对比(内存、计算、表达能力、稳定性)。
13. 可微 MPC 与 Pontryagin Differentiable Programming ⭐⭐⭐¶
动机¶
Model Predictive Control(MPC)是机器人控制的主力工具。如果 MPC 是可微的(即 MPC 的输出——最优控制序列——对其参数如代价权重、动力学参数是可微的),就可以做到:
- 学习代价函数:通过模仿学习,从专家演示中反推代价权重
- 系统辨识:从观测轨迹中估计动力学参数
- 端到端控制:把 MPC 作为策略网络的一层,整体端到端训练
13.1 Differentiable MPC(Amos et al., NeurIPS 2018) ⭐⭐⭐¶
Amos et al. 把 MPC 建模为参数化的 QP,然后用 OptNet 的 KKT 隐式微分技术计算梯度:
对 LQR(线性动力学 + 二次代价),MPC 本身可以用 Riccati 递推精确求解。Amos 等人的关键贡献是**对 Riccati 递推做微分**——即推导"Riccati 解对代价权重的导数"。
实验:在 cartpole 上,用专家控制器生成演示轨迹,然后训练 diff-MPC 恢复专家的代价权重。结果表明 diff-MPC 比行为克隆在分布外测试中更鲁棒。
13.2 Pontryagin Differentiable Programming(Jin et al., NeurIPS 2020) ⭐⭐⭐⭐¶
Jin et al. 提出了一种更优雅的方法:不对 MPC 求解器做微分,而是**对 Pontryagin 最大化原理(PMP)本身做微分**。
PMP 给出最优控制的必要条件(协态方程):
对 PMP 关于参数 \(\theta\) 求微分,得到一个 auxiliary LQR 问题——这个辅助 LQR 可以用 Riccati 回传 \(O(N)\) 时间内精确求解。
PDP 相比 diff-MPC 的优势: 1. 不需要展开求解器迭代或存储 KKT 系统 2. 辅助 LQR 的计算量与前向 MPC 相当 3. 梯度精度不受求解器收敛程度的影响
PDP 的核心数学:
对 Hamiltonian 系统的最优性条件(\(\dot{x} = \partial H / \partial p, \dot{p} = -\partial H / \partial x\))关于 \(\theta\) 求微分,设 \(\delta x = \partial x / \partial \theta, \delta p = \partial p / \partial \theta\),得到线性 ODE 系统:
这是一个**线性时变 ODE**,可以用 Riccati 递推高效求解。
跨领域类比:PDP 之于 MPC,就像解析梯度之于有限差分。MPC 的有限差分做法是"扰动参数→重新求解→观察输出变化",diff-MPC 用 KKT IFT 相当于"在解的定义方程上用 IFT",而 PDP 直接在最优性的必要条件上做微分——每一步都更深入、更高效。
13.3 可微 MPC 的工程选型 ⭐⭐⭐¶
| 方法 | 数学基础 | 内存 | 精度 | 适用范围 |
|---|---|---|---|---|
| 展开求解器 + BPTT | 反向 AD | \(O(K \times N)\) | 依赖收敛 | 通用但低效 |
| KKT 隐式微分 | IFT | \(O(1)\) | 精确(在解处) | 凸 QP/SOCP |
| PDP | PMP + 辅助 LQR | \(O(N)\) | 精确 | 连续时间 OC |
| 灵敏度分析 | NLP 后处理 | \(O(1)\) | 精确 | IPOPT/acados |
实际建议:如果你使用 acados 或 CasADi,灵敏度分析(sensitivity analysis)是最方便的——求解器求解 NLP 后,可以直接提取 \(dy^*/d\theta\),不需要额外的反向传播。
⚠️ 常见陷阱¶
🧠 思维陷阱:认为可微 MPC 可以替代 RL - 错误推理:"MPC 比 RL 更有结构,可微 MPC 又能端到端训练,所以不需要 RL" - 实际上:(1) 可微 MPC 需要可微的动力学模型——如果模型不可微(如接触),梯度就崩溃了;(2) MPC 的 horizon \(N\) 有限,长期行为需要 value function 估计;(3) MPC 的计算量比 RL 策略网络大几个数量级 - 正确思维:可微 MPC 适合"有精确模型 + 需要满足约束 + 代价函数需要学习"的场景
⚠️ 编程陷阱:展开 iLQR 迭代做反向传播 - 这是 §7 已经警告过的陷阱,在 MPC 场景下特别常见 - 额外的问题:iLQR 使用 line search,line search 的步长选择本身不可微(argmin 不可微) - 正确做法:用 IFT 或 PDP
练习¶
- 对 LQR 问题(\(\min \sum x_k^\top Q x_k + u_k^\top R u_k\), s.t. \(x_{k+1} = Ax_k + Bu_k\)),用 KKT 隐式微分推导 \(\partial u_0^* / \partial Q\)。
- 解释 PDP 的辅助 LQR 的物理含义:它在"优化"什么?为什么它的解给出了参数灵敏度?
- (跨章综合)回顾 §8 的 OptNet。MPC 的 QP 形式化 \(\min \frac{1}{2} z^\top H z + g^\top z\) s.t. \(Cz = d, Gz \leq h\) 中,\(z = [x_{0:N}; u_{0:N-1}]\)。如果代价权重 \(Q, R\) 是参数,它们出现在 \(H\) 的哪些位置?用 §8 的 KKT 微分公式写出 \(\partial z^* / \partial Q\)。
14. 可微渲染与三维表示学习 ⭐⭐⭐⭐¶
动机¶
可微渲染(differentiable rendering)是 AD 在计算机视觉和机器人感知中的重要应用。它使得"从 2D 图像反推 3D 结构"变成一个可以用梯度下降求解的优化问题。
14.1 体积渲染与 NeRF ⭐⭐⭐⭐¶
NeRF 的体积渲染公式(见 §10.4)的离散近似是:
其中 \(\sigma_i\) 是第 \(i\) 个采样点的密度(由 MLP \(F_\theta\) 预测),\(\delta_i\) 是采样间隔,\(\mathbf{c}_i\) 是颜色。
关键的 AD 分析:
\(\partial \hat{C} / \partial \sigma_i\) 通过链式法则涉及两条路径: 1. 通过 \(\alpha_i\):增大 \(\sigma_i\) 使第 \(i\) 个点更不透明,贡献更多颜色 2. 通过 \(T_{j>i}\):增大 \(\sigma_i\) 使后续点的透射率 \(T_j\) 降低,遮挡后面的颜色
这两条路径的梯度方向可能相反——当第 \(i\) 个点的颜色与最终像素颜色不匹配时,增大 \(\sigma_i\) 既增加了"错误颜色"的贡献又减少了"正确颜色"的贡献。反向 AD 自动处理了这种复杂的梯度交互。
14.2 3D Gaussian Splatting 中的 AD 挑战 ⭐⭐⭐⭐¶
3D Gaussian Splatting 的渲染管线涉及几个 AD 需要特殊处理的步骤:
- 3D→2D 投影:将 3D 高斯投影到图像平面。这一步是光滑可微的(投影是分式线性变换)
- 深度排序:按深度对高斯排序。排序是不连续的(相邻高斯交换顺序时发生跳变),但 Kerbl et al. 忽略了排序对参数的梯度(近似处理)
- Alpha compositing:前到后的 alpha 合成。结构类似于 NeRF 的离散体积渲染
Kerbl et al. 为了在 CUDA 上高效运行,实现了**手写的反向传播 CUDA kernel**——通用 AD 框架(如 PyTorch autograd)的开销在像素级并行渲染中不可接受。
与机器人学的联系:可微渲染使得"从多视角图像重建 3D 场景"变成端到端可训练的管线。在机器人抓取、导航等任务中,可微渲染提供了一种从视觉输入直接优化动作的途径——这与可微 MPC 的思想相呼应。
可微渲染在机器人学中的具体应用场景:
| 应用 | 可微渲染的角色 | 代表工作 |
|---|---|---|
| 抓取位姿估计 | 从 RGB 图像反推物体 6D 位姿 | NeRF-based pose estimation |
| 导航地图构建 | 用 NeRF/3DGS 构建可微的场景表示 | NeRF-Nav, SplaTAM |
| 模仿学习 | 从视觉演示中提取 3D 信息 | 3D-aware imitation |
| 视觉伺服 | 渲染 → 误差图像 → 梯度回传优化相机运动 | photometric visual servoing |
这些应用的共同模式是:可微渲染充当"3D 世界 ↔ 2D 图像"的可微桥梁,使得 3D 空间中的优化可以用 2D 图像作为监督信号。
⚠️ 常见陷阱¶
🧠 思维陷阱:认为可微渲染可以替代传统的几何方法 - 错误推理:"有了 NeRF/3DGS,不需要点云、网格、SLAM 了" - 实际上:(1) NeRF/3DGS 的训练需要大量多视角图像和精确的相机位姿——获取这些数据本身就需要传统方法;(2) 渲染速度虽然近年大幅提升(3DGS 可达实时),但与传统光栅化相比仍有差距;(3) 对动态场景的泛化能力有限 - 正确思维:可微渲染是传统几何方法的**补充**(提供了新的优化方式),而非替代
练习¶
- 对 NeRF 的离散体积渲染公式,手动推导 \(\partial \hat{C} / \partial \sigma_i\)。验证它包含两项(直接贡献和遮挡贡献),并解释两项的物理含义。
- 解释为什么 3D Gaussian Splatting 的深度排序步骤在 AD 中需要特殊处理。如果忽略排序梯度,会引入什么偏差?在什么条件下这个偏差可以忽略?
- 如果你要对一个 NeRF 场景做"视角优化"(找到使某个目标最大的相机位姿),梯度链条是:相机位姿 → 射线参数 → 采样点 → MLP 输出 → 体积渲染 → 目标。这个链条中每一步是否可微?如果有不可微的步骤,应该如何处理?
15. 典型例题与代码实战 ⭐⭐¶
15.1 例题:用 JAX 实现教学级 dual number 前向 AD¶
import jax.numpy as jnp
class Dual:
"""教学级 dual number 实现"""
def __init__(self, real, dual=0.0):
self.real = real # 函数值
self.dual = dual # 导数值
def __add__(self, other):
if isinstance(other, Dual):
return Dual(self.real + other.real, self.dual + other.dual)
return Dual(self.real + other, self.dual)
def __mul__(self, other):
if isinstance(other, Dual):
# 乘法法则:(a + bε)(c + dε) = ac + (ad + bc)ε
return Dual(self.real * other.real,
self.real * other.dual + self.dual * other.real)
return Dual(self.real * other, self.dual * other)
def __truediv__(self, other):
if isinstance(other, Dual):
# 除法法则:(a + bε)/(c + dε) = a/c + (bc - ad)/c² ε
return Dual(self.real / other.real,
(self.dual * other.real - self.real * other.dual)
/ other.real**2)
return Dual(self.real / other, self.dual / other)
def __radd__(self, other):
return Dual(other + self.real, self.dual)
def __rmul__(self, other):
return Dual(other * self.real, other * self.dual)
def sin(x):
if isinstance(x, Dual):
import math
return Dual(math.sin(x.real), math.cos(x.real) * x.dual)
return jnp.sin(x)
def exp(x):
if isinstance(x, Dual):
import math
e = math.exp(x.real)
return Dual(e, e * x.dual)
return jnp.exp(x)
# 测试:f(x) = sin(x²) * exp(x)
def f(x):
return sin(x * x) * exp(x)
# x = 1.0, seed = 1(求 f'(1))
x = Dual(1.0, 1.0)
result = f(x)
print(f"f(1.0) = {result.real:.10f}") # 函数值
print(f"f'(1.0) = {result.dual:.10f}") # 导数值(精确!)
# 与解析导数比较:f'(x) = [2x cos(x²) + sin(x²)] * exp(x)
import math
exact = (2*1.0*math.cos(1.0) + math.sin(1.0)) * math.exp(1.0)
print(f"解析导数 = {exact:.10f}")
print(f"误差 = {abs(result.dual - exact):.2e}") # 应为 ~1e-16(浮点精度)
15.2 例题:用 CasADi 构建 NLP 并提取灵敏度¶
import casadi as ca
import numpy as np
# 单摆最优控制:min ∫u²dt s.t. θ̈ = -g/l sin(θ) + u/(ml²)
N = 50 # 离散点数
dt = 0.05 # 时间步长
# 符号变量
theta = ca.SX.sym('theta', N+1) # 角度轨迹
omega = ca.SX.sym('omega', N+1) # 角速度轨迹
u = ca.SX.sym('u', N) # 控制输入
g, l, m = 9.81, 1.0, 1.0
# 目标函数:最小化控制能量
J = ca.sum1(u**2) * dt
# 动力学约束(Euler 积分)
constraints = []
for k in range(N):
# θ_{k+1} = θ_k + ω_k * dt
constraints.append(theta[k+1] - theta[k] - omega[k] * dt)
# ω_{k+1} = ω_k + (-g/l sin(θ_k) + u_k/(ml²)) * dt
constraints.append(omega[k+1] - omega[k]
- (-g/l * ca.sin(theta[k]) + u[k]/(m*l**2)) * dt)
# 边界条件
constraints.append(theta[0] - 0) # 初始角度 = 0
constraints.append(omega[0] - 0) # 初始角速度 = 0
constraints.append(theta[N] - ca.pi) # 终端角度 = π(摆上去)
constraints.append(omega[N] - 0) # 终端角速度 = 0
g_vec = ca.vertcat(*constraints)
x_vec = ca.vertcat(theta, omega, u)
# 构建 NLP
nlp = {'x': x_vec, 'f': J, 'g': g_vec}
opts = {'ipopt.print_level': 0, 'print_time': 0}
solver = ca.nlpsol('solver', 'ipopt', nlp, opts)
# 求解
n_vars = 2*(N+1) + N
sol = solver(x0=np.zeros(n_vars),
lbg=np.zeros(g_vec.shape[0]),
ubg=np.zeros(g_vec.shape[0]))
print(f"最优控制代价: {float(sol['f']):.4f}")
# 灵敏度分析:∂J*/∂g(重力加速度的灵敏度)
# 这需要对 KKT 系统做隐式微分——CasADi 自动处理
15.3 例题:用 JAX 实现可微 QP 层(IFT 版) ⭐⭐⭐¶
这个例子展示了如何用隐式微分实现一个可微 QP 层。理解这段代码就等于理解了 OptNet 的核心原理。
import jax
import jax.numpy as jnp
from functools import partial
def solve_qp(Q, q):
"""
求解无约束 QP: min 0.5 * y^T Q y + q^T y
最优解: y* = -Q^{-1} q
"""
return jnp.linalg.solve(Q, -q)
@partial(jax.custom_vjp)
def differentiable_qp(Q, q):
"""可微 QP 层:前向求解 + 自定义反向传播"""
return solve_qp(Q, q)
def differentiable_qp_fwd(Q, q):
"""前向传播:求解 QP 并保存 residuals 供反向使用"""
y_star = solve_qp(Q, q)
return y_star, (Q, q, y_star) # 返回结果 + 保存用于反向的数据
def differentiable_qp_bwd(res, g):
"""
反向传播:用 IFT 计算 VJP。
KKT 条件: F(y, Q, q) = Qy + q = 0
IFT: dy*/dq = -Q^{-1} (显然)
dy*/dQ = -Q^{-1} (dy* ⊗ I) (需要矩阵微分)
VJP: 给定上游梯度 g = dL/dy*
dL/dq = g^T * dy*/dq = -Q^{-1} g
dL/dQ = -Q^{-1} g * y*^T (通过矩阵微分得到)
"""
Q, q, y_star = res
# 解 KKT 线性方程: Q u = g
u = jnp.linalg.solve(Q, g)
# VJP 结果
dL_dq = -u # dL/dq
dL_dQ = -jnp.outer(u, y_star) # dL/dQ(矩阵微分)
return dL_dQ, dL_dq
differentiable_qp.defvjp(differentiable_qp_fwd, differentiable_qp_bwd)
# 测试:验证自定义 VJP 的正确性
n = 5
key = jax.random.PRNGKey(42)
Q = jax.random.normal(key, (n, n))
Q = Q @ Q.T + 0.1 * jnp.eye(n) # 确保正定
q = jax.random.normal(jax.random.PRNGKey(1), (n,))
# 定义 loss: L = sum(y*)
def loss_fn(Q, q):
y = differentiable_qp(Q, q)
return jnp.sum(y)
# 用我们的 IFT VJP 计算梯度
grad_Q, grad_q = jax.grad(loss_fn, argnums=(0, 1))(Q, q)
print(f"dL/dq (IFT): {grad_q[:3]}")
# 用有限差分验证
eps = 1e-5
grad_q_fd = jnp.zeros(n)
for i in range(n):
e_i = jnp.zeros(n).at[i].set(1.0)
grad_q_fd = grad_q_fd.at[i].set(
(loss_fn(Q, q + eps * e_i) - loss_fn(Q, q - eps * e_i)) / (2 * eps)
)
print(f"dL/dq (FD): {grad_q_fd[:3]}")
print(f"相对误差: {jnp.linalg.norm(grad_q - grad_q_fd) / jnp.linalg.norm(grad_q):.2e}")
代码解读:
这段代码的核心是 differentiable_qp_bwd 函数。它实现了 §7-§8 中推导的 KKT 隐式微分:
- 给定上游梯度 \(\bar{y} = dL/dy^*\)
- 解 KKT 线性方程 \(Qu = \bar{y}\)(即 \(u = Q^{-1}\bar{y}\))
- VJP 结果:\(dL/dq = -u\),\(dL/dQ = -u \cdot y^{*\top}\)
注意**不需要展开 QP 求解器**——不管 \(y^*\) 是用 Cholesky 分解、共轭梯度法还是 ADMM 求解的,反向传播只需要解一个额外的 KKT 线性系统。这就是隐式微分的力量。
15.4 例题:JAX 中的 forward-over-reverse Hessian-向量积 ⭐⭐⭐¶
import jax
import jax.numpy as jnp
def rosenbrock(x):
"""Rosenbrock 函数:经典的非凸优化测试函数"""
return jnp.sum(100.0 * (x[1:] - x[:-1]**2)**2 + (1.0 - x[:-1])**2)
def hvp_forward_over_reverse(f, x, v):
"""
Hessian-向量积:H(x) @ v
实现:forward-over-reverse
- 内层(reverse):计算 grad(f)(x)
- 外层(forward):计算 d/dt [grad(f)(x + t*v)]|_{t=0}
总代价约 4-5 倍 cost(f),与 n 无关
"""
return jax.jvp(jax.grad(f), (x,), (v,))[1]
# 测试
n = 100
x = jnp.ones(n) # Rosenbrock 的全局最小值在 x = [1, 1, ..., 1]
v = jax.random.normal(jax.random.PRNGKey(0), (n,))
# Hessian-向量积
Hv = hvp_forward_over_reverse(rosenbrock, x, v)
print(f"H @ v 的前 5 个元素: {Hv[:5]}")
# 与完整 Hessian 的乘积比较(小 n 才能做)
if n <= 50:
H = jax.hessian(rosenbrock)(x)
Hv_exact = H @ v
print(f"误差: {jnp.linalg.norm(Hv - Hv_exact):.2e}")
# 性能比较:hvp vs 完整 Hessian
import time
x_large = jnp.ones(1000)
v_large = jax.random.normal(jax.random.PRNGKey(0), (1000,))
# JIT 编译
hvp_jit = jax.jit(lambda x, v: hvp_forward_over_reverse(rosenbrock, x, v))
_ = hvp_jit(x_large, v_large) # 预热
t0 = time.time()
for _ in range(100):
_ = hvp_jit(x_large, v_large).block_until_ready()
t_hvp = (time.time() - t0) / 100
print(f"n=1000, HVP 时间: {t_hvp*1000:.2f} ms")
print(f"完整 Hessian 大小: {1000*1000*8/1024/1024:.1f} MB — 对比 HVP 几乎不用额外内存")
练习¶
- 修改 15.3 的代码,增加等式约束 \(Ay = b\)。KKT 矩阵变成 \(\begin{pmatrix} Q & A^\top \\ A & 0 \end{pmatrix}\),反向传播需要解什么线性系统?
- 用 15.4 的 HVP 实现一个 truncated Newton 法求解 Rosenbrock 函数:在每次 Newton 迭代中,用 CG + HVP 近似求解 \(Hd = -g\),而不是显式构造 Hessian。比较 CG 迭代次数 = 5, 10, 20 时的收敛速度。
- (跨章综合)回顾凸分析中的强凸性:强凸函数满足 \(\nabla^2 f(x) \succeq \mu I\)。解释为什么对强凸函数,CG + HVP 总是能在有限步内收敛到精确 Newton 方向。
本章小结¶
| 主题 | 核心要点 | 关键公式/概念 | 难度 |
|---|---|---|---|
| 计算图 | 一切 AD 的基础——程序 = DAG | 中间变量 \(v_i\)、基本运算 | ⭐ |
| 三种微分 | AD ≠ 有限差分 ≠ 符号微分 | 截断误差 vs 表达式膨胀 | ⭐ |
| 前向 AD | dual number + JVP | \(f(a+b\varepsilon) = f(a) + f'(a)b\varepsilon\) | ⭐⭐ |
| Ceres Jet | C++ 模板前向 AD | Jet<T,N> 运算符重载 |
⭐⭐ |
| 反向 AD | adjoint + VJP | \(\bar{v}_j = \sum_i \bar{v}_i \partial v_i / \partial v_j\) | ⭐⭐ |
| Cheap Gradient | \(\text{cost}(\nabla f) \leq 5 \cdot \text{cost}(f)\) | Baur-Strassen 1983 | ⭐⭐⭐ |
| 高阶 AD | Hv = forward-over-reverse | hyper-dual, Taylor 模式 | ⭐⭐⭐ |
| Checkpointing | 时间换内存 | \(O(\sqrt{K})\) 内存 | ⭐⭐⭐ |
| 框架选型 | 匹配问题特性选工具 | JAX/PyTorch/CasADi/Ceres/Enzyme | ⭐⭐ |
| 隐函数定理 | \(dy^*/dx = -(F_y)^{-1} F_x\) | 压缩映射 + 全微分 | ⭐⭐ |
| 可微优化层 | KKT 上的 IFT | OptNet / cvxpylayers | ⭐⭐⭐ |
| 可微仿真 | 平滑→FoBG,接触→ZoBG | Suh 2022 bias-variance | ⭐⭐⭐ |
| 工程实践 | 梯度检查 + 数值稳定性 | gradcheck, log-sum-exp | ⭐⭐ |
| 连续伴随 | \(\dot{\lambda} = -\lambda^\top \partial f/\partial z\) | 终端条件 + 反向积分 | ⭐⭐⭐ |
| Neural ODE | 连续深度 + 常数内存 | adjoint ODE, Chen 2018 | ⭐⭐⭐ |
| DEQ | 定点 + IFT | \((I-\partial f/\partial z)^{-1}\) | ⭐⭐⭐⭐ |
| 统一框架 | \(F(x^*,\theta)=0\) | Blondel 2022 | ⭐⭐⭐ |
| 可微 MPC | KKT IFT / PDP | Amos 2018, Jin 2020 | ⭐⭐⭐ |
| 可微渲染 | 体积渲染 + alpha compositing | NeRF, 3DGS | ⭐⭐⭐⭐ |
三大核心洞见:
洞见一:AD 不是一个工具,而是"链式法则的程序化"。前向即 JVP、反向即 VJP,Jet/tape/source-to-source/LLVM-IR 只是实现形式。选前向还是反向由输入输出维度比决定,而非语言或框架。
洞见二:隐式微分是"展开式 AD"的严格优越替代。只要能把收敛点表达为 \(F(x^*, \theta) = 0\),就可用 IFT 以 \(O(1)\) 内存得到收敛解的精确梯度。这一原理统一了 DEQ、Neural ODE、OptNet、cvxpylayers、可微 MPC。
洞见三:在 RL 与机器人交叉领域,梯度的来源决定算法成败。光滑动力学下 first-order gradient 带来 10-100 倍样本效率,但接触、混沌、刚性会让它崩溃。Suh 2022 的 bias-variance 定理是选择 PPO、MPPI、SHAC 还是 diff-MPC 的数学判据。
累积项目:本章新增模块¶
数学库项目进度:
| 前序章节 | 模块 | 本章 |
|---|---|---|
| Ch1-2 向量与矩阵 | 向量类 + 矩阵乘法 | - |
| Ch3 分解 | LU/QR/SVD | - |
| Ch4 最小二乘 | 伪逆 + 正规方程 | - |
| 本章 | Dual 类 + 前向/反向 AD + 可微 QP | 新增 |
本章新增四个模块:
模块 1:Dual 类与前向 AD ⭐⭐
实现 Dual<T> 类,重载全部基本运算和数学函数:
- 四则运算:+, -, *, /
- 三角函数:sin, cos, tan, asin, acos, atan2
- 指数/对数:exp, log, sqrt, pow
- 比较运算(仅用实部比较):<, >, ==
实现 forward_diff(f, x, v) 函数:给定函数 \(f\)、求值点 \(x\) 和方向 \(v\),用 dual number 计算 JVP \(Jv\)。
验证标准:对 10 个随机函数和随机方向,JVP 结果与中心差分的相对误差 \(< 10^{-10}\)。
模块 2:Tape 反向 AD ⭐⭐⭐
实现一个简化的反向 AD 引擎:
- Tape 类:记录计算图(每个节点的运算类型、父节点、局部导数)
- Variable 类:重载运算符,每次运算自动记录到 Tape
- backward(loss) 函数:从 loss 节点开始,逆拓扑序传播 adjoint
这个模块的目标不是性能,而是**理解反向 AD 的完整工作机制**。用不超过 200 行代码实现一个"可以工作但很慢"的反向 AD 引擎。
验证标准:对 5 个测试函数,反向 AD 的梯度与模块 1 的前向 AD 完全一致(浮点精度内)。
模块 3:梯度检查工具 ⭐⭐
实现 grad_check(f, x, eps=1e-5) 函数:
- 对每个分量用中心差分计算数值梯度
- 与 AD 梯度比较,输出逐分量的相对误差
- 如果最大相对误差 > 阈值(默认 \(10^{-3}\)),打印警告
这个工具将在后续所有涉及自定义 AD 规则的章节中使用。
模块 4:可微 QP 层 ⭐⭐⭐
实现 DifferentiableQP(Q, q, A, b) 类:
- 前向:用 scipy 或手写 KKT 求解器求解等式约束 QP
- 反向:用 KKT 隐式微分计算 VJP
- 接口兼容模块 2 的 Tape 引擎
验证标准:VJP 与有限差分的相对误差 \(< 10^{-5}\)。
项目整合测试:用模块 4 的可微 QP + 模块 3 的梯度检查,实现一个简单的"学习 QP 参数"实验——给定目标解 \(y_{\text{target}}\),用梯度下降优化 \(q\) 使得 \(\|y^*(Q, q) - y_{\text{target}}\|^2\) 最小化。
延伸阅读¶
入门级 ⭐:
- Baydin et al., "Automatic Differentiation in Machine Learning: a Survey", JMLR 2018 — AD 与机器学习的最佳入门综述,30 页即可建立完整概念框架
- 3Blue1Brown 的 backpropagation 视频 — 链式法则直觉的最佳可视化
- Matthew Johnson, "JAX Autodiff Cookbook" (docs.jax.dev) — JAX 中 jvp/vjp/jacfwd/jacrev 的权威 how-to
- Roger Grosse, CSC321 Backprop 讲义 — 从矩阵视角理解反向传播
核心级 ⭐⭐: - Griewank & Walther, "Evaluating Derivatives", SIAM 2008 — AD 学科圣经,必读 Ch.3-4(前向/反向基础)、Ch.12(checkpointing)和 Ch.15(隐式/迭代微分) - Amos & Kolter, "OptNet: Differentiable Optimization as a Layer", ICML 2017 — 可微优化层的开山之作 - Blondel et al., "Efficient and Modular Implicit Differentiation", NeurIPS 2022 — 隐式微分的现代统一处理,JAXopt 的数学基础 - Nocedal & Wright, "Numerical Optimization", Springer 2006 — Ch.8(计算导数)和 Ch.12(KKT 理论) - Andersson et al., "CasADi: A software framework for nonlinear optimization", Math. Prog. Comp. 2019 — CasADi 的设计与架构 - Ceres Solver 官方文档 (ceres-solver.org) — Jet 实现细节和 SLAM 应用
进阶级 ⭐⭐⭐: - Carpentier & Mansard, "Analytical Derivatives of Rigid Body Dynamics", RSS 2018 — Pinocchio 解析微分的理论基础 - Suh et al., "Do Differentiable Simulators Give Better Policy Gradients?", ICML 2022 — 可微仿真梯度的 bias-variance 分析(Outstanding Paper) - Howell et al., "Dojo: A Differentiable Physics Engine for Robotics", RSS 2022 — 精确接触梯度的引擎 - Chen et al., "Neural Ordinary Differential Equations", NeurIPS 2018 — 连续深度模型与 adjoint 方法 - Bai, Kolter, Koltun, "Deep Equilibrium Models", NeurIPS 2019 — 定点隐式微分的深度学习应用 - Agrawal et al., "Differentiable Convex Optimization Layers", NeurIPS 2019 — cvxpylayers 的数学与实现 - Chris Rackauckas, MIT 18.337 "Parallel Computing and Scientific Machine Learning" (book.sciml.ai) — adjoint sensitivity 的多种推导
研究前沿 ⭐⭐⭐⭐: - Xu et al., "SHAC", ICLR 2022 — 截断窗口 + actor-critic 的可微仿真 RL - Pineda et al., "Theseus", NeurIPS 2022 — PyTorch 上的可微非线性最小二乘 - Moses & Churavy, "Enzyme", NeurIPS 2020 — LLVM IR 级编译器 AD - Amos et al., "Differentiable MPC for End-to-end Planning and Control", NeurIPS 2018 — 可微 MPC 的开创性工作 - Jin et al., "Pontryagin Differentiable Programming", NeurIPS 2020 — PMP + 辅助 LQR 的优雅方法 - Dontchev & Rockafellar, "Implicit Functions and Solution Mappings", Springer 2014 — IFT 的变分分析版权威参考 - Kolter, Duvenaud, Johnson, NeurIPS 2020 Tutorial "Deep Implicit Layers" (implicit-layers-tutorial.org) — DEQ/ODE/OptNet 三合一讲义 - Brandon Amos 个人博客 "On Differentiable Optimization" (bamos.github.io) — OptNet 作者的系列深度博文
🔧 故障排查手册¶
| 症状 | 可能原因 | 排查步骤 | 相关章节 |
|---|---|---|---|
| AD 导数全为零 | 计算图断裂或类型转换丢失 tangent | 1. 检查是否有 double() 转换截断了 Jet 2. 检查 detach() 是否在梯度路径上 3. 用 gradcheck 验证 |
§3, §6 |
| AD 导数与有限差分差异大(>1e-3) | 自定义 VJP 有 bug 或数值不稳定 | 1. 逐运算拆开做 gradcheck 2. 检查是否有 log(0) 或 sqrt(0) 3. 尝试 float64 排除舍入 | §10.1 |
| 反向传播 OOM | 计算图过长或中间激活未释放 | 1. 检查序列长度 2. 使用 torch.utils.checkpoint 3. 用隐式微分替代展开 |
§5.3, §7 |
| CasADi 符号爆炸(构建图极慢) | 在 SX 中展开了大循环 | 1. 改用 MX 矩阵运算 2. 避免 Python for 循环 3. 用 ca.horzcat/vertcat 向量化 |
§6.2 |
| 可微仿真梯度方向错误 | 接触不连续导致 FoBG 偏差 | 1. 检查是否有 Heaviside 类运算 2. 加随机光滑化 3. 比较 FoBG 与 ZoBG | §9.3 |
| Ceres Jet 编译错误 | cost function 调用了非模板化函数 | 1. 检查所有数学函数是否用 ceres::sin 等 2. 确保全部 template<typename T> 3. 参数块维度之和与 Jet N 一致 |
§3.3 |
| 隐式微分结果数值不稳定 | KKT 矩阵条件数过大 | 1. 检查 active set 是否在切换边界 2. 添加 Tikhonov 正则化 3. 监控 \(\kappa(K)\) | §7, §8 |
| Neural ODE 伴随梯度出现 NaN | 刚性 ODE 反向积分发散 | 1. 检查 Jacobian 特征值 2. 用隐式 ODE solver 3. 切换到 checkpointed BPTT | §11.2 |
MuJoCo MJX jax.grad 报错 |
CG 求解器用 while_loop 不支持反向 |
1. 切换到 Newton 求解器 2. 用 fori_loop 固定迭代上界 3. 考虑用 Brax 替代 |
§9 |
| DEQ 前向不收敛 | \(\|\\partial f/\\partial z\| \geq 1\) 不是压缩映射 | 1. 检查 spectral norm 2. 添加 spectral normalization 3. 减小网络宽度或深度 | §12 |
| cvxpylayers 报 "not DPP compliant" | 参数出现在非仿射位置 | 1. 检查哪个参数违反 DPP 2. 改写为仿射形式(如 sum_squares(P_sqrt @ x) 替代 quad_form(x, P)) |
§8.3 |
| Pinocchio 解析微分与 CppAD 结果不一致 | 坐标约定或符号约定差异 | 1. 对齐 body frame 约定(local vs world) 2. 检查 Jacobian 是 body 还是 spatial 3. 打印中间值逐步比较 | §9.1 |
| PDP 辅助 LQR 发散 | Hamiltonian Hessian 不正定 | 1. 检查 \(H_{uu}\) 是否正定 2. 添加正则化 \(H_{uu} + \epsilon I\) 3. 缩短 MPC horizon 减少病态性 | §13.2 |
使用本手册的建议:遇到问题时,先在本表中搜索匹配的"症状"。如果找不到完全匹配的条目,从"可能原因"列中找最接近的猜测,按"排查步骤"逐步定位。如果仍无法解决,回到对应章节重新理解底层原理——大多数 AD 的 bug 都源于对数学原理的误解,而非代码层面的技术问题。
排查的通用策略:对任何 AD 相关的问题,第一步始终是**梯度检查**(§10.1)。如果 AD 梯度通过了梯度检查但结果仍然不好,问题出在数学建模层面(如条件数、不连续性)而非 AD 实现层面。如果梯度检查未通过,问题出在 AD 实现(如类型转换、图断裂、自定义规则 bug)。这个二分法可以快速定位问题的根源。