跳转至

自动微分与隐式微分

前置自测

📋 前置自测(答不出 ≥ 2 题 → 先回凸分析基础与非线性优化复习)

  1. 什么是链式法则?对复合函数 \(f(g(x))\),如何用 \(f'\)\(g'\) 表达 \((f \circ g)'(x)\)
  2. 给定矩阵 \(A \in \mathbb{R}^{m \times n}\),什么是 Jacobian 矩阵?它的维度是多少?
  3. 什么是 KKT 条件?写出等式约束优化 \(\min f(x) \text{ s.t. } h(x)=0\) 的 KKT 系统。
  4. 隐函数定理的直觉含义是什么?如果 \(F(x,y)=0\)\(\partial F/\partial y\) 可逆,能得到什么结论?
  5. 解释有限差分法计算导数 \(f'(x) \approx \frac{f(x+h)-f(x)}{h}\) 的两个误差来源。

本章目标

学完本章后,你将能够:(1) 在白板上用 dual number 推导前向模式 AD,用 adjoint 变量推导反向模式 AD,并清楚两者的适用场景和复杂度差异;(2) 在 JAX、PyTorch、CasADi、Ceres 中选择合适的 AD 工具实现梯度计算,并理解它们的设计哲学差异;(3) 从隐函数定理出发推导可微优化层(OptNet/cvxpylayers)的反向传播公式,理解为什么隐式微分优于展开式 AD;(4) 判断可微仿真中 first-order gradient 何时有效、何时崩溃,并据此选择合适的策略梯度方法。


本章知识地图

自动微分与隐式微分
├── 基础概念(§1-§2)
│   ├── 计算图与三种微分方式
│   └── 为什么 AD 既不是有限差分也不是符号微分
├── 前向模式 AD(§3)
│   ├── dual number 代数
│   ├── JVP(雅可比-向量积)
│   └── Ceres Jet 模板实现
├── 反向模式 AD(§4)
│   ├── adjoint 变量与反向传播
│   ├── VJP(向量-雅可比积)
│   └── Cheap Gradient Principle
├── 高阶 AD 与工程技巧(§5)
│   ├── Hessian-向量积(forward-over-reverse)
│   ├── hyper-dual number
│   ├── checkpointing 节省内存
│   └── Taylor 模式 AD
├── AD 框架比较与选型(§6)
│   ├── JAX / PyTorch / CasADi / Ceres / CppAD / Enzyme
│   └── 选型决策树
├── 隐式微分(§7-§8)
│   ├── 隐函数定理完整证明
│   ├── VJP 形式的隐式微分
│   └── 可微优化层(OptNet / cvxpylayers)
├── 可微物理仿真(§9)
│   ├── Pinocchio 解析微分
│   ├── Brax / MuJoCo MJX / Dojo
│   ├── 接触微分的困难与 FoBG/ZoBG 分析
│   └── SHAC 截断窗口方法
├── 工程实践(§10)
│   ├── 梯度检查
│   ├── 数值稳定性
│   └── 可微渲染简介(NeRF / 3DGS)
├── 连续伴随方法与 Neural ODE(§11)
│   ├── 连续伴随方程推导
│   ├── Neural ODE 优势与局限
│   └── discretize-then-optimize vs optimize-then-discretize
├── 深度平衡模型与定点隐式微分(§12)
│   ├── DEQ 定点隐式微分推导
│   ├── Neumann 级数与 Anderson 加速
│   └── Blondel 2022 统一框架
├── 可微 MPC 与 PDP(§13)
│   ├── Differentiable MPC(Amos 2018)
│   ├── Pontryagin Differentiable Programming(Jin 2020)
│   └── 工程选型指南
├── 可微渲染与三维表示学习(§14)
│   └── NeRF 与 3D Gaussian Splatting 的 AD 分析
└── 典型例题与累积项目(§15)

1. 计算图:理解一切的起点 ⭐

动机

你写下一段 Python 代码来计算函数值。对计算机而言,这段代码被分解为一系列**基本运算**——加、减、乘、除、sin、exp 等。这些运算之间的数据依赖关系构成一张**有向无环图**(DAG),称为**计算图**。

理解计算图至关重要,因为自动微分的全部内容都可以归结为一句话:在计算图上系统地应用链式法则

一个具体例子

考虑函数 \(f(x_1, x_2) = x_1 x_2 + \sin(x_1)\)。将其拆解为基本运算:

\[v_1 = x_1, \quad v_2 = x_2, \quad v_3 = v_1 \cdot v_2, \quad v_4 = \sin(v_1), \quad v_5 = v_3 + v_4 = y\]

这些中间变量 \(v_i\) 构成计算图的节点,箭头表示数据流向。在 \(x_1 = \pi/4, x_2 = 2\) 处,前向计算(primal evaluation)依次得到:

变量 运算 数值
\(v_1\) \(x_1\) \(\pi/4 \approx 0.785\)
\(v_2\) \(x_2\) \(2\)
\(v_3\) \(v_1 \cdot v_2\) \(\pi/2 \approx 1.571\)
\(v_4\) \(\sin(v_1)\) \(\sin(\pi/4) \approx 0.707\)
\(v_5\) \(v_3 + v_4\) \(2.278\)

计算图是**静态的结构信息**——它记录了哪些运算以什么顺序执行。有了这张图,我们就可以用两种不同的策略在上面传播导数。

本质洞察:计算图把"一段程序"翻译成"一张数学图"。在这张图上,链式法则不再是一个抽象的数学定理,而是一个**可以机械执行的算法**。前向模式从输入到输出传播导数,反向模式从输出到输入传播导数——两者的区别仅在于遍历方向。

跨领域类比:计算图就像工厂的装配线

计算图可以类比为工厂的装配线:每个节点是一个加工工位,输入的原材料(变量)经过一系列工位加工,最终产出成品(函数值)。前向模式 AD 就像在装配线上每个工位贴一张标签"这个零件对原材料 \(x_1\) 的敏感度是多少",标签从头传到尾。反向模式 AD 则是从成品开始,逆着装配线追溯"成品质量对每个工位的调整有多敏感"。

两者都利用了装配线的线性结构(链式法则),但方向不同——这决定了在不同场景下哪种更高效。

⚠️ 常见陷阱

💡 概念误区:认为自动微分就是数值微分(有限差分) - 错误想法:"AD 不就是用 \((f(x+h)-f(x))/h\) 来近似导数吗?" - 实际上:AD 是**精确的**(在浮点精度范围内),它通过在计算图上应用链式法则得到导数的精确值,不存在截断误差 - 根本原因:有限差分是对导数定义的数值近似,AD 是对链式法则的精确执行——两者的数学基础完全不同 - 正确理解:AD 与有限差分的关系,就像精确解与数值解的关系

🧠 思维陷阱:认为计算图只存在于深度学习框架中 - 错误想法:"计算图是 PyTorch/TensorFlow 的概念,传统优化不需要" - 实际上:任何可微分的计算过程都有隐式的计算图。Ceres 的模板元编程在编译期构建计算图;CasADi 在符号层面构建计算图;手写的 RNEA 递推也是一张特殊的计算图 - 启示:理解计算图让你能统一理解所有 AD 工具的底层原理

练习

  1. 给出 \(f(x_1, x_2, x_3) = x_1 x_2 \sin(x_3) + e^{x_1} x_2\) 的计算图(画出所有中间变量和运算)。在 \((x_1, x_2, x_3) = (1, 2, \pi/6)\) 处计算所有中间变量的值。
  2. 对上述计算图,数一数共有多少条边(数据依赖关系)。思考:如果函数有 \(n\) 个输入和 \(m\) 个输出,计算图的边数与 \(n, m\) 有什么关系?
  3. (跨章综合)回顾凸分析中的仿射函数 \(f(x) = Ax + b\)。画出 \(f: \mathbb{R}^3 \to \mathbb{R}^2\)\(A\)\(2 \times 3\) 矩阵)的计算图。这个计算图有什么特殊结构?(提示:所有运算都是线性的)

2. 三种微分方式的本质区别 ⭐

动机

在编写优化求解器时,你总需要计算目标函数的梯度。有三条路可走:手动推导符号导数、有限差分近似、自动微分。它们各有什么优劣?这个问题的答案决定了你选什么工具。

2.1 符号微分(Symbolic Differentiation)

符号微分是你在微积分课上学的方法:把 \(f(x) = x^2 \sin(x)\) 用求导法则变换为 \(f'(x) = 2x \sin(x) + x^2 \cos(x)\)

优点:得到封闭形式的表达式,便于理论分析和进一步化简。

致命缺点——表达式膨胀(expression swell):对复合函数反复应用乘积法则和链式法则,表达式的长度会**指数级增长**。例如,对 \(f(x) = \prod_{i=1}^{100} \sin(ix)\) 求导,符号微分会产生一个包含 \(2^{100}\) 项的和式——而函数值本身只需要 100 次乘法。

如果不做表达式膨胀会怎样?对于复杂的目标函数(如一个 1000 步的 MPC 展开、一个 100 层的神经网络),符号微分产生的导数表达式会大到无法存储,更不用说计算。这就是为什么深度学习不用 SymPy 来求梯度。

2.2 有限差分(Finite Differences)

有限差分用导数的极限定义做近似:

\[\frac{\partial f}{\partial x_i} \approx \frac{f(x + h e_i) - f(x)}{h}\]

其中 \(e_i\) 是第 \(i\) 个单位向量。这是最"朴素"的方法。

优点:实现极其简单(几行代码),不需要了解函数内部结构。

两大致命缺点

(1) 截断误差与舍入误差的矛盾\(h\) 太大,截断误差 \(O(h)\) 大(中心差分 \((f(x+h)-f(x-h))/(2h)\) 可减至 \(O(h^2)\),但仍非零)。\(h\) 太小,\(f(x+h)\)\(f(x)\) 的差被浮点舍入误差淹没。对于 double 精度(\(\epsilon_{\text{mach}} \approx 10^{-16}\)),最优的 \(h \approx \sqrt{\epsilon_{\text{mach}}} \approx 10^{-8}\)(前向差分)或 \(h \approx \epsilon_{\text{mach}}^{1/3} \approx 10^{-5}\)(中心差分),精度上限分别约为 8 位和 10 位有效数字。

(2) 计算量与输入维度成正比。计算 \(n\) 维梯度需要 \(n\) 次(前向差分)或 \(2n\) 次(中心差分)函数求值。对于 SLAM 的百万维参数或神经网络的亿级参数,这完全不可行。

反事实推理:如果有限差分的精度是足够的(比如你只需要 3 位精度做一个粗略的调参),那么它仍然是最简单的选择。很多工程师在调试阶段用有限差分来**验证** AD 的正确性——这是它的正确用法。

2.3 自动微分(Automatic Differentiation)

自动微分站在符号微分和有限差分的"中间地带":它像符号微分一样**精确**(没有截断误差),又像有限差分一样**高效**(不产生表达式膨胀)。

AD 的核心思想:不对整个函数做符号求导,而是在计算图的**每个基本运算节点**上应用已知的求导规则,然后通过链式法则将这些局部导数"传播"成全局导数。

方法 精度 计算量 实现复杂度 适用场景
符号微分 精确(封闭形式) 可能指数级膨胀 需要 CAS 小规模解析推导
有限差分 \(O(h^p)\) 近似 \(O(n \cdot \text{cost}(f))\) 极低 调试验证、低维问题
前向 AD 精确(浮点范围内) \(O(n \cdot \text{cost}(f))\) 中等 输入少输出多
反向 AD 精确(浮点范围内) \(O(m \cdot \text{cost}(f))\)* 较高 输入多输出少(深度学习)

*注:反向模式的常数因子一般为 2-5 倍,这就是著名的 Cheap Gradient Principle。

AD 的历史可以追溯到 1960 年代。Robert E. Wengert 在 1964 年发表了第一篇关于自动微分的论文,提出了前向模式。但反向模式的发现更加曲折:Seppo Linnainmaa 在 1970 年的硕士论文中独立提出了反向累积(reverse accumulation),而 Paul Werbos 在 1974 年的博士论文中将其应用于神经网络训练,称为"反向传播"。直到 Rumelhart、Hinton 和 Williams 在 1986 年的 Nature 论文中推广了反向传播,AD 在机器学习中的重要性才被广泛认识。

Andreas Griewank 在 1989-2008 年的系统化工作将 AD 建立为一个严谨的数学学科,其著作《Evaluating Derivatives》(与 Walther 合著,SIAM 2008)至今仍是该领域的权威参考。

⚠️ 常见陷阱

⚠️ 编程陷阱:用有限差分做最终的梯度计算 - 错误做法:在 SLAM 求解器中用 (f(x+1e-8) - f(x)) / 1e-8 计算雅可比 - 现象:对于良态问题可能看起来"差不多对",但在病态问题(条件数大)中产生严重偏差,导致优化不收敛 - 根本原因:\(h=10^{-8}\) 时截断误差和舍入误差恰好平衡,但如果函数值本身很大或变化很剧烈,这个平衡点会移动 - 正确做法:用 AD 计算精确梯度,仅在调试时用有限差分做验证

💡 概念误区:认为 AD 只能算一阶导数 - 错误想法:"AD 只能算梯度,Hessian 还是得用有限差分" - 实际上:AD 可以嵌套使用(forward-over-reverse 或 reverse-over-forward)来计算任意阶导数。JAX 的 jax.hessian 就是自动嵌套两次 AD - 延伸:Hessian-向量积 \(Hv = \nabla(\nabla f \cdot v)\) 只需一次前向+一次反向,成本与计算 \(f\) 本身相当

练习

  1. 用 Python 编写一个函数,分别用前向差分、中心差分和复步微分(complex step)计算 \(f(x) = e^{\sin(x^3)}\)\(x=1.5\) 处的导数。比较三者的精度与真实值 \(f'(1.5)\) 的差异。
  2. 解释为什么中心差分 \((f(x+h)-f(x-h))/(2h)\) 的截断误差是 \(O(h^2)\) 而不是 \(O(h)\)。(提示:对 \(f(x\pm h)\) 做 Taylor 展开。)
  3. 如果一个函数 \(f: \mathbb{R}^{1000} \to \mathbb{R}\) 的一次求值需要 1ms,用中心差分计算完整梯度需要多长时间?用反向 AD 呢?

3. 前向模式 AD:dual number 与 JVP ⭐⭐

动机

我们要计算函数 \(f: \mathbb{R}^n \to \mathbb{R}^m\)\(x_0\) 处沿方向 \(v\) 的方向导数 \(J \cdot v\)(其中 \(J = \partial f / \partial x\) 是 Jacobian 矩阵)。如果 \(n\) 很小(比如 Ceres 中每个残差只依赖 3-4 个参数),前向模式是最自然的选择。

3.1 Dual Number 的代数基础 ⭐⭐

定义与直觉

Dual number 是实数的一种扩展,定义为:

\[a + b\varepsilon, \quad \text{其中 } a, b \in \mathbb{R}, \quad \varepsilon^2 = 0 \text{ 但 } \varepsilon \neq 0\]

这个 \(\varepsilon\) 称为**无穷小量**,它满足 \(\varepsilon^2 = 0\) 这一关键性质。

跨领域类比:dual number 之于实数,就像复数之于实数。复数引入了 \(i^2 = -1\) 的虚数单位来扩展实数系统,使得所有多项式都有根。dual number 引入了 \(\varepsilon^2 = 0\) 的无穷小量来扩展实数系统,使得我们可以"自动"提取导数。它们都是**在实数上添加一个新元素并规定其代数规则**的代数扩展。

但两者的"用途"完全不同:复数用来解方程、做旋转;dual number 用来**无截断误差地传播导数信息**。

运算规则

Dual number 的四则运算规则完全由 \(\varepsilon^2 = 0\) 推导而来:

加法\((a + b\varepsilon) + (c + d\varepsilon) = (a+c) + (b+d)\varepsilon\)

乘法\((a + b\varepsilon)(c + d\varepsilon) = ac + (ad + bc)\varepsilon + bd\varepsilon^2 = ac + (ad + bc)\varepsilon\)

这里 \(bd\varepsilon^2 = 0\) 被丢弃——这正是"无穷小量"的妙处。

除法\(\frac{a + b\varepsilon}{c + d\varepsilon} = \frac{(a + b\varepsilon)(c - d\varepsilon)}{(c + d\varepsilon)(c - d\varepsilon)} = \frac{ac + (bc - ad)\varepsilon}{c^2} = \frac{a}{c} + \frac{bc - ad}{c^2}\varepsilon\)

为什么 dual number 能自动求导

核心观察:对任意解析函数 \(f\),将其在 \(a\) 处 Taylor 展开:

\[f(a + b\varepsilon) = f(a) + f'(a) \cdot b\varepsilon + \frac{f''(a)}{2!}(b\varepsilon)^2 + \cdots = f(a) + f'(a) \cdot b\varepsilon\]

因为 \(\varepsilon^2 = 0\),所有二阶及以上的项全部消失。这意味着**把 \(x_0 + 1 \cdot \varepsilon\) 代入 \(f\),实部给出 \(f(x_0)\),虚部给出 \(f'(x_0)\)**——函数值和导数同时得到,精确无误。

让我们用几个具体的基本函数验证这一点。

sin\(\sin(a + b\varepsilon) = \sin(a) + \cos(a) \cdot b\varepsilon\)(因为 \(\sin'(x) = \cos(x)\)

exp\(\exp(a + b\varepsilon) = \exp(a) + \exp(a) \cdot b\varepsilon = \exp(a)(1 + b\varepsilon)\)

log\(\log(a + b\varepsilon) = \log(a) + \frac{b}{a}\varepsilon\)

幂函数\((a + b\varepsilon)^n = a^n + n a^{n-1} b\varepsilon\)

这些规则与微积分中的求导法则**完全一致**——这不是巧合,而是 dual number 代数结构的必然结果。

本质洞察:dual number 把"求导"这个分析学操作编码进了代数运算中。你不需要"对程序做分析"来求导,只需要"用 dual number 替换实数运行程序"——导数会自动出现在 \(\varepsilon\) 系数中。这就是为什么 Ceres 的 Jet<T,N> 类只需要重载 +, -, *, /, sin, cos, exp, log 等运算符,就能对任意模板化的 cost function 自动求导。

3.2 从 dual number 到 JVP ⭐⭐

多元函数的推广

对多元函数 \(f: \mathbb{R}^n \to \mathbb{R}^m\),我们引入 \(n\) 个独立的无穷小方向 \(\dot{x}_1, \ldots, \dot{x}_n\)(称为"tangent"或"种子向量"),计算:

\[f(x_0 + \dot{x} \cdot \varepsilon) = f(x_0) + J(x_0) \cdot \dot{x} \cdot \varepsilon\]

其中 \(J(x_0) = \frac{\partial f}{\partial x}\big|_{x_0}\)\(m \times n\) 的 Jacobian 矩阵。\(\varepsilon\) 系数给出的是 Jacobian-向量积(Jacobian-Vector Product, JVP):

\[\text{JVP}: \quad (x, \dot{x}) \mapsto (f(x), \; J(x) \cdot \dot{x})\]

如果选择 \(\dot{x} = e_i\)(第 \(i\) 个单位向量),则 JVP 给出 Jacobian 的第 \(i\) 列。因此,计算完整 Jacobian 需要 \(n\) 次 JVP 调用。

前向传播的逐步执行

回到 §1 的例子 \(f(x_1, x_2) = x_1 x_2 + \sin(x_1)\)。选种子 \(\dot{x}_1 = 1, \dot{x}_2 = 0\)(即求 \(\partial f / \partial x_1\)):

变量 primal 值 tangent(\(\varepsilon\) 系数) 运算规则
\(v_1 = x_1\) \(\pi/4\) \(\dot{v}_1 = 1\) 种子
\(v_2 = x_2\) \(2\) \(\dot{v}_2 = 0\) 种子
\(v_3 = v_1 \cdot v_2\) \(\pi/2\) \(\dot{v}_3 = \dot{v}_1 v_2 + v_1 \dot{v}_2 = 2\) 乘法的 dual rule
\(v_4 = \sin(v_1)\) \(\sin(\pi/4)\) \(\dot{v}_4 = \cos(v_1) \cdot \dot{v}_1 = \cos(\pi/4)\) sin 的 dual rule
\(v_5 = v_3 + v_4\) \(2.278\) \(\dot{v}_5 = \dot{v}_3 + \dot{v}_4 = 2.707\) 加法的 dual rule

最终 \(\dot{v}_5 = 2 + \cos(\pi/4) \approx 2.707\) 就是 \(\partial f / \partial x_1\)。可以手工验证:\(\partial f / \partial x_1 = x_2 + \cos(x_1) = 2 + \cos(\pi/4) \approx 2.707\)。精确匹配!

复杂度分析

一次 JVP 的计算量等于**一次函数求值加一次同样结构的 tangent 传播**。每个基本运算的 tangent 传播最多是常数倍的额外运算(如乘法需要 2 次额外乘法和 1 次加法)。因此:

\[\text{cost(一次 JVP)} \leq c \cdot \text{cost}(f), \quad c \approx 2\text{-}3\]

计算完整 \(m \times n\) Jacobian 需要 \(n\) 次 JVP,总成本 \(O(n \cdot \text{cost}(f))\)

前向模式的最佳场景\(n \ll m\)——输入维度远小于输出维度。典型例子:Ceres 中每个残差 \(r_i: \mathbb{R}^6 \to \mathbb{R}^2\)(相机位姿 6 参数 → 重投影误差 2 维),\(n=6, m=2\),前向模式只需 6 次 JVP。

如果 \(n\) 很大会怎样? 考虑一个神经网络 \(f: \mathbb{R}^{10^7} \to \mathbb{R}\)(一亿参数 → 标量 loss)。前向模式需要 \(10^7\) 次 JVP——显然不可行。这就是反向模式登场的理由。

3.3 Ceres Jet 的模板实现 ⭐⭐

Ceres Solver 用 C++ 模板元编程实现了前向 AD,核心是 Jet<T, N> 类。理解 Jet 的实现有助于深刻理解前向 AD 的工程化。

// Ceres Jet 的核心结构(简化版,展示设计思想)
template <typename T, int N>
struct Jet {
    T a;       // primal 值(实部)
    T v[N];    // tangent 值(N 个方向的导数)

    // 构造函数:从标量构造(所有 tangent 为零)
    explicit Jet(const T& value) : a(value) {
        for (int i = 0; i < N; ++i) v[i] = T(0);
    }

    // 构造函数:指定第 k 个方向的种子
    Jet(const T& value, int k) : a(value) {
        for (int i = 0; i < N; ++i) v[i] = (i == k) ? T(1) : T(0);
    }
};

// 乘法重载:(a + v*ε) * (b + w*ε) = ab + (aw + bv)ε
template <typename T, int N>
Jet<T, N> operator*(const Jet<T, N>& f, const Jet<T, N>& g) {
    Jet<T, N> result;
    result.a = f.a * g.a;                    // primal: 普通乘法
    for (int i = 0; i < N; ++i)
        result.v[i] = f.a * g.v[i] + f.v[i] * g.a;  // tangent: 乘法法则
    return result;
}

// sin 重载:sin(a + v*ε) = sin(a) + cos(a)*v*ε
template <typename T, int N>
Jet<T, N> sin(const Jet<T, N>& f) {
    Jet<T, N> result;
    result.a = std::sin(f.a);               // primal: 普通 sin
    T cos_a = std::cos(f.a);
    for (int i = 0; i < N; ++i)
        result.v[i] = cos_a * f.v[i];       // tangent: cos(a) * v
    return result;
}

设计哲学:Ceres 要求用户把 cost function 写成 C++ 模板 template<typename T> bool operator()(const T* x, T* residual)。当 T = double 时,编译器生成计算函数值的代码;当 T = Jet<double, N> 时,编译器自动生成同时计算函数值和导数的代码——两者**共享同一份源码**,只是运算符重载不同。这就是"一份代码,两种用途"的模板元编程范式。

为什么 N 是模板参数而不是运行时参数? 因为 N 在编译期已知时,编译器可以展开循环、使用 SIMD 指令、消除内存分配——对 SLAM 中需要每秒调用上百万次的 cost function,这种优化至关重要。

⚠️ 常见陷阱

⚠️ 编程陷阱:在 Ceres cost function 中调用非模板化的外部函数 - 错误做法:double result = sqrt(x[0]); 而不是 T result = ceres::sqrt(x[0]); - 现象:当 T = Jet<double, N> 时,double sqrt(double) 会把 Jet 隐式转成 double,丢失所有导数信息,最终得到零导数 - 根本原因:C++ 的隐式类型转换会悄无声息地丢弃 tangent 信息 - 正确做法:cost function 中所有函数调用必须 template<typename T>

💡 概念误区:认为 Jet 的 N 等于总参数个数 - 错误想法:"我的问题有 1000 个参数,所以 N=1000" - 实际上:Ceres 的 AutoDiffCostFunction<F, residual_dim, param_block_1_dim, param_block_2_dim, ...> 中,N = 所有参数块维度之和。但由于 Ceres 每次只对**一个残差**求 Jacobian,N 通常很小(6-15),不是总参数个数 - 根本原因:稀疏性——每个残差只依赖少数参数块,Jacobian 是稀疏的

练习

  1. 手动实现一个 40 行的 C++ Dual 类(不用模板 N,只支持一个方向),重载 +, -, *, /, sin, cos, exp, log。用它计算 \(f(x) = \sin(x^2) \cdot e^x\)\(x=1\)\(f\)\(f'\),与真实值比较。
  2. 解释为什么 Jet 的乘法规则 \(\dot{v}_3 = \dot{v}_1 v_2 + v_1 \dot{v}_2\) 与微积分的乘法法则 \((fg)' = f'g + fg'\) 完全一致。
  3. 假设你有一个 cost function \(r: \mathbb{R}^3 \to \mathbb{R}^2\),使用 Ceres AutoDiffCostFunction<F, 2, 3>。计算一次 Jacobian 需要多少次 JVP?总共涉及多少个 Jet 的基本运算?

4. 反向模式 AD:adjoint 与 VJP ⭐⭐

动机

深度学习的成功建立在一个关键数学事实之上:对于 \(f: \mathbb{R}^n \to \mathbb{R}\)\(n\) 可以是数亿),计算完整梯度 \(\nabla f\) 的代价**与 \(n\) 无关**,只与计算 \(f\) 本身的代价成正比。这就是反向模式 AD(backpropagation 是其特例)的力量。

4.1 Adjoint 变量的定义 ⭐⭐

对计算图中的每个中间变量 \(v_i\),定义其 adjoint(伴随变量、灵敏度):

\[\bar{v}_i = \frac{\partial y}{\partial v_i}\]

其中 \(y\) 是最终输出。\(\bar{v}_i\) 回答的问题是:"如果 \(v_i\) 变化 \(\delta\),最终输出 \(y\) 变化多少?"

关键观察:输出节点的 adjoint 是 \(\bar{y} = \partial y / \partial y = 1\)。从这个"种子"出发,我们可以用链式法则**逆着计算图的方向**逐层传播 adjoint。

对于节点 \(v_j\),假设它的值被直接用在节点 \(v_{i_1}, v_{i_2}, \ldots, v_{i_k}\) 的计算中。由全微分公式(链式法则的多路版本):

\[\bar{v}_j = \sum_{i : j \in \text{parents}(i)} \bar{v}_i \cdot \frac{\partial v_i}{\partial v_j}\]

这就是反向模式 AD 的核心递推公式。它从输出向输入**反向**传播,每一步只需要知道"局部导数"\(\partial v_i / \partial v_j\)(这些在前向计算时已经确定了)。

4.2 一个完整的反向传播例子 ⭐⭐

继续使用 \(f(x_1, x_2) = x_1 x_2 + \sin(x_1)\),在 \(x_1 = \pi/4, x_2 = 2\) 处。

前向计算(存储所有中间值):

变量 运算
\(v_1\) \(x_1\) \(\pi/4\)
\(v_2\) \(x_2\) \(2\)
\(v_3\) \(v_1 \cdot v_2\) \(\pi/2\)
\(v_4\) \(\sin(v_1)\) \(\sqrt{2}/2\)
\(v_5\) \(v_3 + v_4\) \(\pi/2 + \sqrt{2}/2\)

反向传播(从 \(\bar{v}_5 = 1\) 开始):

adjoint 递推公式
\(\bar{v}_5\) 种子 \(1\)
\(\bar{v}_3\) \(\bar{v}_5 \cdot \frac{\partial v_5}{\partial v_3} = 1 \cdot 1\) \(1\)
\(\bar{v}_4\) \(\bar{v}_5 \cdot \frac{\partial v_5}{\partial v_4} = 1 \cdot 1\) \(1\)
\(\bar{v}_1\) \(\bar{v}_3 \cdot \frac{\partial v_3}{\partial v_1} + \bar{v}_4 \cdot \frac{\partial v_4}{\partial v_1}\) \(1 \cdot v_2 + 1 \cdot \cos(v_1) = 2 + \cos(\pi/4)\)
\(\bar{v}_2\) \(\bar{v}_3 \cdot \frac{\partial v_3}{\partial v_2}\) \(1 \cdot v_1 = \pi/4\)

结果:\(\frac{\partial f}{\partial x_1} = 2 + \cos(\pi/4) \approx 2.707\)\(\frac{\partial f}{\partial x_2} = \pi/4 \approx 0.785\)。与前向模式得到的结果完全一致——这不是巧合,两者都是链式法则的精确应用,只是遍历方向不同。

4.3 VJP(向量-雅可比积)⭐⭐

反向模式计算的是 VJP(Vector-Jacobian Product):

\[\text{VJP}: \quad (x, \bar{y}) \mapsto (f(x), \; \bar{y}^\top J(x))\]

其中 \(\bar{y}\) 是输出空间的"种子向量"(cotangent)。当 \(m=1\)\(\bar{y}=1\),一次 VJP 直接给出完整梯度——这就是为什么反向模式对标量输出函数如此高效。

如果选择 \(\bar{y} = e_j\)(第 \(j\) 个单位向量),则 VJP 给出 Jacobian 的第 \(j\) 行。计算完整 Jacobian 需要 \(m\) 次 VJP 调用,总成本 \(O(m \cdot \text{cost}(f))\)

前向与反向的对称关系

前向模式(JVP) 反向模式(VJP)
计算 \(J \cdot v\)(Jacobian 乘向量) \(v^\top J\)(向量乘 Jacobian)
种子 输入空间的向量 \(\dot{x}\) 输出空间的向量 \(\bar{y}\)
遍历方向 计算图正方向 计算图逆方向
一次得到 Jacobian 的一列 Jacobian 的一行
全 Jacobian \(n\) 次调用 \(m\) 次调用
最佳场景 \(n \ll m\) \(n \gg m\)
额外存储 \(O(1)\) \(O(\text{计算图节点数})\)

反事实推理:如果深度学习的 loss 不是标量而是高维向量(比如 \(m = 10^6\)),那么反向模式就不再高效——每个输出分量需要一次反向传播。在这种情况下(如计算完整 Jacobian),最优策略取决于 \(n\)\(m\) 的相对大小:\(n < m\) 用前向,\(n > m\) 用反向。

4.4 Cheap Gradient Principle ⭐⭐⭐

定理(Baur-Strassen 1983;Griewank-Walther Thm. 4.11):对任意 \(f: \mathbb{R}^n \to \mathbb{R}\),反向模式计算完整梯度 \(\nabla f\) 的运算量满足:

\[\text{cost}(\nabla f) \leq c \cdot \text{cost}(f), \quad c \leq 5\]

其中 \(c\) 是常数,典型值 2-3,与输入维度 \(n\) 无关

为什么 \(c\) 有上界? 在反向传播中,每个基本运算的反向对应有确定的运算量上界:

基本运算 前向代价 反向附加代价 \(c\)
\(v = a + b\) 1 ADD \(\bar{a} += \bar{v}\), \(\bar{b} += \bar{v}\) 2 ADD
\(v = a \times b\) 1 MUL \(\bar{a} += \bar{v} \cdot b\), \(\bar{b} += \bar{v} \cdot a\) 2 MUL + 2 ADD
\(v = \sin(a)\) 1 SIN \(\bar{a} += \bar{v} \cdot \cos(a)\) 1 COS + 1 MUL + 1 ADD
\(v = a / b\) 1 DIV \(\bar{a} += \bar{v}/b\), \(\bar{b} -= \bar{v} \cdot a / b^2\) 2 DIV + 1 MUL + 2 ADD

每个基本运算的反向代价是常数倍的前向代价,因此总的反向代价也是常数倍。

这个定理的深远意义:它意味着无论你的神经网络有多少参数(\(n = 10^9\)),计算完整梯度的代价始终只是"多做几次前向计算"。这是深度学习得以训练超大模型的数学基础——如果梯度计算的代价与 \(n\) 成正比,现代的 GPT 级模型根本不可能训练。

4.5 反向模式的内存代价分析 ⭐⭐⭐

反向模式 AD 的一个经常被忽视的代价是**内存**。为什么反向模式需要额外内存,而前向模式不需要?

在前向模式中,tangent \(\dot{v}_i\) 与 primal \(v_i\) 同步计算——当你计算到节点 \(v_i\) 时,所有父节点的 tangent 已经可用。因此,前向模式可以用 \(O(1)\) 额外内存(只存储当前层的 tangent)实现流式计算。

在反向模式中,adjoint \(\bar{v}_i\) 的计算需要 \(v_i\) 的**局部导数** \(\partial v_i / \partial v_j\),而这些局部导数依赖于**前向计算时的 primal 值**。例如,\(v_3 = v_1 \cdot v_2\) 的局部导数是 \(\partial v_3 / \partial v_1 = v_2\)——反向传播到节点 \(v_1\) 时需要用到 \(v_2\) 的值,但此时前向计算早已结束。

因此,反向模式必须**先完成整个前向计算并存储所有中间值**,然后才能开始反向传播。存储这些中间值的数据结构称为 tape(磁带)。

tape 的内存消耗\(O(W)\),其中 \(W\) 是计算图的总节点数(work)。对于深度学习,\(W\) 与网络参数量和 batch size 成正比;对于 MPC 展开,\(W\) 与 horizon 长度成正比。

情况 tape 大小 典型值
ResNet-50 \(\sim 10^7\) 节点 \(\sim 200\) MB
1000 步 RNN \(\sim 10^6 \times 1000\) 数 GB
MPC 展开 50 步 \(\sim 10^4 \times 50\) \(\sim 10\) MB

这就是 §5.3 中 checkpointing 技术的必要性来源。

4.6 反向传播就是反向模式 AD 的特例 ⭐⭐

"反向传播"(backpropagation)和"反向模式 AD"在数学上完全相同。历史上之所以有两个名字,是因为它们在两个不同的社区(机器学习 vs 科学计算)被独立发现。

具体来说,深度学习中的标准反向传播是反向模式 AD 应用于一类特殊的计算图——层级结构**的计算图(每层的输出是下一层的输入)。但反向模式 AD 适用于**任意 DAG 结构,包括有多路分支、汇合的复杂计算图。

本质洞察:反向传播不是一个独立的算法,而是链式法则在计算图上从输出到输入方向遍历的自然结果。一旦你理解了"adjoint 变量沿反向边传播"这个机制,所有"xx 的反向传播公式"(如 LSTM 的、Transformer 的、GNN 的)都不需要单独记忆——它们都是同一个通用算法在不同图结构上的实例化。

⚠️ 常见陷阱

⚠️ 编程陷阱:反向 AD 的内存爆炸 - 错误做法:对一个 1000 步的 RNN 或 MPC 展开做完整反向传播 - 现象:需要存储所有 1000 步的中间激活(前向计算的值),GPU 内存耗尽 - 根本原因:反向模式需要在反向传播时**重用前向计算的中间值**(如 \(\cos(v_1)\)),所以这些值必须保存到反向传播到达该节点时 - 正确做法:使用 checkpointing(§5.3)或切换到隐式微分(§7-§8)

🧠 思维陷阱:认为"反向模式总是比前向模式好" - 错误推理:"Cheap Gradient Principle 说反向模式代价与 \(n\) 无关,所以永远用反向模式" - 实际上:(1) 反向模式需要 \(O(\text{图大小})\) 的额外内存存储中间值,前向模式只需 \(O(1)\);(2) 当 \(n < m\) 时前向模式更快;(3) 某些硬件(如没有足够缓存的嵌入式系统)上,前向模式的简单内存模式更有优势 - 正确思维:前向和反向是**互补**的工具,选择取决于 \(n, m\) 的比例和内存约束

练习

  1. \(f(x_1, x_2, x_3) = (x_1 x_2 + x_3)^2\) 执行完整的反向传播,验证 \(\nabla f\)
  2. 考虑函数 \(f: \mathbb{R}^{100} \to \mathbb{R}^{200}\)。分别用前向和反向模式计算完整 Jacobian,哪个更快?需要多少次 JVP/VJP 调用?
  3. (跨章综合)回顾凸优化中的梯度下降:\(x_{k+1} = x_k - \alpha \nabla f(x_k)\)。如果 \(f\) 是一个神经网络的 loss(\(n = 10^6\) 参数),解释为什么我们必须用反向 AD 而非前向 AD 来计算 \(\nabla f\)。估算两者的计算时间比。

5. 高阶 AD 与工程技巧 ⭐⭐⭐

动机

一阶导数(梯度/Jacobian)是优化的基本工具,但很多场景需要更高阶的信息:Newton 法需要 Hessian、灵敏度分析需要二阶导数、不确定性传播需要曲率信息。高阶 AD 提供了高效计算这些量的方法。

5.1 Hessian-向量积:forward-over-reverse ⭐⭐⭐

\(f: \mathbb{R}^n \to \mathbb{R}\),Hessian 矩阵 \(H = \nabla^2 f\)\(n \times n\) 的。直接计算完整 Hessian 的代价是 \(O(n^2 \cdot \text{cost}(f))\)——对大 \(n\) 不可接受。

Hessian-向量积 \(Hv = \nabla^2 f \cdot v\) 可以用"forward-over-reverse"技术以 \(O(\text{cost}(f))\) 的代价计算:

\[Hv = \nabla(\nabla f \cdot v) = \nabla(g(x)), \quad \text{其中 } g(x) = \nabla f(x) \cdot v\]

步骤: 1. **反向模式**计算 \(\nabla f(x)\)(代价 \(c \cdot \text{cost}(f)\)) 2. 计算内积 \(g(x) = \nabla f(x)^\top v\)(代价 \(O(n)\)) 3. **前向模式**计算 \(\nabla g(x)\)(对 \(g\) 做一次 JVP,代价 \(c' \cdot \text{cost}(g)\)

总代价约 \(O(\text{cost}(f))\),与 \(n\) 无关。这使得 truncated Newton 法、共轭梯度法中的 Hessian-向量积变得高效可行。

在 JAX 中,这可以简洁地表达为:

import jax
import jax.numpy as jnp

def hvp(f, x, v):
    """Hessian-向量积:H @ v = d/dt grad(f)(x + t*v)|_{t=0}"""
    return jax.jvp(lambda x: jax.grad(f)(x), (x,), (v,))[1]

# 等价写法(更高效)
def hvp_alt(f, x, v):
    return jax.grad(lambda x: jnp.dot(jax.grad(f)(x), v))(x)

5.2 Hyper-dual Number ⭐⭐⭐⭐

Dual number 只能提取一阶导数(\(\varepsilon^2 = 0\) 丢弃了二阶信息)。Hyper-dual number 引入两个独立的无穷小量 \(\varepsilon_1, \varepsilon_2\),满足:

\[\varepsilon_1^2 = \varepsilon_2^2 = 0, \quad \varepsilon_1 \varepsilon_2 = \varepsilon_2 \varepsilon_1 \neq 0\]

一个 hyper-dual number 写作 \(a + b\varepsilon_1 + c\varepsilon_2 + d\varepsilon_1\varepsilon_2\)

将函数 \(f\) 作用在 \(x_0 + 1 \cdot \varepsilon_1 + 1 \cdot \varepsilon_2 + 0 \cdot \varepsilon_1\varepsilon_2\) 上:

\[f(x_0 + \varepsilon_1 + \varepsilon_2) = f(x_0) + f'(x_0)\varepsilon_1 + f'(x_0)\varepsilon_2 + f''(x_0)\varepsilon_1\varepsilon_2\]

\(\varepsilon_1\varepsilon_2\) 的系数直接给出**精确的二阶导数** \(f''(x_0)\),无截断误差。

与中心差分的二阶导数对比:中心差分 \(f''(x) \approx (f(x+h) - 2f(x) + f(x-h))/h^2\)\(O(h^2)\) 的截断误差和严重的舍入误差(分母 \(h^2\) 放大了数值噪声)。Hyper-dual number 完全避免了这些问题。

反事实推理:如果没有 hyper-dual number,计算精确的二阶导数只能通过(1)符号微分(表达式膨胀)或(2)嵌套 AD(需要构建二层计算图)。Fike 和 Alonso 在 2011 年提出 hyper-dual number 后,一些简单场景下的二阶导数计算变得和一阶一样直接。

5.3 Checkpointing:时间换内存 ⭐⭐⭐

反向模式 AD 的主要瓶颈是内存:它需要存储前向计算的所有中间值,供反向传播使用。对于长序列模型(RNN、MPC 展开、ODE 数值积分),内存消耗与序列长度成正比。

Checkpointing(又称 rematerialization)是一种**时间换内存**的策略:

基本思想:不存储所有中间值,只在若干"检查点"处存储状态。反向传播到某个区间时,从最近的检查点开始**重新前向计算**该区间的中间值。

经典结果(Griewank 1992):对 \(K\) 步的计算序列: - 完全存储:内存 \(O(K)\),不需要重计算 - 均匀 checkpointing(每 \(\sqrt{K}\) 步存一次):内存 \(O(\sqrt{K})\),重计算代价 \(O(\sqrt{K})\) - 最优 binomial checkpointing(Revolve 算法):给定 \(c\) 个检查点,最优时间代价为 \(O(K \log K / \log c)\)

策略 内存 时间(相对于无 checkpoint) 适用场景
不用 checkpoint \(O(K)\) \(1\times\) 短序列 (\(K < 100\))
\(\sqrt{K}\)-checkpoint \(O(\sqrt{K})\) \(\sim 1.5\times\) 中等长度
Revolve 最优 \(O(\log K)\) \(\sim 2\times\) 超长序列

在 PyTorch 中通过 torch.utils.checkpoint 使用,在 JAX 中通过 jax.checkpoint 使用。

5.4 Taylor 模式 AD ⭐⭐⭐⭐

Taylor 模式 AD 是前向模式的高阶推广。它不是用 dual number(\(\varepsilon^2 = 0\)),而是保留 \(\varepsilon\) 的高阶项,到 \(\varepsilon^d\) 为止。这相当于同时计算函数在一点处的 Taylor 系数 \(f(x_0), f'(x_0), f''(x_0)/2!, \ldots, f^{(d)}(x_0)/d!\)

每个基本运算的传播规则变成了多项式乘法/卷积。例如,乘法 \(w = u \cdot v\) 的第 \(k\) 阶 Taylor 系数满足:

\[w_k = \sum_{j=0}^{k} u_j v_{k-j}\]

这是 Cauchy 乘积公式——与多项式/形式幂级数的乘法完全相同。

Taylor 模式的应用场景包括:ODE 初值问题的高阶方法(Taylor 积分器比 Runge-Kutta 更高效)、函数的 Padé 近似、以及代数方程的级数解。

⚠️ 常见陷阱

🧠 思维陷阱:总是计算完整 Hessian - 错误想法:"Newton 法需要 Hessian,所以我要算出完整的 \(n \times n\) 矩阵" - 实际上:(1) Newton 法的核心步骤 \(H^{-1} \nabla f\) 可以通过 CG 迭代用 Hessian-向量积实现,不需要显式 Hessian;(2) 对 \(n > 1000\) 的问题,存储 Hessian 本身就不可行(\(10^6\) 个 double = 8 MB,\(10^4\) 维就是 800 MB) - 正确思维:先问"我真的需要完整 Hessian 吗?"——多数情况下 Hessian-向量积就够了

⚠️ 编程陷阱:不用 checkpointing 训练长序列 - 错误做法:对 1000 步 RNN 直接 loss.backward() - 现象:OOM(Out of Memory)或 GPU 利用率极低(因为大部分显存被中间值占满) - 正确做法:使用 torch.utils.checkpoint 包装 RNN cell,或切到 truncated BPTT

练习

  1. 推导 hyper-dual number 下 \(f(x) = x^3\) 的展开:\(f(a + b\varepsilon_1 + c\varepsilon_2 + d\varepsilon_1\varepsilon_2) = ?\),验证 \(\varepsilon_1\varepsilon_2\) 系数为 \(6a \cdot bc\)(当 \(b=c=1, d=0\) 时等于 \(f''(a)=6a\))。
  2. 对一个 \(K=1024\) 步的 ODE 积分器,如果使用 \(\sqrt{K} = 32\) 个 checkpoint,内存节省多少倍?重计算的时间开销大约增加多少?
  3. 用 JAX 实现 Hessian-向量积函数 hvp(f, x, v) 并验证其正确性。

6. AD 框架比较与选型 ⭐⭐

动机

"该用哪个 AD 框架?"这是工程师最常问的问题之一。答案取决于你的语言偏好、问题规模、部署环境和性能要求。本节给出系统化的选型指南。

6.1 框架全景

框架 主模式 实现机制 语言 高阶 GPU 典型用途
Ceres Jet 前向 模板+运算符重载 C++ 嵌套 SLAM/BA
CppAD 反向 tape(运行时) C++ 任意 NLP/MPC
CppADCodeGen 反向 tape→C 代码 C++ 部分 嵌入式 MPC
CasADi 前向+反向 符号图 SX/MX C++/Py codegen OCP/MPC
Enzyme 前向+反向 LLVM IR C/C++/Rust HPC
PyTorch 反向(+前向) 动态 tape Python 嵌套 深度学习
JAX 前向+反向 tracing→XLA Python 任意 ML/RL/科学计算
autodiff 前向+反向 C++17 header C++ 原型
Pinocchio 解析+AD 手工链式 C++/Py 一阶 有限 刚体动力学

6.2 六大框架的设计哲学 ⭐⭐

JAX:函数式变换的组合

JAX 的核心是四个可组合的函数变换:grad(反向 AD)、jit(JIT 编译到 XLA)、vmap(自动向量化/批处理)、pmap(多设备并行)。这种设计来自函数式编程——每个变换把一个函数映射为另一个函数,变换之间可以任意嵌套。

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(jnp.sin(x) ** 2)

# grad: f -> grad_f(反向 AD)
grad_f = jax.grad(f)
# jit: grad_f -> 编译后的 grad_f(加速 10-100x)
fast_grad_f = jax.jit(grad_f)
# vmap: fast_grad_f -> 批处理版本(自动向量化)
batched_grad_f = jax.vmap(fast_grad_f)

JAX 的约束:函数必须是"纯"的(无副作用),不能包含 Python 控制流(需用 jax.lax.cond/scan 替代 if/for)。这种约束换来了**可预测的性能**和**任意阶可微**。

PyTorch autograd:动态图的灵活性

PyTorch 在每次前向计算时动态构建计算图(tape),反向传播后图被释放。这意味着你可以在 forward 中自由使用 Python 的 if/else、for 循环,每次前向可以走不同的代码路径。

这种设计对**研究**极为友好(你可以随时 print、debug、改图结构),但对**部署**不利(每次前向都要重建图,无法像 JAX 那样一次编译多次执行)。

CasADi:符号图 + 代码生成

CasADi 走了一条独特的路:用 Python/MATLAB 构建**符号计算图**(不执行数值计算),然后对这张图做符号微分和代码生成(导出 C 代码)。

CasADi 有两种图类型: - SX:标量级别的符号图。每个标量是图的一个节点,适合稠密的小问题(维度 < 100) - MX:矩阵级别的符号图。节点是矩阵运算,适合结构化的大问题(OCP、NLP)

import casadi as ca

# SX 符号变量
x = ca.SX.sym('x', 3)  # 3维符号向量
f = ca.sin(x[0]) * x[1] + ca.exp(x[2])
J = ca.jacobian(f, x)   # 符号 Jacobian(不执行数值计算!)

# 编译为高效的 C 函数
func = ca.Function('f', [x], [f, J])
func.generate('my_func.c')  # 导出 C 代码,可被 acados/OCS2 调用

CasADi 的核心限制:不支持 Python 控制流(因为图是静态的)。如果你的 ODE 右端函数中有 if/else,必须用 CasADi 的 ca.if_else 替代。

Ceres Jet:编译期的 SLAM 专用 AD

Ceres 的 Jet 通过 C++ 模板在**编译期**实例化 AD 代码。优势是零运行时开销和极高的性能(对 SLAM/BA 这种需要每秒调用数百万次 cost function 的场景至关重要)。劣势是只支持前向模式(反向模式需要运行时构建图)。

Enzyme:编译器级别的 AD

Enzyme 工作在 LLVM 中间表示(IR)层面——它直接对编译后的 IR 做 AD 变换,然后再生成机器码。这意味着你可以对**任何编译为 LLVM IR 的语言**(C、C++、Rust、Fortran、Julia)做 AD,且不需要修改源码或使用特殊的数值类型。

Enzyme 的突破性优势是它能对 GPU kernel 做 AD——这在以前只有 JAX/PyTorch 等框架级解决方案才能做到。

6.3 Enzyme:编译器级 AD 的革命 ⭐⭐⭐

Enzyme(Moses & Churavy, NeurIPS 2020)代表了 AD 实现的一种全新范式。传统 AD 要么在源码层面重载运算符(如 Ceres Jet),要么在框架层面构建计算图(如 PyTorch/JAX)。Enzyme 则在**编译器的中间表示**(LLVM IR)层面做 AD。

工作原理: 1. 你用任何支持 LLVM 的语言(C、C++、Rust、Fortran、Julia)写正常代码 2. Clang/GCC 将代码编译为 LLVM IR(一种低层级的中间表示) 3. Enzyme 作为 LLVM 的一个 pass,对 IR 做 AD 变换——在 IR 层面自动生成前向/反向模式的导数代码 4. LLVM 后端将变换后的 IR 编译为优化的机器码

关键优势: - 零源码修改:不需要重写为模板函数、不需要用特殊的张量类型、不需要用框架的 API - 跨语言:C、C++、Rust、Fortran、Julia 的代码都可以直接做 AD - GPU 支持(Moses et al., SC 2021):可以对 CUDA kernel 做 AD——这在以前只有 JAX/PyTorch 等框架级方案才能实现 - 性能:由于在 IR 层面操作,可以利用 LLVM 的全部优化 pass(内联、循环展开、向量化),生成的导数代码质量极高

在机器人学中的应用前景:Enzyme 使得对已有的 C/C++ 物理引擎(如 MuJoCo、Drake、Bullet)做 AD 成为理论上可能的事情——不需要从头用 JAX 重写。当然,实践中仍有很多工程挑战(如对复杂控制流、内存分配、虚函数的支持)。

6.4 选型决策树 ⭐⭐

你的问题需要 AD?
├── 是深度学习(训练神经网络)?
│   ├── 是 → PyTorch(研究灵活)或 JAX(性能极致)
│   └── 否 ↓
├── 是 SLAM/BA?
│   ├── 是 → Ceres Jet(C++ 模板 AD,行业标准)
│   └── 否 ↓
├── 是 MPC/OCP(需要实时代码生成)?
│   ├── 是 → CasADi(符号→C 代码→ acados/IPOPT)
│   │       或 CppAD + CppADCodeGen(C++ tape→ C 代码)
│   └── 否 ↓
├── 是刚体动力学微分?
│   ├── 是 → Pinocchio 解析微分(比黑盒 AD 快 3×+)
│   └── 否 ↓
├── 是 RL + 可微仿真?
│   ├── 是 → JAX + Brax/MJX
│   │       或 Dojo(Julia,精确接触梯度)
│   └── 否 ↓
└── 通用科学计算?
    ├── C/C++ 环境 → Enzyme 或 autodiff(header-only)
    ├── Python → JAX
    └── Julia → ForwardDiff.jl + Enzyme.jl

反事实推理:如果你为 SLAM 选了 PyTorch 而不是 Ceres,会发生什么?PyTorch 的动态图每次前向都要分配内存、构建 tape、Python 解释器开销——对于 BA 中每秒数百万次的 cost function 调用,Python 开销比计算本身大一个数量级。Ceres 的编译期模板 AD 完全消除了这些开销。所以选对工具不是锦上添花,而是决定系统能不能跑起来。

⚠️ 常见陷阱

🧠 思维陷阱:用黑盒 AD 对刚体动力学求微分 - 错误做法:用 CppAD 或 PyTorch 对 RNEA/ABA 做黑盒反向 AD - 现象:对 7 自由度机械臂,黑盒 AD 约需 50 \(\mu\)s,而 Pinocchio 解析微分只需 3 \(\mu\)s - 根本原因:RNEA/ABA 的递推结构天然编码了空间代数的链式法则。Carpentier 和 Mansard(RSS 2018)证明了手工推导的解析微分利用了递推的稀疏结构,避免了黑盒 AD 中大量冗余运算 - 正确做法:使用 pinocchio.computeRNEADerivatives()pinocchio.computeABADerivatives()

💡 概念误区:混淆 CasADi 的符号图和 PyTorch 的动态图 - CasADi 的图是**完全符号的**:构建阶段不执行任何数值计算,只记录运算结构。图构建完成后再"编译"为高效的数值函数 - PyTorch 的图是**边执行边构建**的:每次 forward() 同时执行数值计算和构建反向图 - 后果:在 CasADi 中不能用 Python 的 if x > 0 (因为 x 是符号、没有数值),必须用 ca.if_else

练习

  1. 对同一个函数 \(f(x) = \sum_i \sin(x_i)^2\)\(n = 1000\)),分别用 JAX jax.grad 和 PyTorch torch.autograd.grad 计算梯度,比较 wall time。加上 jax.jit 后差异如何变化?
  2. 用 CasADi 构建单摆 \(\dot{\theta} = \omega, \dot{\omega} = -g/l \sin\theta + u/(ml^2)\) 的符号模型,生成 Jacobian \(\partial f / \partial x\)\(\partial f / \partial u\),与手工推导对比。
  3. 列出你目前的研究中使用的 AD 工具,分析是否选择了最合适的方案。如果要切换到更合适的工具,需要修改哪些代码?

7. 隐函数定理与隐式微分 ⭐⭐

动机

到目前为止,我们讨论的 AD 都是**前向/反向传播**——沿着计算图传播导数。但有一大类重要的函数,其值不是通过一段显式的代码计算出来的,而是**通过求解一个方程或优化问题隐式定义的**。

例如: - 解线性方程组:\(y^*(A, b) = A^{-1}b\)\(y^*\) 是方程 \(Ay = b\) 的解 - 解优化问题:\(y^*(\theta) = \arg\min_y f(y, \theta)\)\(y^*\) 是目标函数的最优解 - 解 ODE:\(y^*(T, \theta)\)\(\dot{z} = f_\theta(z, t)\)\(t=T\) 时的解 - 解不动点方程:\(y^* = T_\theta(y^*)\)\(y^*\) 是映射 \(T\) 的不动点

对这些"隐式定义的函数",如何计算 \(\partial y^* / \partial \theta\)

朴素方法——展开(unrolling):把求解过程的每一步迭代(如 Newton 迭代、梯度下降迭代)展开成计算图,然后用反向 AD 穿过所有迭代步骤。这种方法可以工作,但有严重的缺点:

  1. 内存:需要存储所有迭代步骤的中间状态,\(O(K)\)\(K\) 为迭代步数)
  2. 梯度质量:如果迭代没有完全收敛,梯度会受到"收敛路径"的污染
  3. 计算量:反向传播穿过 \(K\) 步迭代,每步都要计算局部导数

正确方法——隐式微分:利用隐函数定理,只在**收敛点** \(y^*\) 处计算导数,完全不需要知道"是怎么求解到 \(y^*\) 的"。内存 \(O(1)\),梯度精确。

跨领域类比:展开法就像在地图上追踪你从家走到学校的每一步路径,然后计算"如果起点偏移 \(\delta\),终点偏移多少"。隐式微分则是直接问:"学校的位置(终点)对家的位置(起点)有什么依赖关系?"——你不需要知道走了哪条路,只需要知道终点满足什么条件。

7.1 隐函数定理(IFT)的严格陈述 ⭐⭐

定理(经典隐函数定理):设 \(F: \mathbb{R}^n \times \mathbb{R}^m \to \mathbb{R}^m\)\(C^1\) 映射。如果在点 \((x_0, y_0)\) 处:

  1. \(F(x_0, y_0) = 0\)\(y_0\) 满足方程)
  2. \(\frac{\partial F}{\partial y}(x_0, y_0)\) 是可逆的(\(m \times m\) 矩阵非奇异)

则存在 \(x_0\) 的邻域 \(U\) 和唯一的 \(C^1\) 映射 \(y^*: U \to \mathbb{R}^m\),使得:

\[F(x, y^*(x)) = 0, \quad \forall x \in U\]

并且:

\[\frac{dy^*}{dx} = -\left(\frac{\partial F}{\partial y}\right)^{-1} \frac{\partial F}{\partial x}\]

7.2 IFT 的完整证明思路 ⭐⭐⭐

证明的核心思想:构造压缩映射,然后用 Banach 不动点定理。

Step 1:定义辅助映射 \(G(x, y) = y - A^{-1} F(x, y)\),其中 \(A = \frac{\partial F}{\partial y}(x_0, y_0)\)

Step 2:验证 \(G(x_0, y_0) = y_0 - A^{-1} \cdot 0 = y_0\)

Step 3:计算 \(\frac{\partial G}{\partial y}(x_0, y_0) = I - A^{-1} \frac{\partial F}{\partial y}(x_0, y_0) = I - A^{-1} A = 0\)

因为 \(\partial G / \partial y\)\((x_0, y_0)\) 处为零(Lipschitz 常数为 0),由连续性,在 \((x_0, y_0)\) 的足够小邻域内 \(G\) 关于 \(y\) 是**压缩映射**。

Step 4:由 Banach 不动点定理,\(G(x, \cdot)\) 有唯一不动点 \(y^*(x) = G(x, y^*(x))\),即 \(y^*(x) - A^{-1} F(x, y^*(x)) = y^*(x)\),化简得 \(F(x, y^*(x)) = 0\)

Step 5:求导。对恒等式 \(F(x, y^*(x)) \equiv 0\) 两边对 \(x\) 求全微分:

\[\frac{\partial F}{\partial x} + \frac{\partial F}{\partial y} \cdot \frac{dy^*}{dx} = 0\]

解出:

\[\frac{dy^*}{dx} = -\left(\frac{\partial F}{\partial y}\right)^{-1} \frac{\partial F}{\partial x}\]

这就是**隐式微分公式**。

物理直觉\(\frac{\partial F}{\partial x}\) 衡量"参数变化对方程左边的直接影响",\((\frac{\partial F}{\partial y})^{-1}\) 衡量"解需要调整多少来重新满足方程"。两者的乘积就是"参数变化对解的间接影响"。

7.3 VJP 形式的隐式微分 ⭐⭐⭐

在实际应用中,我们通常不需要完整的 Jacobian \(dy^*/dx\),而只需要 VJP:给定上游梯度 \(\bar{y}\),计算 \(\bar{x} = \bar{y}^\top \frac{dy^*}{dx}\)

\[\bar{x}^\top = \bar{y}^\top \cdot \left(-\left(\frac{\partial F}{\partial y}\right)^{-1} \frac{\partial F}{\partial x}\right) = -u^\top \frac{\partial F}{\partial x}\]

其中 \(u\) 满足线性方程:

\[\left(\frac{\partial F}{\partial y}\right)^\top u = \bar{y}\]

这意味着隐式微分的反向传播只需要**解一个线性方程组**(与前向求解规模相同),不需要构造或存储完整的 Jacobian 或其逆矩阵。

7.4 IFT 失效的情况与对策 ⭐⭐⭐

IFT 的核心条件是 \(\partial F / \partial y\) 非奇异。在实际应用中,这个条件可能失效:

场景 1:退化 QP。当 QP 有多个最优解时(\(Q\) 半正定、不唯一),KKT 矩阵奇异。对策:添加 Tikhonov 正则化 \(Q \to Q + \epsilon I\)

场景 2:Active set 切换。含不等式约束的优化问题中,当参数 \(\theta\) 变化导致某个约束从 active 变为 inactive(或反之),最优解不可微(存在拐点)。对策:使用 interior-point 方法的 central-path 参数化 \(\mu\) 做光滑化。

场景 3:非唯一解。非凸问题可能有多个局部最优解。参数微小变化可能导致全局最优解从一个局部最优跳到另一个。对策:在明确指定的局部最优附近使用 IFT,不追求全局最优。

反事实推理:如果 \(\partial F / \partial y\) 恰好奇异,IFT 公式 \(dy^*/dx = -(F_y)^{-1} F_x\) 中出现了矩阵求逆失败。但这并不意味着 \(y^*(x)\) 不存在或不连续——它可能仍然存在且连续,只是不可微。这类似于 \(|x|\)\(x=0\) 处连续但不可微。处理这种情况需要更高级的工具:广义隐函数定理(Dontchev & Rockafellar 2014)用 Aubin property 和度量正则性替代经典的非奇异条件。

⚠️ 常见陷阱

⚠️ 编程陷阱:展开 QP/iLQR 迭代做反向传播 - 错误做法:把 QP 求解器的每步迭代放进 PyTorch 计算图 - 现象:(1) 内存与迭代步数成正比;(2) 梯度依赖于收敛路径而非收敛解;(3) 如果求解器使用了 warm-start,梯度会被 warm-start 的质量污染 - 根本原因:展开法混淆了"求解过程"和"解的性质"——导数应该只依赖于"解在哪里",不应该依赖于"怎么找到解的" - 正确做法:用隐式微分,在收敛解 \(y^*\) 处通过 IFT 计算梯度

💡 概念误区:认为 IFT 只适用于等式约束 - 实际上:对不等式约束 \(g(y, x) \leq 0\),在严格互补条件下(active 约束上 \(\lambda > 0\)),可以把 KKT 条件写成等式 \(F(z, x) = 0\)(其中 \(z = (y, \lambda, \nu)\)),然后对 \(F\) 用 IFT - 这正是 OptNet/cvxpylayers 的数学基础

练习

  1. 对方程 \(F(x, y) = y^3 + xy - 1 = 0\),在 \((x_0, y_0) = (0, 1)\) 处:(a) 验证 IFT 的条件;(b) 用隐式微分公式计算 \(dy^*/dx\);(c) 用 SymPy 求解 \(y^*(x)\) 的显式表达式并验证导数。
  2. 考虑线性方程组 \(Ay = b\)\(A\) 可逆)。将其写成 \(F(A, b, y) = Ay - b = 0\)。用 IFT 推导 \(\partial y^* / \partial b\)\(\partial y^* / \partial A\)(后者需要使用矩阵微分规则)。
  3. (跨章综合)回顾凸优化中的 proximal 算子 \(\text{prox}_f(x) = \arg\min_y \{f(y) + \frac{1}{2}\|y - x\|^2\}\)。写出最优性条件 \(F(x, y^*) = 0\),用 IFT 推导 \(\partial \text{prox}_f / \partial x\) 的表达式。

8. 可微优化层:KKT 的微分 ⭐⭐⭐

动机

2017 年 Amos 和 Kolter 提出了一个改变游戏规则的想法:把优化问题本身当作神经网络的一层(OptNet)。这意味着你可以在网络中嵌入一个 QP 求解器,梯度穿过 QP 层继续反向传播——优化的输出对网络的参数是可微的。

这是隐式微分(§7)在优化问题上的直接应用。

8.1 QP 层的隐式微分推导 ⭐⭐⭐

考虑参数化的等式约束 QP:

\[y^*(\theta) = \arg\min_y \frac{1}{2} y^\top Q(\theta) y + q(\theta)^\top y \quad \text{s.t.} \quad A(\theta) y = b(\theta)\]

Step 1:写出 KKT 条件。引入 Lagrange 乘子 \(\nu\)

\[F(y, \nu; \theta) = \begin{pmatrix} Q y + q + A^\top \nu \\ A y - b \end{pmatrix} = 0\]

Step 2:KKT 矩阵(Jacobian \(\partial F / \partial (y, \nu)\)):

\[\frac{\partial F}{\partial (y, \nu)} = \begin{pmatrix} Q & A^\top \\ A & 0 \end{pmatrix}\]

这就是凸优化中经典的 KKT 矩阵。KKT 矩阵非奇异的条件是:

  1. \(Q \succ 0\)(正定,保证目标严格凸);
  2. \(A\) 满行秩(即 \(\text{rank}(A) = p\),其中 \(p\) 为等式约束数),这是**线性独立约束品性(LICQ)**的要求。

为什么需要 \(A\) 满行秩? KKT 矩阵是 \(2 \times 2\) 分块鞍点矩阵。由 Sylvester 惯性定理,它非奇异当且仅当 \(Q\)\(\ker(A)\) 上正定**且** \(A\) 满行秩。如果 \(A\) 有线性相关的行(冗余约束),KKT 矩阵的零空间非平凡——乘子 \(\nu^*\) 不唯一,隐式微分无法定义。工程中出现冗余等式约束(如多只脚都约束在地面上但几何上有依赖)时,必须先去冗余再做隐式微分。

Step 3:用 IFT 得到隐式微分公式。对参数 \(q\) 的微分最简单:

\[\frac{\partial F}{\partial q} = \begin{pmatrix} I \\ 0 \end{pmatrix}\]
\[\frac{d(y^*, \nu^*)}{dq} = -\begin{pmatrix} Q & A^\top \\ A & 0 \end{pmatrix}^{-1} \begin{pmatrix} I \\ 0 \end{pmatrix}\]

用 Schur 补提取 \(\partial y^*/\partial q\)。KKT 系统的分块消元过程:从第二行 \(A \cdot dy/dq = 0\)(因为 \(b\) 不依赖 \(q\)),第一行 \(Q \cdot dy/dq + A^\top \cdot d\nu/dq = -I\)。在约束 \(A \cdot dy/dq = 0\) 下求解——这等价于将 \(dy/dq\) 限制在 \(\ker(A)\) 中,再用 \(Q\) 的逆。正确的 Schur 补公式为:

\[\frac{\partial y^*}{\partial q} = -\left(Q^{-1} - Q^{-1}A^\top(AQ^{-1}A^\top)^{-1}AQ^{-1}\right)\]

这正是将 \(Q^{-1}\) 投影到 \(\ker(A)\) 上的结果(即 \(Q^{-1}\) 的零空间正交投影)。直觉含义:参数 \(q\) 的变化只能沿着满足约束 \(Ay = b\) 的方向影响最优解——正交于约束的分量被"锁死"了。

Step 4:VJP 形式。给定上游梯度 \(\bar{y}\),需要解:

\[\begin{pmatrix} Q & A^\top \\ A & 0 \end{pmatrix} \begin{pmatrix} u \\ w \end{pmatrix} = \begin{pmatrix} \bar{y} \\ 0 \end{pmatrix}\]

然后 \(\bar{q} = -u\)。这只需要解**一个 KKT 系统**——与前向求解 QP 的计算量相当。

为什么不直接对 \(y^* = -Q^{-1}q\) 做 AD? 对无约束 QP,最优解有封闭形式,确实可以直接对封闭形式做 AD。但封闭形式有几个问题:(1) 求 \(Q^{-1}\) 的计算量是 \(O(n^3)\),而 KKT 系统的求解(如用 Cholesky 分解)也是 \(O(n^3)\),但后者可以复用前向求解时的分解;(2) 有约束时没有封闭形式;(3) 大规模问题中 \(Q\) 是稀疏的,KKT 求解可以利用稀疏性而显式求逆不能。

所以 IFT/KKT 隐式微分不仅是唯一的选择(有约束时),而且通常是更高效的选择(利用前向求解的中间结果和稀疏结构)。

8.2 含不等式约束的推广(OptNet) ⭐⭐⭐

对含不等式约束的 QP \(\min \frac{1}{2}y^\top Qy + q^\top y\) s.t. \(Gy \leq h, Ay = b\),KKT 条件增加了互补松弛条件 \(\lambda \odot (Gy - h) = 0\)

在**严格互补**条件下(active 约束上 \(\lambda_i > 0\)),可以把 KKT 写成 \(F(z, \theta) = 0\)\(z = (y, \lambda, \nu)\)),然后直接用 IFT。

关键的实现细节是:需要识别 active 约束集 \(\mathcal{A} = \{i : g_i(y^*) = 0\}\),只对 active 约束建立 KKT 矩阵。

8.3 cvxpylayers:更通用的可微凸优化 ⭐⭐⭐

Agrawal et al.(NeurIPS 2019)将 OptNet 推广到**任意凸优化问题**。核心思想是利用 CVXPY 的 DCP(Disciplined Convex Programming)框架:

  1. 用 CVXPY 定义凸问题(参数用 cp.Parameter 标记)
  2. 框架自动将问题 canonicalize 为锥规划
  3. 用锥求解器(SCS、ECOS)求解
  4. 用锥 KKT 条件做隐式微分

DPP(Disciplined Parametrized Programming)约束:参数必须以**仿射方式**出现在问题数据中(即 \(Q(\theta) = Q_0 + \sum_i \theta_i Q_i\))。这个约束保证了 canonicalization 是可微的。

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
import torch

# 定义参数化 QP
n = 10
Q_sqrt = cp.Parameter((n, n))
q = cp.Parameter(n)
y = cp.Variable(n)
prob = cp.Problem(cp.Minimize(0.5 * cp.sum_squares(Q_sqrt @ y) + q @ y),
                  [y >= -1, y <= 1])

# 包装为 PyTorch 层
layer = CvxpyLayer(prob, parameters=[Q_sqrt, q], variables=[y])

# 前向 + 反向(梯度穿过 QP 求解!)
Q_sqrt_val = torch.randn(n, n, requires_grad=True)
q_val = torch.randn(n, requires_grad=True)
y_star, = layer(Q_sqrt_val, q_val)
loss = y_star.sum()
loss.backward()  # dL/dQ_sqrt 和 dL/dq 通过隐式微分计算

本质洞察:可微优化层的价值不在于"让优化可微"这个技术本身,而在于它建立了**优化与学习的双向通道**。正向:学习系统生成优化问题的参数;反向:优化解的质量反馈给学习系统调整参数。这使得"端到端训练一个包含优化的系统"成为可能——这正是可微 MPC、可微 motion planning 的数学基础。

⚠️ 常见陷阱

⚠️ 编程陷阱:在 cvxpylayers 中使用非 DPP 兼容的构造 - 错误做法:参数出现在非线性位置,如 cp.quad_form(x, P)P 是参数 - 现象:cvxpylayers 报错 "problem is not DPP compliant" - 根本原因:DPP 要求参数以仿射方式出现在问题数据中。quad_form(x, P) 中 P 在二次项中不是仿射的 - 正确做法:改写为 cp.sum_squares(P_sqrt @ x),其中 P_sqrt 是参数

🧠 思维陷阱:认为隐式微分总是比展开式更好 - 例外情况:(1) 当求解器没有精确收敛时,隐式微分给出的是"精确解的梯度"而非"近似解的梯度"——两者之间有 gap;(2) 当 KKT 矩阵接近奇异(如 active set 即将变化时),隐式微分的数值稳定性变差 - 正确思维:隐式微分是默认首选,但需要监控 KKT 矩阵的条件数

练习

  1. 对无约束 QP \(\min \frac{1}{2} y^\top Q y + q^\top y\)\(Q \succ 0\)),最优解 \(y^* = -Q^{-1}q\)。直接微分得到 \(\partial y^* / \partial q = -Q^{-1}\)。用 IFT 推导同样的结果(把 KKT 条件写成 \(F = Qy + q = 0\)),验证两者一致。
  2. 用 cvxpylayers 实现一个简单的"学习代价函数"实验:给定专家轨迹 \(y_{\text{expert}}\),训练 QP 参数使得 \(\|y^*(\theta) - y_{\text{expert}}\|^2\) 最小化。
  3. (跨章综合)回顾凸分析中的 Slater 条件。解释为什么 Slater 条件是隐式微分可行的前提——它保证了什么?如果 Slater 条件不满足,KKT 矩阵会有什么问题?

9. 可微物理仿真 ⭐⭐⭐

动机

可微仿真是机器人学习的前沿方向:如果物理仿真器是可微的,就可以直接计算"策略参数→仿真轨迹→目标函数"的端到端梯度,而不需要像 RL 那样靠采样估计。这在平滑动力学下可以带来 10-100 倍的样本效率提升。

但接触和碰撞引入了**不连续性**,使得梯度问题变得微妙。这一节讨论可微仿真的原理、主要框架和核心挑战。

9.1 Pinocchio 的解析微分 ⭐⭐⭐

Pinocchio 是机器人刚体动力学的 C++ 库。Carpentier 和 Mansard(RSS 2018)证明了一个关键结果:RNEA(递归 Newton-Euler 算法)和 ABA(Articulated Body 算法)可以**手工推导**解析微分公式,比黑盒 AD 快 3 倍以上。

为什么解析微分更快? RNEA 的前向递推和反向递推天然编码了空间代数的链式法则。黑盒 AD 不知道这个结构,会产生大量冗余计算。Carpentier 和 Mansard 利用递推结构,将 Jacobian 计算分解为与前向/反向递推同形的递推过程。

性能数据(Carpentier-Mansard RSS 2018):

模型 自由度 解析微分 CppAD 黑盒 AD 加速比
7-dof arm 7 3 \(\mu\)s 15 \(\mu\)s 5\(\times\)
人形 36 17 \(\mu\)s 78 \(\mu\)s 4.6\(\times\)
import pinocchio as pin

# 加载模型
model = pin.buildModelFromUrdf("robot.urdf")
data = model.createData()

# 计算 RNEA 微分:∂τ/∂q, ∂τ/∂v, ∂τ/∂a
pin.computeRNEADerivatives(model, data, q, v, a)
dtau_dq = data.dtau_dq   # n×n 矩阵
dtau_dv = data.dtau_dv   # n×n 矩阵
# dtau_da = M(q) 就是质量矩阵(RNEA 关于加速度是线性的)

Pinocchio 解析微分的数学原理

RNEA 计算的是逆动力学 \(\tau = \text{RNEA}(q, \dot{q}, \ddot{q})\)。其内部结构是从根到叶的前向递推(计算每个关节的速度和加速度)和从叶到根的反向递推(累积力和力矩)。

Carpentier 和 Mansard 的关键观察:RNEA 的递推结构**本身就是一种链式法则的具体实现**——空间代数中的递推关系编码了关节到关节的局部变换和力的传递。对 RNEA 求微分,等价于在这个递推结构上叠加一层"微分递推"。

具体来说,\(\partial \tau / \partial q\) 可以分解为: - 对每个关节 \(i\),计算"关节位形对空间速度/加速度的影响"(前向递推的微分) - 累积"空间力对扭矩的影响"(反向递推的微分)

这两层递推的计算量与原始 RNEA 相当(同样是 \(O(n)\)\(n\) 为关节数),但通过利用稀疏结构,常数因子只有 2-3 倍——而黑盒 AD 的常数因子可达 10-15 倍。

9.2 可微仿真引擎对比 ⭐⭐⭐

引擎 语言 可微方式 接触模型 GPU 主要用途
Brax JAX/Python JAX AD 弹簧/PBD 大规模 RL
MuJoCo MJX JAX/Python JAX AD 凸优化 RL + 控制
Dojo Julia IFT NCP+IPM 精确接触梯度
DiffTaichi Python/Taichi 编译器 AD 各种 流体/弹性体
NVIDIA Warp Python→CUDA 编译器 AD 各种 通用 GPU 仿真
Nimble C++/Python adjoint spatial LCP 接触敏感任务

Brax 的设计哲学:用 JAX 的 jit + vmap + grad 组合实现"数千个环境并行仿真 + 整体可微"。接触模型简化为弹簧或 PBD(Position Based Dynamics),牺牲物理精度换取可微性和速度。

Dojo 的设计哲学(Howell et al., RSS 2022):通过 primal-dual 内点法求解接触 NCP(非线性互补问题),然后在 IPM 的 KKT 系统上用 IFT 求梯度。关键参数 \(\kappa\)(central-path 参数)控制接触的"软化"程度——\(\kappa\) 越大接触越软、梯度越平滑,但物理精度下降。

9.3 接触微分的核心困难 ⭐⭐⭐

接触力学引入了**Heaviside 类的不连续性**:物体要么接触、要么不接触,接触力要么存在、要么为零。这种不连续性使得 \(\partial f / \partial x\) 在接触切换点上不存在(或为零),first-order gradient 失效。

Suh et al.(ICML 2022 Outstanding Paper)给出了定量分析:

定理(FoBG 偏差):设 \(z = \phi(\theta, \epsilon)\) 包含 Heaviside 不连续性,smooth 目标 \(J(\theta) = \mathbb{E}_\epsilon[R(\phi(\theta, \epsilon))]\)。则:

  • FoBG(first-order batched gradient)\(= \mathbb{E}[\frac{\partial R}{\partial z} \cdot \frac{\partial \phi}{\partial \theta}]\) 几乎处处为零(因为 \(\partial \phi / \partial \theta = 0\)),但 \(\nabla J(\theta)\) 非零——FoBG 有偏
  • ZoBG(zeroth-order batched gradient,即 REINFORCE)\(= \mathbb{E}[R \cdot \nabla \log p(\epsilon)]\) 不经过 \(\phi\) 的 Jacobian——ZoBG 无偏

实际含义:在有接触的机器人任务中,可微仿真给出的梯度可能完全错误。对策包括:

  1. 随机光滑化(randomized smoothing):对参数加噪声 \(\theta + \sigma \xi\),在期望意义下获得光滑梯度
  2. 互补松弛(complementarity relaxation):将接触条件 \(0 \leq \lambda \perp \phi \geq 0\) 松弛为 \(\lambda \phi = \kappa\)(Dojo 方法)
  3. 混合策略:光滑阶段用可微仿真(FoBG),接触密集阶段切到 REINFORCE(ZoBG)

反事实推理:如果所有物理系统都是光滑的(没有接触、碰撞),那么可微仿真将是 RL 的终极解决方案——样本效率比无模型 RL 高 100 倍。但现实世界充满了不连续性(抓取、行走、碰撞),这使得"用什么梯度"成为一个需要仔细分析的研究问题。

9.4 SHAC:截断窗口解决梯度爆炸 ⭐⭐⭐⭐

Xu et al.(ICLR 2022)提出 SHAC(Short-Horizon Actor-Critic),解决可微仿真中长 horizon BPTT 的梯度爆炸问题。核心思想:

  1. 只在短窗口(如 16-32 步)内通过可微仿真做 BPTT
  2. 用 critic 估计窗口外的 value function
  3. 结合短窗口的 FoG 和 critic 的 bootstrap

这在一些连续控制任务(如 Ant locomotion)上实现了比 PPO 更高的样本效率和更快的训练速度。

⚠️ 常见陷阱

💡 概念误区:认为可微仿真可以完全替代 model-free RL - 错误推理:"有了可微仿真,直接反向传播就能训练策略,不需要 PPO/SAC 了" - 实际上:(1) 接触导致梯度偏差;(2) 长 horizon 导致梯度爆炸(Lyapunov 指数);(3) 混沌系统的梯度方差随时间指数增长。Suh 2022 的结论是:没有银弹,需要根据任务特性选择 FoBG、ZoBG 或混合方法 - 正确思维:可微仿真是工具箱中的一个强力工具,但不是唯一的

⚠️ 编程陷阱:MuJoCo MJX 的 while_loop 反向失败 - 现象:用 jax.grad 对 MJX 仿真做反向 AD 时报错 - 根本原因:MJX 默认 CG 求解器用 jax.lax.while_loop(动态终止条件),JAX 的反向 AD 不支持 - 正确做法:切换到 Newton 求解器或使用 jax.lax.fori_loop(静态上界)

练习

  1. 解释 Pinocchio 的 computeRNEADerivatives 为什么比黑盒 AD 更快。(提示:利用了 RNEA 递推的什么结构性质?)
  2. 设计一个简单的 1D 接触问题:小球落地 \(x_{t+1} = \max(x_t + v_t \Delta t, 0)\)。分析 \(\partial x_{t+1} / \partial v_t\)\(x_t + v_t \Delta t = 0\) 处的行为。这解释了为什么接触梯度是困难的。
  3. (跨章综合)回顾 §7 的隐式微分。解释 Dojo 如何用 IFT 绕过接触的不连续性——它把什么写成 \(F(z, \theta) = 0\)

10. 工程实践 ⭐⭐

动机

理解 AD 的数学原理是基础,但要在工程中**正确、高效、可靠地使用 AD**,还需要掌握一系列实践技巧。本节覆盖三个关键主题:梯度验证、数值稳定性和可微渲染。

10.1 梯度检查(Gradient Check) ⭐⭐

每当你实现了一个自定义的 AD 规则(如 torch.autograd.Functionjax.custom_vjp),第一件事就是**用有限差分验证它的正确性**。这就是"有限差分做验证"的正确用法。

import jax
import jax.numpy as jnp

def check_grad(f, x, eps=1e-5):
    """用中心差分验证 jax.grad 的正确性"""
    grad_ad = jax.grad(f)(x)
    grad_fd = jnp.zeros_like(x)
    for i in range(len(x)):
        e_i = jnp.zeros_like(x).at[i].set(1.0)
        grad_fd = grad_fd.at[i].set(
            (f(x + eps * e_i) - f(x - eps * e_i)) / (2 * eps)
        )
    # 相对误差应小于 sqrt(eps) ~ 1e-2.5
    rel_error = jnp.linalg.norm(grad_ad - grad_fd) / (jnp.linalg.norm(grad_ad) + 1e-10)
    print(f"相对误差: {rel_error:.2e}")
    return rel_error < 1e-3  # 经验阈值

精度预期:对中心差分(\(h = 10^{-5}\)),截断误差 \(O(h^2) \approx 10^{-10}\),舍入误差 \(O(\epsilon_{\text{mach}} / h) \approx 10^{-11}\)。AD 与有限差分的相对误差应在 \(10^{-5}\)\(10^{-7}\) 之间。如果差异超过 \(10^{-3}\),几乎可以确定 AD 实现有 bug。

10.2 复步微分法(Complex Step Method) ⭐⭐⭐

有限差分有截断误差和舍入误差的矛盾。复步微分法是一个巧妙的替代方案:

\[f'(x) = \text{Im}\left[\frac{f(x + ih)}{h}\right] + O(h^2)\]

其中 \(i = \sqrt{-1}\)。关键观察:分子不涉及两个接近的实数相减——\(f(x)\) 出现在实部,\(f'(x) \cdot h\) 出现在虚部,两者在完全不同的"通道"中。因此没有 catastrophic cancellation,可以取极小的 \(h\)(如 \(h = 10^{-200}\))获得接近机器精度的导数。

与 dual number 的关系:复步微分法实际上是 dual number AD 的一个特例。在 \(\mathbb{C}\) 中,\(f(x + ih) \approx f(x) + if'(x)h\)(当 \(h\) 足够小时)。虚部除以 \(h\) 就得到 \(f'(x)\)。但 dual number 更一般——它不要求 \(h\) 趋于零,因为 \(\varepsilon^2 = 0\) **精确地**消除了二阶项,而不是靠 \(h\) 很小来近似消除。

import numpy as np

def complex_step_derivative(f, x, h=1e-30):
    """复步微分法:精度接近机器精度,无 catastrophic cancellation"""
    return np.imag(f(x + 1j * h)) / h

# 测试
f = lambda x: np.sin(x**3) * np.exp(x)
x0 = 1.5

# 比较三种方法
deriv_fd = (f(x0 + 1e-8) - f(x0)) / 1e-8         # 前向差分
deriv_cd = (f(x0 + 1e-5) - f(x0 - 1e-5)) / 2e-5  # 中心差分
deriv_cs = complex_step_derivative(f, x0)           # 复步微分

# 解析导数
exact = (3*x0**2*np.cos(x0**3) + np.sin(x0**3)) * np.exp(x0)

print(f"前向差分误差: {abs(deriv_fd - exact):.2e}")   # ~1e-8
print(f"中心差分误差: {abs(deriv_cd - exact):.2e}")   # ~1e-10
print(f"复步微分误差: {abs(deriv_cs - exact):.2e}")   # ~1e-16(机器精度!)

复步微分法的局限:它要求函数 \(f\) 可以接受复数输入,且在实轴附近是解析的。对于包含 absmax、分支判断的函数不适用。

10.3 数值稳定性 ⭐⭐

AD 虽然没有截断误差,但仍然受**浮点舍入**和**数学条件数**的影响。以下是最常见的数值问题:

Softmax 的 log-sum-exp 技巧:直接计算 \(\log(\sum_i e^{x_i})\)\(x_i\) 很大时会溢出。标准技巧是减去最大值 \(M = \max_i x_i\)

\[\text{logsumexp}(x) = M + \log\left(\sum_i e^{x_i - M}\right)\]

这不影响 AD 的正确性(因为减去常数不影响梯度),但对数值稳定性至关重要。

梯度消失/爆炸:在长序列反向传播中,梯度通过大量矩阵连乘传播。如果矩阵的谱范数 \(> 1\),梯度指数增长;如果 \(< 1\),梯度指数衰减。对策包括梯度裁剪(gradient clipping)、残差连接、LSTM 的门控机制等。

病态 Hessian:当 Hessian 矩阵的条件数 \(\kappa = \lambda_{\max} / \lambda_{\min}\) 很大时,AD 计算的梯度方向虽然正确,但优化收敛速度极慢(梯度下降的收敛率 \(\sim ((\kappa-1)/(\kappa+1))^k\))。此时需要预条件化或自适应步长。

10.4 AD 中的常见数值灾难 ⭐⭐

以下是 AD 实践中最容易遇到的数值问题及其标准解决方案:

灾难 1:\(\sqrt{0}\) 的导数\(\frac{d}{dx}\sqrt{x} = \frac{1}{2\sqrt{x}}\),在 \(x=0\) 处趋于无穷。如果你的代码中有 loss = sqrt(x**2 + y**2)(如 L2 范数),在 \((x,y) = (0,0)\) 处梯度会出现 NaN。

标准对策:添加小常数 loss = sqrt(x**2 + y**2 + eps),或使用 jnp.linalg.norm(v + eps) 而非手写 sqrt(sum(v**2))

灾难 2:\(\log(0)\)\(\log(\text{softmax})\)\(\log(0) = -\infty\)\(\frac{d}{dx}\log(x) = 1/x\)\(x \to 0\) 时趋于无穷。在分类任务中,loss = -log(softmax(logits)[label]) 的梯度在 logits 差异大时可能出问题。

标准对策:使用 log_softmax(一步计算,数值稳定),而非先 softmaxlog

灾难 3:条件数的放大效应。AD 的结果在数学上是精确的(浮点范围内),但如果函数本身是病态的(条件数大),AD 给出的梯度虽然"精确",但对输入微小扰动高度敏感。例如,矩阵求逆 \(f(A) = A^{-1}\) 的条件数为 \(\kappa(A)^2\)——如果 \(\kappa(A) = 10^8\),那么 \(A\) 的 8 位精度变化会导致 \(A^{-1}\) 完全不准。

标准对策:避免显式矩阵求逆,改用分解+求解;添加正则化降低条件数。

10.5 可微渲染简介 ⭐⭐⭐⭐

可微渲染是 AD 在计算机视觉中的前沿应用。两个里程碑式的工作:

NeRF(Mildenhall et al., ECCV 2020)用 MLP \(F_\theta(\mathbf{x}, \mathbf{d}) \to (c, \sigma)\) 表示场景的颜色和密度,通过**体积渲染**将 3D 表示投影到 2D 图像。体积渲染公式:

\[C(\mathbf{r}) = \int_0^{+\infty} T(t) \sigma(\mathbf{r}(t)) \mathbf{c}(\mathbf{r}(t), \mathbf{d}) \, dt, \quad T(t) = \exp\left(-\int_0^t \sigma(\mathbf{r}(s))\,ds\right)\]

整个渲染过程是可微的(通过 PyTorch 反向 AD),因此可以用多视角图像反向传播优化 MLP 参数 \(\theta\)

3D Gaussian Splatting(Kerbl et al., SIGGRAPH 2023)用三维高斯椭球代替 MLP 表示场景。每个高斯由位置、协方差、颜色和不透明度参数化。渲染通过"splatting"(将 3D 高斯投影到 2D 图像平面并做 alpha compositing)实现。

关键的 AD 挑战:高斯 splatting 的渲染管线涉及 3D→2D 投影、排序和 alpha 合成——这些操作需要**自定义 CUDA kernel 的手工 VJP**(因为通用 AD 框架对 GPU kernel 的支持有限),Kerbl 等人用手写的 CUDA 反向传播实现了这一点。

10.6 AD 工程的黄金法则 ⭐⭐

总结本节的实践经验,以下六条法则涵盖了 AD 工程中最重要的注意事项:

  1. 永远做梯度检查:任何自定义的前向/反向规则,在投入使用前必须用有限差分或复步微分验证。没有通过梯度检查的代码不可信。

  2. 先用框架提供的 AD,再考虑手写:PyTorch 的 autograd、JAX 的 grad 已经处理了大量边界情况。只有当性能不满足需求时才考虑手写 VJP(如 3DGS 的 CUDA kernel)。

  3. 关注 \(n/m\) 比值\(n \gg m\) 用反向模式,\(n \ll m\) 用前向模式。选错模式可能导致 1000 倍的性能差异。

  4. 监控条件数:AD 计算的梯度在数学上是精确的,但如果函数本身是病态的,梯度仍然不可靠。定期检查 Hessian 条件数或 KKT 矩阵条件数。

  5. 隐式优于展开:对任何涉及迭代求解的计算(优化、ODE、不动点),默认使用隐式微分(IFT),除非有充分理由使用展开式 AD。

  6. 不可微处需要特殊处理absmaxrelu、接触、碰撞等不可微运算需要 log-barrier、softplus、随机光滑化等处理。忽略不可微性会导致零梯度或错误梯度。

⚠️ 常见陷阱

⚠️ 编程陷阱:忘记做梯度检查就信任自定义 VJP - 错误做法:写了 custom_vjp 后直接用于训练 - 现象:训练 loss 不降或出现 NaN,但不知道原因 - 根本原因:手写的 VJP 公式可能有符号错误、转置错误或维度不匹配 - 正确做法:先用有限差分验证,再用于训练。PyTorch 提供 torch.autograd.gradcheck

💡 概念误区:认为 float32 够用就不需要考虑数值稳定性 - 错误想法:"GPU 上用 float32 训练就行了,不用管 float64" - 实际上:float32 只有 7 位有效数字。当两个接近的数相减时(catastrophic cancellation),有效位数可能骤降到 1-2 位,导致梯度严重失真 - 正确思维:在关键路径上使用数值稳定的实现(如 log-sum-exp、稳定的 softmax),而不是简单地切换到 float64

练习

  1. 实现 gradcheck(f, x) 函数,使用中心差分验证 JAX 自动梯度的正确性。对 \(f(x) = \log(\sum_i e^{x_i})\)\(x = [1000, 1001, 1002]\) 处测试——如果不用 log-sum-exp 技巧会发生什么?
  2. 对一个 50 层的全连接网络(无残差连接),实验观察反向传播中梯度范数随层数的变化。在什么条件下出现梯度消失/爆炸?
  3. 阅读 NeRF 的体积渲染公式,解释为什么这个积分的离散近似是可微的。每个采样点的"贡献权重" \(T(t_i) (1 - \exp(-\sigma_i \delta_i))\)\(\sigma_i\) 的梯度是什么?

11. 连续伴随方法与 Neural ODE ⭐⭐⭐

动机

到目前为止,我们讨论的隐式微分(§7-§8)处理的是**离散的方程组** \(F(x, y) = 0\)。但在最优控制和科学计算中,许多问题天然是**连续时间**的:ODE 约束下的参数估计、连续深度模型(Neural ODE)、轨迹优化。对这些问题,连续伴随方程(continuous adjoint method)提供了一种内存高效的梯度计算方法。

11.1 问题设定

考虑参数化的 ODE 系统:

\[\dot{z}(t) = f_\theta(z(t), t), \quad z(0) = z_0\]

目标是最小化终端损失 \(L(z(T))\)。问题:如何计算 \(dL/d\theta\)

朴素方法(discretize-then-optimize):将 ODE 用 Euler/RK4 离散化为 \(K\) 步,把所有步骤展开成计算图,然后反向 AD。内存 \(O(K)\),因为需要存储所有中间状态。

伴随方法(optimize-then-discretize):在连续层面推导梯度公式,然后再离散化计算。内存 \(O(1)\)(只存储终端状态和伴随变量)。

11.2 连续伴随方程的推导 ⭐⭐⭐

这是自动微分理论中最优美的推导之一。 它把链式法则从离散图推广到连续时间。

Step 1:构造增广 Lagrangian

引入伴随变量(Lagrange 乘子函数)\(\lambda(t)\),将 ODE 约束纳入目标:

\[\mathcal{L} = L(z(T)) + \int_0^T \lambda(t)^\top (f_\theta(z, t) - \dot{z}) \, dt\]

为什么要这样做?\(z(t)\) 满足 ODE 约束 \(\dot{z} = f_\theta\) 时,积分项为零,\(\mathcal{L} = L(z(T))\)。但通过引入 \(\lambda\),我们可以对 \(z(t)\) 做变分(虚位移),而不需要约束 \(z(t)\) 始终满足 ODE。

Step 2:对 \(z(t)\) 做变分

\(z(t) \to z(t) + \delta z(t)\),则:

\[\delta \mathcal{L} = \frac{\partial L}{\partial z}\bigg|_{z(T)} \delta z(T) + \int_0^T \lambda^\top \left(\frac{\partial f}{\partial z} \delta z - \delta \dot{z}\right) dt\]

\(\delta \dot{z}\) 做分部积分:

\[\int_0^T \lambda^\top (-\delta \dot{z}) \, dt = -\lambda^\top \delta z \bigg|_0^T + \int_0^T \dot{\lambda}^\top \delta z \, dt\]

代入得:

\[\delta \mathcal{L} = \left(\frac{\partial L}{\partial z}\bigg|_{z(T)} - \lambda(T)\right)^\top \delta z(T) + \int_0^T \left(\lambda^\top \frac{\partial f}{\partial z} + \dot{\lambda}^\top\right) \delta z \, dt + \lambda(0)^\top \delta z(0)\]

Step 3:令各阶变分为零

由于 \(\delta z(t)\) 是任意的,各系数必须分别为零:

  • 终端条件\(\lambda(T) = \frac{\partial L}{\partial z}\bigg|_{z(T)}\)
  • 伴随 ODE\(\dot{\lambda} = -\lambda^\top \frac{\partial f}{\partial z}\)(即 \(\dot{\lambda} = -\left(\frac{\partial f}{\partial z}\right)^\top \lambda\)
  • 初始条件不变\(\delta z(0) = 0\)(初始条件是给定的)

Step 4:参数梯度

\(\theta\) 求全微分:

\[\frac{dL}{d\theta} = \int_0^T \lambda(t)^\top \frac{\partial f}{\partial \theta}\bigg|_{z(t)} dt\]

这就是**连续伴随方程**:反向从 \(T\)\(0\) 积分 \(\lambda\),同时累积参数梯度。

本质洞察:连续伴随方程是反向模式 AD 在连续时间的自然推广。离散图上的 adjoint 变量 \(\bar{v}_i = \partial y / \partial v_i\) 变成了连续的 \(\lambda(t) = \partial L / \partial z(t)\)。离散图上"逆拓扑序遍历"变成了"反向时间积分"。链式法则的"乘法+求和"变成了"ODE + 积分"。形式不同,本质完全相同。

11.3 Neural ODE(Chen et al., NeurIPS 2018) ⭐⭐⭐

Chen et al. 的核心观察:如果把 ResNet 的层 \(z_{k+1} = z_k + f_\theta(z_k, k)\) 视为 Euler 离散化,那么当层数趋于无穷、步长趋于零时,ResNet 变成了 ODE \(\dot{z} = f_\theta(z, t)\)

Neural ODE 的优势: 1. 自适应计算量:ODE 求解器根据函数的"刚性"自动调整步长,简单区域大步走、困难区域小步走 2. 常数内存:用伴随方法计算梯度,内存不随"深度"增长 3. 可逆性:ODE 的前向和反向积分互为逆过程,无需存储前向状态

Neural ODE 的局限: 1. 表达能力有限:ODE 的轨迹不能交叉(唯一性定理),这限制了模型的表达能力 2. 刚性 ODE 的问题:当 \(f_\theta\) 产生刚性动力学时,adjoint ODE 可能产生 stiff 反向积分,导致数值不稳定 3. discretize-then-optimize vs optimize-then-discretize 的不一致:Gholami et al.(2019)指出,当前向用自适应步长求解器时,连续伴随的离散版本与"展开求解器再反向 AD"给出的梯度不完全相同——因为步长选择依赖于状态,而状态对参数可微

跨领域类比:Neural ODE 之于 ResNet,就像微分方程之于差分方程。ResNet 是"离散时间的动力系统",Neural ODE 是"连续时间的动力系统"。就像连续分析往往比离散分析更优美(可以用微积分工具),Neural ODE 的伴随方法也比 BPTT 更优美(常数内存、自适应精度)。

Neural ODE 在机器人学中的应用

应用场景 为什么用 Neural ODE 代替什么
系统辨识 学习连续时间动力学 \(\dot{x} = f_\theta(x, u)\) 传统 NARX/LSTM
轨迹预测 自适应步长,在简单阶段节省计算 固定步长 RNN
物理信息模型 可嵌入已知物理结构(如 Hamiltonian、Lagrangian) 纯数据驱动模型
连续时间 RL 策略和值函数是 ODE 的解 离散时间 RL

特别值得关注的是 Hamiltonian Neural Networks(Greydanus et al., NeurIPS 2019)和 Lagrangian Neural Networks(Cranmer et al., ICLR 2020)——它们把物理守恒律(能量守恒、动量守恒)嵌入 Neural ODE 的结构中,使学到的动力学模型在长时间预测中保持物理一致性。这对机器人学中的"学习-控制"闭环至关重要:如果学到的模型违反能量守恒,基于模型的控制器可能产生不稳定的行为。

11.4 Discretize-then-Optimize vs Optimize-then-Discretize ⭐⭐⭐⭐

这是计算科学中一个经典的二元对立,理解它对正确使用 Neural ODE 至关重要。

维度 Discretize-then-Optimize(展开) Optimize-then-Discretize(伴随)
做法 先离散化 ODE(如 RK4),再对离散步骤做 AD 先在连续层面推导伴随方程,再离散化伴随 ODE
内存 \(O(K)\)(存所有步的状态) \(O(1)\)(只存终端状态)
梯度一致性 与前向计算严格一致 可能有离散化误差
自适应步长 前向步长选择本身是可微的(复杂) 前向和反向可用不同步长(灵活)
实现复杂度 简单(直接用 AD 框架) 需要实现伴随 ODE 的求解器

工程建议:对于大多数应用,先用 checkpointed discretize-then-optimize(如 torch.utils.checkpoint + torchdiffeq),因为它的梯度与前向严格一致。只有在内存极度受限时才切换到纯伴随方法。

⚠️ 常见陷阱

⚠️ 编程陷阱:对刚性 ODE 用伴随方法导致梯度不准 - 现象:伴随方程的反向积分数值发散,梯度出现 NaN 或大幅震荡 - 根本原因:刚性 ODE 的 Jacobian \(\partial f / \partial z\) 有很大的负实部特征值,反向积分时这些变成正实部→指数增长 - 正确做法:使用 Gholami et al. 推荐的 checkpointed augmented-state 方法,或切换到隐式 ODE 求解器

💡 概念误区:认为 Neural ODE 总是比 ResNet 好 - 实际上:(1) Neural ODE 的训练通常比 ResNet 慢 5-10 倍(ODE 求解器开销);(2) 表达能力有限(轨迹不交叉);(3) 在图像分类等标准任务上并不比 ResNet 更准确 - Neural ODE 的真正价值在于建模连续时间动力学(如时间序列、物理系统)和理论分析(如动力系统稳定性)

练习

  1. 对简谐振子 \(\ddot{x} = -\omega^2 x\)(改写为一阶系统 \(\dot{z} = Az\)\(A = \begin{pmatrix} 0 & 1 \\ -\omega^2 & 0 \end{pmatrix}\)),写出伴随方程 \(\dot{\lambda} = -A^\top \lambda\)。验证伴随 ODE 也是简谐振子(频率相同)。
  2. 推导参数梯度 \(dL/d\omega^2\) 的积分表达式。如果 \(L = z_1(T)^2\)(终端位移的平方),\(\lambda(T) = ?\)
  3. torchdiffeq 实现一个简单的 Neural ODE 分类器(如螺旋线分类),比较(1) adjoint 方法和(2) 直接 BPTT 的内存消耗和梯度差异。

12. 深度平衡模型与定点隐式微分 ⭐⭐⭐⭐

动机

Neural ODE 把无限深度解释为连续 ODE 的积分。Deep Equilibrium Models(DEQ, Bai et al. NeurIPS 2019)走了另一条路:把无限深度解释为**定点迭代的收敛**。

考虑一个共享权重的深度网络 \(z_{k+1} = f_\theta(z_k, x)\)。当 \(k \to \infty\) 时(如果收敛),\(z^* = f_\theta(z^*, x)\)——这就是**定点方程**。DEQ 直接求解定点方程(用 Anderson acceleration、Broyden 法等),然后用 IFT 计算梯度。

12.1 定点隐式微分的推导 ⭐⭐⭐

定点条件 \(z^* = f_\theta(z^*, x)\) 可以写成 \(F(z, \theta, x) = z - f_\theta(z, x) = 0\)。对 \(\theta\) 求微分:

\[\frac{\partial F}{\partial z} \frac{dz^*}{d\theta} + \frac{\partial F}{\partial \theta} = 0\]
\[\left(I - \frac{\partial f}{\partial z}\bigg|_{z^*}\right) \frac{dz^*}{d\theta} = \frac{\partial f}{\partial \theta}\bigg|_{z^*}\]
\[\frac{dz^*}{d\theta} = \left(I - \frac{\partial f}{\partial z}\bigg|_{z^*}\right)^{-1} \frac{\partial f}{\partial \theta}\bigg|_{z^*}\]

VJP 形式:给定上游梯度 \(\bar{z}\),需要解:

\[u^\top \left(I - \frac{\partial f}{\partial z}\bigg|_{z^*}\right) = \bar{z}^\top\]

然后 \(\bar{\theta} = u^\top \frac{\partial f}{\partial \theta}\)

关键观察:这个线性方程本身也是一个定点问题!可以用迭代法求解:

\[u_{k+1}^\top = \bar{z}^\top + u_k^\top \frac{\partial f}{\partial z}\]

这是 Neumann 级数 \((I - A)^{-1} = \sum_{k=0}^{\infty} A^k\) 的迭代版本(当 \(\|\partial f / \partial z\| < 1\) 时收敛)。实际中通常用 Anderson acceleration 加速收敛。

12.2 DEQ 的优势与局限 ⭐⭐⭐⭐

优势: 1. 常数内存:前向只需要存储 \(z^*\)(不需要展开 \(K\) 步的中间状态) 2. 反向也是常数内存:VJP 通过解线性方程得到(不需要反向穿过 \(K\) 步迭代) 3. 解耦前向与反向:前向求解器和反向求解器可以完全不同(前向用 Broyden,反向用 CG)

局限: 1. 需要收敛保证:如果 \(f\) 不是压缩映射(\(\|\partial f / \partial z\| \geq 1\)),定点迭代不收敛 2. 隐式微分假设精确收敛:如果前向没有完全收敛到 \(z^*\),隐式微分给出的梯度是"精确解的梯度"而非"近似解的梯度"——两者之间有 gap 3. Jacobian 条件数\(I - \partial f / \partial z\) 接近奇异时,隐式微分数值不稳定

反事实推理:如果我们不用 IFT 而是展开 1000 步定点迭代做 BPTT,会发生什么?(1) 内存 \(O(1000)\) vs IFT 的 \(O(1)\);(2) 梯度受"收敛路径"污染——如果迭代只到 \(z_{500}\) 就几乎收敛了,后 500 步的梯度几乎为零,但 BPTT 仍然要穿过它们,浪费计算并引入数值噪声;(3) 梯度的质量依赖于迭代步数 \(K\),而 IFT 的梯度只依赖于收敛解 \(z^*\),与 \(K\) 无关。

12.3 Anderson 加速与 Broyden 法 ⭐⭐⭐⭐

Neumann 级数收敛速度可能很慢(线性收敛,速率 \(\|\partial f / \partial z\|\))。实践中常用两种加速方法:

Anderson acceleration(混合法):保存最近 \(m\) 步的迭代历史 \(\{z_k, g_k = T(z_k) - z_k\}\),在这 \(m\) 个点的仿射组合中找最小残差:

\[\alpha^* = \arg\min_\alpha \left\|\sum_{i=0}^{m} \alpha_i g_{k-i}\right\|^2 \quad \text{s.t.} \quad \sum_i \alpha_i = 1\]

然后 \(z_{k+1} = \sum_i \alpha_i T(z_{k-i})\)。Anderson 加速通常将线性收敛加速为超线性收敛。

Broyden 法:用秩一更新维护 Jacobian 的近似逆 \(B_k \approx (I - \partial f / \partial z)^{-1}\)

\[B_{k+1} = B_k + \frac{(\Delta z_k - B_k \Delta g_k) \Delta z_k^\top B_k}{\Delta z_k^\top B_k \Delta g_k}\]

其中 \(\Delta z_k = z_{k+1} - z_k\), \(\Delta g_k = g_{k+1} - g_k\)。Broyden 法不需要计算 Jacobian(\(O(n^2)\) 而非 \(O(n^3)\)),是 DEQ 论文中使用的默认方法。

为什么不直接用 Newton 法? Newton 法 \(z_{k+1} = z_k - (I - \partial f / \partial z)^{-1}(z_k - f(z_k))\) 需要在每步计算完整 Jacobian \(\partial f / \partial z\) 并求解线性方程——对深度网络来说计算量过大。Broyden 法用近似 Jacobian 替代精确 Jacobian,每步只需 \(O(n^2)\)

12.4 Blondel 2022 的统一框架 ⭐⭐⭐

Blondel et al.(NeurIPS 2022)给出了一个统一视角:所有隐式微分问题都可以归结为 \(F(x^*, \theta) = 0\) 的 IFT

问题类型 \(F\) 的形式 例子
定点 \(F = z - T_\theta(z)\) DEQ, RNN, Almeida-Pineda
最优性 \(F = \nabla_z f(z, \theta)\) 无约束优化、proximal 算子
KKT \(F = (KKT 条件)\) OptNet, cvxpylayers, diff-MPC
ODE 终端 \(F = z(T) - \text{ODESolve}(\theta)\) Neural ODE

JAXopt 库实现了这个统一框架:用户只需要实现 \(F\) 和前向求解器,隐式微分的 VJP 自动生成。

练习

  1. 对映射 \(f(z) = \tanh(Wz + b)\)\(W\)\(n \times n\) 矩阵),写出定点条件 \(z^* = f(z^*)\)。什么条件下 \(f\) 是压缩映射?(提示:\(\|\tanh'\|_\infty = 1\),所以需要 \(\|W\| < 1\)。)
  2. 用 Neumann 级数 \(u_{k+1}^\top = \bar{z}^\top + u_k^\top J\)\(J = \partial f / \partial z|_{z^*}\))实现 DEQ 的反向传播,验证与 JAX 的 custom_vjp + 直接求逆的结果一致。
  3. 比较 DEQ 和 Neural ODE:两者都实现了"无限深度",但机制完全不同。列出至少 3 个维度的对比(内存、计算、表达能力、稳定性)。

13. 可微 MPC 与 Pontryagin Differentiable Programming ⭐⭐⭐

动机

Model Predictive Control(MPC)是机器人控制的主力工具。如果 MPC 是可微的(即 MPC 的输出——最优控制序列——对其参数如代价权重、动力学参数是可微的),就可以做到:

  1. 学习代价函数:通过模仿学习,从专家演示中反推代价权重
  2. 系统辨识:从观测轨迹中估计动力学参数
  3. 端到端控制:把 MPC 作为策略网络的一层,整体端到端训练

13.1 Differentiable MPC(Amos et al., NeurIPS 2018) ⭐⭐⭐

Amos et al. 把 MPC 建模为参数化的 QP,然后用 OptNet 的 KKT 隐式微分技术计算梯度:

\[u^*(\theta) = \text{MPC}(\theta) = \arg\min_{u_{0:N-1}} \sum_{k=0}^{N} \ell_k(x_k, u_k; \theta) \quad \text{s.t.} \quad x_{k+1} = f(x_k, u_k)\]

对 LQR(线性动力学 + 二次代价),MPC 本身可以用 Riccati 递推精确求解。Amos 等人的关键贡献是**对 Riccati 递推做微分**——即推导"Riccati 解对代价权重的导数"。

实验:在 cartpole 上,用专家控制器生成演示轨迹,然后训练 diff-MPC 恢复专家的代价权重。结果表明 diff-MPC 比行为克隆在分布外测试中更鲁棒。

13.2 Pontryagin Differentiable Programming(Jin et al., NeurIPS 2020) ⭐⭐⭐⭐

Jin et al. 提出了一种更优雅的方法:不对 MPC 求解器做微分,而是**对 Pontryagin 最大化原理(PMP)本身做微分**。

PMP 给出最优控制的必要条件(协态方程):

\[\dot{p} = -\frac{\partial H}{\partial x}, \quad H(x, u, p) = \ell(x, u) + p^\top f(x, u)\]

对 PMP 关于参数 \(\theta\) 求微分,得到一个 auxiliary LQR 问题——这个辅助 LQR 可以用 Riccati 回传 \(O(N)\) 时间内精确求解。

PDP 相比 diff-MPC 的优势: 1. 不需要展开求解器迭代或存储 KKT 系统 2. 辅助 LQR 的计算量与前向 MPC 相当 3. 梯度精度不受求解器收敛程度的影响

PDP 的核心数学

对 Hamiltonian 系统的最优性条件(\(\dot{x} = \partial H / \partial p, \dot{p} = -\partial H / \partial x\))关于 \(\theta\) 求微分,设 \(\delta x = \partial x / \partial \theta, \delta p = \partial p / \partial \theta\),得到线性 ODE 系统:

\[\frac{d}{dt} \begin{pmatrix} \delta x \\ \delta p \end{pmatrix} = \begin{pmatrix} H_{xp} & H_{pp} \\ -H_{xx} & -H_{px} \end{pmatrix} \begin{pmatrix} \delta x \\ \delta p \end{pmatrix} + \begin{pmatrix} H_{p\theta} \\ -H_{x\theta} \end{pmatrix}\]

这是一个**线性时变 ODE**,可以用 Riccati 递推高效求解。

跨领域类比:PDP 之于 MPC,就像解析梯度之于有限差分。MPC 的有限差分做法是"扰动参数→重新求解→观察输出变化",diff-MPC 用 KKT IFT 相当于"在解的定义方程上用 IFT",而 PDP 直接在最优性的必要条件上做微分——每一步都更深入、更高效。

13.3 可微 MPC 的工程选型 ⭐⭐⭐

方法 数学基础 内存 精度 适用范围
展开求解器 + BPTT 反向 AD \(O(K \times N)\) 依赖收敛 通用但低效
KKT 隐式微分 IFT \(O(1)\) 精确(在解处) 凸 QP/SOCP
PDP PMP + 辅助 LQR \(O(N)\) 精确 连续时间 OC
灵敏度分析 NLP 后处理 \(O(1)\) 精确 IPOPT/acados

实际建议:如果你使用 acados 或 CasADi,灵敏度分析(sensitivity analysis)是最方便的——求解器求解 NLP 后,可以直接提取 \(dy^*/d\theta\),不需要额外的反向传播。

⚠️ 常见陷阱

🧠 思维陷阱:认为可微 MPC 可以替代 RL - 错误推理:"MPC 比 RL 更有结构,可微 MPC 又能端到端训练,所以不需要 RL" - 实际上:(1) 可微 MPC 需要可微的动力学模型——如果模型不可微(如接触),梯度就崩溃了;(2) MPC 的 horizon \(N\) 有限,长期行为需要 value function 估计;(3) MPC 的计算量比 RL 策略网络大几个数量级 - 正确思维:可微 MPC 适合"有精确模型 + 需要满足约束 + 代价函数需要学习"的场景

⚠️ 编程陷阱:展开 iLQR 迭代做反向传播 - 这是 §7 已经警告过的陷阱,在 MPC 场景下特别常见 - 额外的问题:iLQR 使用 line search,line search 的步长选择本身不可微(argmin 不可微) - 正确做法:用 IFT 或 PDP

练习

  1. 对 LQR 问题(\(\min \sum x_k^\top Q x_k + u_k^\top R u_k\), s.t. \(x_{k+1} = Ax_k + Bu_k\)),用 KKT 隐式微分推导 \(\partial u_0^* / \partial Q\)
  2. 解释 PDP 的辅助 LQR 的物理含义:它在"优化"什么?为什么它的解给出了参数灵敏度?
  3. (跨章综合)回顾 §8 的 OptNet。MPC 的 QP 形式化 \(\min \frac{1}{2} z^\top H z + g^\top z\) s.t. \(Cz = d, Gz \leq h\) 中,\(z = [x_{0:N}; u_{0:N-1}]\)。如果代价权重 \(Q, R\) 是参数,它们出现在 \(H\) 的哪些位置?用 §8 的 KKT 微分公式写出 \(\partial z^* / \partial Q\)

14. 可微渲染与三维表示学习 ⭐⭐⭐⭐

动机

可微渲染(differentiable rendering)是 AD 在计算机视觉和机器人感知中的重要应用。它使得"从 2D 图像反推 3D 结构"变成一个可以用梯度下降求解的优化问题。

14.1 体积渲染与 NeRF ⭐⭐⭐⭐

NeRF 的体积渲染公式(见 §10.4)的离散近似是:

\[\hat{C}(\mathbf{r}) = \sum_{i=1}^{K} T_i \alpha_i \mathbf{c}_i, \quad T_i = \prod_{j=1}^{i-1}(1 - \alpha_j), \quad \alpha_i = 1 - \exp(-\sigma_i \delta_i)\]

其中 \(\sigma_i\) 是第 \(i\) 个采样点的密度(由 MLP \(F_\theta\) 预测),\(\delta_i\) 是采样间隔,\(\mathbf{c}_i\) 是颜色。

关键的 AD 分析

\(\partial \hat{C} / \partial \sigma_i\) 通过链式法则涉及两条路径: 1. 通过 \(\alpha_i\):增大 \(\sigma_i\) 使第 \(i\) 个点更不透明,贡献更多颜色 2. 通过 \(T_{j>i}\):增大 \(\sigma_i\) 使后续点的透射率 \(T_j\) 降低,遮挡后面的颜色

这两条路径的梯度方向可能相反——当第 \(i\) 个点的颜色与最终像素颜色不匹配时,增大 \(\sigma_i\) 既增加了"错误颜色"的贡献又减少了"正确颜色"的贡献。反向 AD 自动处理了这种复杂的梯度交互。

14.2 3D Gaussian Splatting 中的 AD 挑战 ⭐⭐⭐⭐

3D Gaussian Splatting 的渲染管线涉及几个 AD 需要特殊处理的步骤:

  1. 3D→2D 投影:将 3D 高斯投影到图像平面。这一步是光滑可微的(投影是分式线性变换)
  2. 深度排序:按深度对高斯排序。排序是不连续的(相邻高斯交换顺序时发生跳变),但 Kerbl et al. 忽略了排序对参数的梯度(近似处理)
  3. Alpha compositing:前到后的 alpha 合成。结构类似于 NeRF 的离散体积渲染

Kerbl et al. 为了在 CUDA 上高效运行,实现了**手写的反向传播 CUDA kernel**——通用 AD 框架(如 PyTorch autograd)的开销在像素级并行渲染中不可接受。

与机器人学的联系:可微渲染使得"从多视角图像重建 3D 场景"变成端到端可训练的管线。在机器人抓取、导航等任务中,可微渲染提供了一种从视觉输入直接优化动作的途径——这与可微 MPC 的思想相呼应。

可微渲染在机器人学中的具体应用场景

应用 可微渲染的角色 代表工作
抓取位姿估计 从 RGB 图像反推物体 6D 位姿 NeRF-based pose estimation
导航地图构建 用 NeRF/3DGS 构建可微的场景表示 NeRF-Nav, SplaTAM
模仿学习 从视觉演示中提取 3D 信息 3D-aware imitation
视觉伺服 渲染 → 误差图像 → 梯度回传优化相机运动 photometric visual servoing

这些应用的共同模式是:可微渲染充当"3D 世界 ↔ 2D 图像"的可微桥梁,使得 3D 空间中的优化可以用 2D 图像作为监督信号。

⚠️ 常见陷阱

🧠 思维陷阱:认为可微渲染可以替代传统的几何方法 - 错误推理:"有了 NeRF/3DGS,不需要点云、网格、SLAM 了" - 实际上:(1) NeRF/3DGS 的训练需要大量多视角图像和精确的相机位姿——获取这些数据本身就需要传统方法;(2) 渲染速度虽然近年大幅提升(3DGS 可达实时),但与传统光栅化相比仍有差距;(3) 对动态场景的泛化能力有限 - 正确思维:可微渲染是传统几何方法的**补充**(提供了新的优化方式),而非替代

练习

  1. 对 NeRF 的离散体积渲染公式,手动推导 \(\partial \hat{C} / \partial \sigma_i\)。验证它包含两项(直接贡献和遮挡贡献),并解释两项的物理含义。
  2. 解释为什么 3D Gaussian Splatting 的深度排序步骤在 AD 中需要特殊处理。如果忽略排序梯度,会引入什么偏差?在什么条件下这个偏差可以忽略?
  3. 如果你要对一个 NeRF 场景做"视角优化"(找到使某个目标最大的相机位姿),梯度链条是:相机位姿 → 射线参数 → 采样点 → MLP 输出 → 体积渲染 → 目标。这个链条中每一步是否可微?如果有不可微的步骤,应该如何处理?

15. 典型例题与代码实战 ⭐⭐

15.1 例题:用 JAX 实现教学级 dual number 前向 AD

import jax.numpy as jnp

class Dual:
    """教学级 dual number 实现"""
    def __init__(self, real, dual=0.0):
        self.real = real    # 函数值
        self.dual = dual    # 导数值

    def __add__(self, other):
        if isinstance(other, Dual):
            return Dual(self.real + other.real, self.dual + other.dual)
        return Dual(self.real + other, self.dual)

    def __mul__(self, other):
        if isinstance(other, Dual):
            # 乘法法则:(a + bε)(c + dε) = ac + (ad + bc)ε
            return Dual(self.real * other.real,
                       self.real * other.dual + self.dual * other.real)
        return Dual(self.real * other, self.dual * other)

    def __truediv__(self, other):
        if isinstance(other, Dual):
            # 除法法则:(a + bε)/(c + dε) = a/c + (bc - ad)/c² ε
            return Dual(self.real / other.real,
                       (self.dual * other.real - self.real * other.dual)
                       / other.real**2)
        return Dual(self.real / other, self.dual / other)

    def __radd__(self, other):
        return Dual(other + self.real, self.dual)

    def __rmul__(self, other):
        return Dual(other * self.real, other * self.dual)

def sin(x):
    if isinstance(x, Dual):
        import math
        return Dual(math.sin(x.real), math.cos(x.real) * x.dual)
    return jnp.sin(x)

def exp(x):
    if isinstance(x, Dual):
        import math
        e = math.exp(x.real)
        return Dual(e, e * x.dual)
    return jnp.exp(x)

# 测试:f(x) = sin(x²) * exp(x)
def f(x):
    return sin(x * x) * exp(x)

# x = 1.0, seed = 1(求 f'(1))
x = Dual(1.0, 1.0)
result = f(x)
print(f"f(1.0) = {result.real:.10f}")     # 函数值
print(f"f'(1.0) = {result.dual:.10f}")    # 导数值(精确!)

# 与解析导数比较:f'(x) = [2x cos(x²) + sin(x²)] * exp(x)
import math
exact = (2*1.0*math.cos(1.0) + math.sin(1.0)) * math.exp(1.0)
print(f"解析导数 = {exact:.10f}")
print(f"误差 = {abs(result.dual - exact):.2e}")  # 应为 ~1e-16(浮点精度)

15.2 例题:用 CasADi 构建 NLP 并提取灵敏度

import casadi as ca
import numpy as np

# 单摆最优控制:min ∫u²dt  s.t. θ̈ = -g/l sin(θ) + u/(ml²)
N = 50     # 离散点数
dt = 0.05  # 时间步长

# 符号变量
theta = ca.SX.sym('theta', N+1)   # 角度轨迹
omega = ca.SX.sym('omega', N+1)   # 角速度轨迹
u = ca.SX.sym('u', N)             # 控制输入
g, l, m = 9.81, 1.0, 1.0

# 目标函数:最小化控制能量
J = ca.sum1(u**2) * dt

# 动力学约束(Euler 积分)
constraints = []
for k in range(N):
    # θ_{k+1} = θ_k + ω_k * dt
    constraints.append(theta[k+1] - theta[k] - omega[k] * dt)
    # ω_{k+1} = ω_k + (-g/l sin(θ_k) + u_k/(ml²)) * dt
    constraints.append(omega[k+1] - omega[k]
                      - (-g/l * ca.sin(theta[k]) + u[k]/(m*l**2)) * dt)

# 边界条件
constraints.append(theta[0] - 0)       # 初始角度 = 0
constraints.append(omega[0] - 0)       # 初始角速度 = 0
constraints.append(theta[N] - ca.pi)   # 终端角度 = π(摆上去)
constraints.append(omega[N] - 0)       # 终端角速度 = 0

g_vec = ca.vertcat(*constraints)
x_vec = ca.vertcat(theta, omega, u)

# 构建 NLP
nlp = {'x': x_vec, 'f': J, 'g': g_vec}
opts = {'ipopt.print_level': 0, 'print_time': 0}
solver = ca.nlpsol('solver', 'ipopt', nlp, opts)

# 求解
n_vars = 2*(N+1) + N
sol = solver(x0=np.zeros(n_vars),
             lbg=np.zeros(g_vec.shape[0]),
             ubg=np.zeros(g_vec.shape[0]))

print(f"最优控制代价: {float(sol['f']):.4f}")

# 灵敏度分析:∂J*/∂g(重力加速度的灵敏度)
# 这需要对 KKT 系统做隐式微分——CasADi 自动处理

15.3 例题:用 JAX 实现可微 QP 层(IFT 版) ⭐⭐⭐

这个例子展示了如何用隐式微分实现一个可微 QP 层。理解这段代码就等于理解了 OptNet 的核心原理。

import jax
import jax.numpy as jnp
from functools import partial

def solve_qp(Q, q):
    """
    求解无约束 QP: min 0.5 * y^T Q y + q^T y
    最优解: y* = -Q^{-1} q
    """
    return jnp.linalg.solve(Q, -q)

@partial(jax.custom_vjp)
def differentiable_qp(Q, q):
    """可微 QP 层:前向求解 + 自定义反向传播"""
    return solve_qp(Q, q)

def differentiable_qp_fwd(Q, q):
    """前向传播:求解 QP 并保存 residuals 供反向使用"""
    y_star = solve_qp(Q, q)
    return y_star, (Q, q, y_star)  # 返回结果 + 保存用于反向的数据

def differentiable_qp_bwd(res, g):
    """
    反向传播:用 IFT 计算 VJP。

    KKT 条件: F(y, Q, q) = Qy + q = 0
    IFT: dy*/dq = -Q^{-1}  (显然)
         dy*/dQ = -Q^{-1} (dy* ⊗ I)  (需要矩阵微分)

    VJP: 给定上游梯度 g = dL/dy*
         dL/dq = g^T * dy*/dq = -Q^{-1} g
         dL/dQ = -Q^{-1} g * y*^T  (通过矩阵微分得到)
    """
    Q, q, y_star = res
    # 解 KKT 线性方程: Q u = g
    u = jnp.linalg.solve(Q, g)
    # VJP 结果
    dL_dq = -u                          # dL/dq
    dL_dQ = -jnp.outer(u, y_star)       # dL/dQ(矩阵微分)
    return dL_dQ, dL_dq

differentiable_qp.defvjp(differentiable_qp_fwd, differentiable_qp_bwd)

# 测试:验证自定义 VJP 的正确性
n = 5
key = jax.random.PRNGKey(42)
Q = jax.random.normal(key, (n, n))
Q = Q @ Q.T + 0.1 * jnp.eye(n)  # 确保正定
q = jax.random.normal(jax.random.PRNGKey(1), (n,))

# 定义 loss: L = sum(y*)
def loss_fn(Q, q):
    y = differentiable_qp(Q, q)
    return jnp.sum(y)

# 用我们的 IFT VJP 计算梯度
grad_Q, grad_q = jax.grad(loss_fn, argnums=(0, 1))(Q, q)
print(f"dL/dq (IFT): {grad_q[:3]}")

# 用有限差分验证
eps = 1e-5
grad_q_fd = jnp.zeros(n)
for i in range(n):
    e_i = jnp.zeros(n).at[i].set(1.0)
    grad_q_fd = grad_q_fd.at[i].set(
        (loss_fn(Q, q + eps * e_i) - loss_fn(Q, q - eps * e_i)) / (2 * eps)
    )
print(f"dL/dq (FD):  {grad_q_fd[:3]}")
print(f"相对误差: {jnp.linalg.norm(grad_q - grad_q_fd) / jnp.linalg.norm(grad_q):.2e}")

代码解读

这段代码的核心是 differentiable_qp_bwd 函数。它实现了 §7-§8 中推导的 KKT 隐式微分:

  1. 给定上游梯度 \(\bar{y} = dL/dy^*\)
  2. 解 KKT 线性方程 \(Qu = \bar{y}\)(即 \(u = Q^{-1}\bar{y}\)
  3. VJP 结果:\(dL/dq = -u\)\(dL/dQ = -u \cdot y^{*\top}\)

注意**不需要展开 QP 求解器**——不管 \(y^*\) 是用 Cholesky 分解、共轭梯度法还是 ADMM 求解的,反向传播只需要解一个额外的 KKT 线性系统。这就是隐式微分的力量。

15.4 例题:JAX 中的 forward-over-reverse Hessian-向量积 ⭐⭐⭐

import jax
import jax.numpy as jnp

def rosenbrock(x):
    """Rosenbrock 函数:经典的非凸优化测试函数"""
    return jnp.sum(100.0 * (x[1:] - x[:-1]**2)**2 + (1.0 - x[:-1])**2)

def hvp_forward_over_reverse(f, x, v):
    """
    Hessian-向量积:H(x) @ v
    实现:forward-over-reverse
    - 内层(reverse):计算 grad(f)(x)
    - 外层(forward):计算 d/dt [grad(f)(x + t*v)]|_{t=0}
    总代价约 4-5 倍 cost(f),与 n 无关
    """
    return jax.jvp(jax.grad(f), (x,), (v,))[1]

# 测试
n = 100
x = jnp.ones(n)  # Rosenbrock 的全局最小值在 x = [1, 1, ..., 1]
v = jax.random.normal(jax.random.PRNGKey(0), (n,))

# Hessian-向量积
Hv = hvp_forward_over_reverse(rosenbrock, x, v)
print(f"H @ v 的前 5 个元素: {Hv[:5]}")

# 与完整 Hessian 的乘积比较(小 n 才能做)
if n <= 50:
    H = jax.hessian(rosenbrock)(x)
    Hv_exact = H @ v
    print(f"误差: {jnp.linalg.norm(Hv - Hv_exact):.2e}")

# 性能比较:hvp vs 完整 Hessian
import time
x_large = jnp.ones(1000)
v_large = jax.random.normal(jax.random.PRNGKey(0), (1000,))

# JIT 编译
hvp_jit = jax.jit(lambda x, v: hvp_forward_over_reverse(rosenbrock, x, v))
_ = hvp_jit(x_large, v_large)  # 预热

t0 = time.time()
for _ in range(100):
    _ = hvp_jit(x_large, v_large).block_until_ready()
t_hvp = (time.time() - t0) / 100

print(f"n=1000, HVP 时间: {t_hvp*1000:.2f} ms")
print(f"完整 Hessian 大小: {1000*1000*8/1024/1024:.1f} MB — 对比 HVP 几乎不用额外内存")

练习

  1. 修改 15.3 的代码,增加等式约束 \(Ay = b\)。KKT 矩阵变成 \(\begin{pmatrix} Q & A^\top \\ A & 0 \end{pmatrix}\),反向传播需要解什么线性系统?
  2. 用 15.4 的 HVP 实现一个 truncated Newton 法求解 Rosenbrock 函数:在每次 Newton 迭代中,用 CG + HVP 近似求解 \(Hd = -g\),而不是显式构造 Hessian。比较 CG 迭代次数 = 5, 10, 20 时的收敛速度。
  3. (跨章综合)回顾凸分析中的强凸性:强凸函数满足 \(\nabla^2 f(x) \succeq \mu I\)。解释为什么对强凸函数,CG + HVP 总是能在有限步内收敛到精确 Newton 方向。

本章小结

主题 核心要点 关键公式/概念 难度
计算图 一切 AD 的基础——程序 = DAG 中间变量 \(v_i\)、基本运算
三种微分 AD ≠ 有限差分 ≠ 符号微分 截断误差 vs 表达式膨胀
前向 AD dual number + JVP \(f(a+b\varepsilon) = f(a) + f'(a)b\varepsilon\) ⭐⭐
Ceres Jet C++ 模板前向 AD Jet<T,N> 运算符重载 ⭐⭐
反向 AD adjoint + VJP \(\bar{v}_j = \sum_i \bar{v}_i \partial v_i / \partial v_j\) ⭐⭐
Cheap Gradient \(\text{cost}(\nabla f) \leq 5 \cdot \text{cost}(f)\) Baur-Strassen 1983 ⭐⭐⭐
高阶 AD Hv = forward-over-reverse hyper-dual, Taylor 模式 ⭐⭐⭐
Checkpointing 时间换内存 \(O(\sqrt{K})\) 内存 ⭐⭐⭐
框架选型 匹配问题特性选工具 JAX/PyTorch/CasADi/Ceres/Enzyme ⭐⭐
隐函数定理 \(dy^*/dx = -(F_y)^{-1} F_x\) 压缩映射 + 全微分 ⭐⭐
可微优化层 KKT 上的 IFT OptNet / cvxpylayers ⭐⭐⭐
可微仿真 平滑→FoBG,接触→ZoBG Suh 2022 bias-variance ⭐⭐⭐
工程实践 梯度检查 + 数值稳定性 gradcheck, log-sum-exp ⭐⭐
连续伴随 \(\dot{\lambda} = -\lambda^\top \partial f/\partial z\) 终端条件 + 反向积分 ⭐⭐⭐
Neural ODE 连续深度 + 常数内存 adjoint ODE, Chen 2018 ⭐⭐⭐
DEQ 定点 + IFT \((I-\partial f/\partial z)^{-1}\) ⭐⭐⭐⭐
统一框架 \(F(x^*,\theta)=0\) Blondel 2022 ⭐⭐⭐
可微 MPC KKT IFT / PDP Amos 2018, Jin 2020 ⭐⭐⭐
可微渲染 体积渲染 + alpha compositing NeRF, 3DGS ⭐⭐⭐⭐

三大核心洞见

洞见一:AD 不是一个工具,而是"链式法则的程序化"。前向即 JVP、反向即 VJP,Jet/tape/source-to-source/LLVM-IR 只是实现形式。选前向还是反向由输入输出维度比决定,而非语言或框架。

洞见二:隐式微分是"展开式 AD"的严格优越替代。只要能把收敛点表达为 \(F(x^*, \theta) = 0\),就可用 IFT 以 \(O(1)\) 内存得到收敛解的精确梯度。这一原理统一了 DEQ、Neural ODE、OptNet、cvxpylayers、可微 MPC。

洞见三:在 RL 与机器人交叉领域,梯度的来源决定算法成败。光滑动力学下 first-order gradient 带来 10-100 倍样本效率,但接触、混沌、刚性会让它崩溃。Suh 2022 的 bias-variance 定理是选择 PPO、MPPI、SHAC 还是 diff-MPC 的数学判据。


累积项目:本章新增模块

数学库项目进度

前序章节 模块 本章
Ch1-2 向量与矩阵 向量类 + 矩阵乘法 -
Ch3 分解 LU/QR/SVD -
Ch4 最小二乘 伪逆 + 正规方程 -
本章 Dual 类 + 前向/反向 AD + 可微 QP 新增

本章新增四个模块

模块 1:Dual 类与前向 AD ⭐⭐

实现 Dual<T> 类,重载全部基本运算和数学函数: - 四则运算:+, -, *, / - 三角函数:sin, cos, tan, asin, acos, atan2 - 指数/对数:exp, log, sqrt, pow - 比较运算(仅用实部比较):<, >, ==

实现 forward_diff(f, x, v) 函数:给定函数 \(f\)、求值点 \(x\) 和方向 \(v\),用 dual number 计算 JVP \(Jv\)

验证标准:对 10 个随机函数和随机方向,JVP 结果与中心差分的相对误差 \(< 10^{-10}\)

模块 2:Tape 反向 AD ⭐⭐⭐

实现一个简化的反向 AD 引擎: - Tape 类:记录计算图(每个节点的运算类型、父节点、局部导数) - Variable 类:重载运算符,每次运算自动记录到 Tape - backward(loss) 函数:从 loss 节点开始,逆拓扑序传播 adjoint

这个模块的目标不是性能,而是**理解反向 AD 的完整工作机制**。用不超过 200 行代码实现一个"可以工作但很慢"的反向 AD 引擎。

验证标准:对 5 个测试函数,反向 AD 的梯度与模块 1 的前向 AD 完全一致(浮点精度内)。

模块 3:梯度检查工具 ⭐⭐

实现 grad_check(f, x, eps=1e-5) 函数: - 对每个分量用中心差分计算数值梯度 - 与 AD 梯度比较,输出逐分量的相对误差 - 如果最大相对误差 > 阈值(默认 \(10^{-3}\)),打印警告

这个工具将在后续所有涉及自定义 AD 规则的章节中使用。

模块 4:可微 QP 层 ⭐⭐⭐

实现 DifferentiableQP(Q, q, A, b) 类: - 前向:用 scipy 或手写 KKT 求解器求解等式约束 QP - 反向:用 KKT 隐式微分计算 VJP - 接口兼容模块 2 的 Tape 引擎

验证标准:VJP 与有限差分的相对误差 \(< 10^{-5}\)

项目整合测试:用模块 4 的可微 QP + 模块 3 的梯度检查,实现一个简单的"学习 QP 参数"实验——给定目标解 \(y_{\text{target}}\),用梯度下降优化 \(q\) 使得 \(\|y^*(Q, q) - y_{\text{target}}\|^2\) 最小化。


延伸阅读

入门级 ⭐: - Baydin et al., "Automatic Differentiation in Machine Learning: a Survey", JMLR 2018 — AD 与机器学习的最佳入门综述,30 页即可建立完整概念框架 - 3Blue1Brown 的 backpropagation 视频 — 链式法则直觉的最佳可视化 - Matthew Johnson, "JAX Autodiff Cookbook" (docs.jax.dev) — JAX 中 jvp/vjp/jacfwd/jacrev 的权威 how-to - Roger Grosse, CSC321 Backprop 讲义 — 从矩阵视角理解反向传播

核心级 ⭐⭐: - Griewank & Walther, "Evaluating Derivatives", SIAM 2008 — AD 学科圣经,必读 Ch.3-4(前向/反向基础)、Ch.12(checkpointing)和 Ch.15(隐式/迭代微分) - Amos & Kolter, "OptNet: Differentiable Optimization as a Layer", ICML 2017 — 可微优化层的开山之作 - Blondel et al., "Efficient and Modular Implicit Differentiation", NeurIPS 2022 — 隐式微分的现代统一处理,JAXopt 的数学基础 - Nocedal & Wright, "Numerical Optimization", Springer 2006 — Ch.8(计算导数)和 Ch.12(KKT 理论) - Andersson et al., "CasADi: A software framework for nonlinear optimization", Math. Prog. Comp. 2019 — CasADi 的设计与架构 - Ceres Solver 官方文档 (ceres-solver.org) — Jet 实现细节和 SLAM 应用

进阶级 ⭐⭐⭐: - Carpentier & Mansard, "Analytical Derivatives of Rigid Body Dynamics", RSS 2018 — Pinocchio 解析微分的理论基础 - Suh et al., "Do Differentiable Simulators Give Better Policy Gradients?", ICML 2022 — 可微仿真梯度的 bias-variance 分析(Outstanding Paper) - Howell et al., "Dojo: A Differentiable Physics Engine for Robotics", RSS 2022 — 精确接触梯度的引擎 - Chen et al., "Neural Ordinary Differential Equations", NeurIPS 2018 — 连续深度模型与 adjoint 方法 - Bai, Kolter, Koltun, "Deep Equilibrium Models", NeurIPS 2019 — 定点隐式微分的深度学习应用 - Agrawal et al., "Differentiable Convex Optimization Layers", NeurIPS 2019 — cvxpylayers 的数学与实现 - Chris Rackauckas, MIT 18.337 "Parallel Computing and Scientific Machine Learning" (book.sciml.ai) — adjoint sensitivity 的多种推导

研究前沿 ⭐⭐⭐⭐: - Xu et al., "SHAC", ICLR 2022 — 截断窗口 + actor-critic 的可微仿真 RL - Pineda et al., "Theseus", NeurIPS 2022 — PyTorch 上的可微非线性最小二乘 - Moses & Churavy, "Enzyme", NeurIPS 2020 — LLVM IR 级编译器 AD - Amos et al., "Differentiable MPC for End-to-end Planning and Control", NeurIPS 2018 — 可微 MPC 的开创性工作 - Jin et al., "Pontryagin Differentiable Programming", NeurIPS 2020 — PMP + 辅助 LQR 的优雅方法 - Dontchev & Rockafellar, "Implicit Functions and Solution Mappings", Springer 2014 — IFT 的变分分析版权威参考 - Kolter, Duvenaud, Johnson, NeurIPS 2020 Tutorial "Deep Implicit Layers" (implicit-layers-tutorial.org) — DEQ/ODE/OptNet 三合一讲义 - Brandon Amos 个人博客 "On Differentiable Optimization" (bamos.github.io) — OptNet 作者的系列深度博文


🔧 故障排查手册

症状 可能原因 排查步骤 相关章节
AD 导数全为零 计算图断裂或类型转换丢失 tangent 1. 检查是否有 double() 转换截断了 Jet 2. 检查 detach() 是否在梯度路径上 3. 用 gradcheck 验证 §3, §6
AD 导数与有限差分差异大(>1e-3) 自定义 VJP 有 bug 或数值不稳定 1. 逐运算拆开做 gradcheck 2. 检查是否有 log(0) 或 sqrt(0) 3. 尝试 float64 排除舍入 §10.1
反向传播 OOM 计算图过长或中间激活未释放 1. 检查序列长度 2. 使用 torch.utils.checkpoint 3. 用隐式微分替代展开 §5.3, §7
CasADi 符号爆炸(构建图极慢) 在 SX 中展开了大循环 1. 改用 MX 矩阵运算 2. 避免 Python for 循环 3. 用 ca.horzcat/vertcat 向量化 §6.2
可微仿真梯度方向错误 接触不连续导致 FoBG 偏差 1. 检查是否有 Heaviside 类运算 2. 加随机光滑化 3. 比较 FoBG 与 ZoBG §9.3
Ceres Jet 编译错误 cost function 调用了非模板化函数 1. 检查所有数学函数是否用 ceres::sin 等 2. 确保全部 template<typename T> 3. 参数块维度之和与 Jet N 一致 §3.3
隐式微分结果数值不稳定 KKT 矩阵条件数过大 1. 检查 active set 是否在切换边界 2. 添加 Tikhonov 正则化 3. 监控 \(\kappa(K)\) §7, §8
Neural ODE 伴随梯度出现 NaN 刚性 ODE 反向积分发散 1. 检查 Jacobian 特征值 2. 用隐式 ODE solver 3. 切换到 checkpointed BPTT §11.2
MuJoCo MJX jax.grad 报错 CG 求解器用 while_loop 不支持反向 1. 切换到 Newton 求解器 2. 用 fori_loop 固定迭代上界 3. 考虑用 Brax 替代 §9
DEQ 前向不收敛 \(\|\\partial f/\\partial z\| \geq 1\) 不是压缩映射 1. 检查 spectral norm 2. 添加 spectral normalization 3. 减小网络宽度或深度 §12
cvxpylayers 报 "not DPP compliant" 参数出现在非仿射位置 1. 检查哪个参数违反 DPP 2. 改写为仿射形式(如 sum_squares(P_sqrt @ x) 替代 quad_form(x, P) §8.3
Pinocchio 解析微分与 CppAD 结果不一致 坐标约定或符号约定差异 1. 对齐 body frame 约定(local vs world) 2. 检查 Jacobian 是 body 还是 spatial 3. 打印中间值逐步比较 §9.1
PDP 辅助 LQR 发散 Hamiltonian Hessian 不正定 1. 检查 \(H_{uu}\) 是否正定 2. 添加正则化 \(H_{uu} + \epsilon I\) 3. 缩短 MPC horizon 减少病态性 §13.2

使用本手册的建议:遇到问题时,先在本表中搜索匹配的"症状"。如果找不到完全匹配的条目,从"可能原因"列中找最接近的猜测,按"排查步骤"逐步定位。如果仍无法解决,回到对应章节重新理解底层原理——大多数 AD 的 bug 都源于对数学原理的误解,而非代码层面的技术问题。

排查的通用策略:对任何 AD 相关的问题,第一步始终是**梯度检查**(§10.1)。如果 AD 梯度通过了梯度检查但结果仍然不好,问题出在数学建模层面(如条件数、不连续性)而非 AD 实现层面。如果梯度检查未通过,问题出在 AD 实现(如类型转换、图断裂、自定义规则 bug)。这个二分法可以快速定位问题的根源。