Skip to main content

从 SSM 到 Mamba

· 12 min read
PuQing
AI, CVer, Pythoner, Half-stack Developer

什么是 SSM

状态空间模型(State Space Models,SSM),与 线性系统入门-动态系统 中我们介绍过线性时不变系统就是一个东西。

{dx(t)dt=Ax(t)+Bu(t)RNy(t)=Cx(t)+Du(t)R\begin{cases} \displaystyle \frac{d \boldsymbol{x}(t)}{d t}=\boldsymbol{A} \boldsymbol{x}(t)+\boldsymbol{B} u(t) \in \mathbb{R}^{N} \\ y(t)=\boldsymbol{C} \boldsymbol{x}(t)+\boldsymbol{D} u(t) \in \mathbb{R} \end{cases}

其中每个矩阵是具有含义的,其中的 A\boldsymbol{A} 矩阵是状态矩阵(系统矩阵),它描述了系统状态是如何变化的,B\boldsymbol{B} 矩阵是输入矩阵描述了输入是如何影响到系统状态的。而 C\boldsymbol{C} 矩阵是输出矩阵,描述了当前状态是如何作用到输出量上的,D\boldsymbol{D} 矩阵是直接传递矩阵描述了输入是如何直接作用到输出量上的。

HiPPO(High-order Polynomial Projection Operators)

HiPPO1 是 2020 年 NIPS 上的一篇工作,长序列依赖建模的核心问题是如何通过有限的 memory 来尽可能的记住之前所有的历史信息,这个 memory 就体现在上述状态矩阵上。当前的主流序列建模模型(即 Transformer 和 RNN)存在着普遍的遗忘问题

tip
  • fixed-size context windows:Transformer 的 window size 通常是有限的,一般来说 Quadratic 的 Attention 最多建模到大约 10K 的 Token 就到计算极限了
  • vanishing gradient:RNN 通过 hidden state 来存储历史信息,理论上能记住之前所有内容,但是实际上 effective memory 大概是 <1K<1K 个 Token,可能的原因是 gradient vanishing

在线函数逼近(Online Function Approximation)

问题的设定是:

info

考虑一个一维函数,我们能否用一个固定大小的 representation c(t)RNc(t) \in \mathbb{R}^N 来拟合 ff[0,t][0,t](记为 ftf_{\leq t}) 的曲线?并且随着 tt 的增加,例如从 t1t_{1}t2t_{2},我们可以在线的根据 c(t1)c(t_{1})c(t2)c(t_{2}) 来拟合 ft2f_{\leq t_{2}}

为了判断拟合的效果,我们需要一个测度 (measure) 来判定拟合出来的连续函数和原来的连续函数的相似度,并且假设对于不同的 time step xx 有一个权重 μ(x)\mu(x)。每个 measure 都需要在函数空间里定义一个距离,即定义函数的内积:

f,gμ=0f(x)g(x)dμ(x)\langle f, g\rangle_{\mu}=\int_{0}^{\infty} f(x) g(x) \mathrm{d} \mu(x)
如何用 NN 维向量来 encoder ftf_{\leq t}

假设一组多项式正交基 G={gn}n<N\mathcal{G}=\{g_{n}\}_{n<N} 满足 gi,gjμ=0\langle g_{i}, g_{j}\rangle_{\mu}=0(正交基是由 measure μ\mu 决定的,不同的 μ\mu 对应不同 G\mathcal{G})。把 ftf_{\leq t} 投影到多项式正交基 G\mathcal{G} 上,让每个投影分量为 c(t)c(t) 的分量:

cn(t):=ft,gnμ(t)c_{n}^{(t)}:=\left\langle f_{\leq t}, g_{n}\right\rangle_{\mu^{(t)}}

也就是说 c(t)c(t) 是多项式正交基的系数向量,有点像傅立叶变换后的频率上的值,每一个 cnc_{n} 都对应一个频率

如何计算 c(t)c(t)

从离散的角度来理解的话,从 t=0t=0 开始,每次输入一个 f(t)f(t),然后来更新 c(t)c(t) 用来 encodeftf_{\leq t},接着再输入 f(t+1)f(t+1) 并更新 c(t+1)c(t+1) 来 encodeft+1f_{\leq t+1}。可以看出,c(t+1)c(t+1) 的值受到 c(t)c(t)f(t+1)f(t+1) 的影响,HiPPO 证明,在连续情况下,这个过程可以用一个一阶 ODE 来建模:

c˙(t)=A(t)c(t)+B(t)f(t)\dot{c}(t) = A(t)c(t)+B(t)f(t)

只要给定 measure μ\mu 我们就能确定 ARN×N,BRN×1A \in \mathbb{R}^{N\times N},B\in \mathbb{R}^{N\times 1}。如果按照 ODE 来求解,c(t)c(t) 就是我们要找的 optimally encodeftf_{\leq t} 的系数。

下面是一个实例,其中的蓝色线是 c(t)c(t),黑色线给定的时间序列,红色线是解码出来的值。

两个实例

image.png

这两个 measure 都是在给定的窗口内的 uniform measures

第一个例子是 Translated Legendre Measure(LegT),它的 window size 是固定的,也就是说,它只在乎 recent history(within the window),而不在乎更早的 history。第二个例子是 scaled Legendre Measure(LegS),它的 window size 随着时间变换的,并且 window size 等于整个 history,所有的历史都同等重要。相应的,为了归一化,对每个时刻的 measure 的 scale 大小会对应缩放。 其中的 HiPPO 矩阵长这个样子:

上述连续的情形 c˙(t)=A(t)c(t)+B(t)f(t)\dot{c}(t)=A(t)c(t)+B(t)f(t) 其实就对应着一个 state space model。由于除了连续的输入之外,还会碰到离散的输入(如文本序列),可以用离散化的操作来写成一个线性 RNN 的形式

ct+1=Aˉtct+Bˉtftc_{t+1}=\bar{A}_{t} c_{t}+\bar{B}_{t} f_{t}

离散化的办法有很多,例如欧拉法,零阶保持等等。

欧拉法

采用如下公式来近似微分:

x˙=x(k+1)x(k)T\dot{x}=\frac{x(k+1)-x(k)}{T}

欧拉法是一阶数值方法,取其曲线在 kk 处的切线进行近似

已知一定常连续系统的状态空间方程为:

{x˙=Ax+Buy=Cx+Du\begin{cases} \dot{\mathbf{x}}=\mathbf{A} \mathbf{x}+\mathbf{B} u \\ y=\mathbf{C x}+\mathbf{D} u \end{cases}

x˙=x(k+1)x(k)T\dot{x}=\frac{x(k+1)-x(k)}{T} 可得:

x˙=x(k+1)x(k)T=Ax(k)+Bu(k)x(k+1)x(k)=T[Ax(k)+Bu(k)]x(k+1)=(I+TA)x(k)+TBu(k)=Φx(k)+Gu(k)\begin{array}{l} \dot{\mathbf{x}}=\frac{\mathbf{x}(k+1)-\mathbf{x}(k)}{T}=\mathbf{A} \mathbf{x}(k)+\mathbf{B} u(k) \\ \mathbf{x}(k+1)-\mathbf{x}(k)=T[\mathbf{A} \mathbf{x}(k)+\mathbf{B} u(k)] \\ \mathbf{x}(k+1)=(\mathbf{I}+T \mathbf{A}) \mathbf{x}(k)+T \mathbf{B} u(k)=\mathbf{\Phi} \mathbf{x}(k)+\mathbf{G} u(k) \end{array}

其中:Φ=I+TA;G=TB\mathbf{\Phi}=\mathbf{I}+T \mathbf{A} ; \mathbf{G}=T \mathbf{B}

输出方程同样可以得到:

y=Cx+Duy(k)=Hx(k)+Ju(k)y=\mathbf{C x}+\mathbf{D} u \Rightarrow y(k)=\mathbf{H} \mathbf{x}(k)+\mathbf{J} u(k)

其中:H=C;J=D\mathbf{H}=\mathbf{C} ; \mathbf{J}=\mathbf{D}

综上述,离散化后的状态空间为:

{x(k+1)=Φx(k)+Gu(k)y(k)=Hx(k)+Ju(k)\begin{cases} \mathbf{x}(k+1)=\boldsymbol{\Phi} \mathbf{x}(k)+\mathbf{G} u(k) \\ y(k)=\mathbf{H} \mathbf{x}(k)+\mathbf{J} u(k) \end{cases}

S4(Sequences With Structured State Spaces)

并行化加速

上述 HiPPO 构建了一个方法用于计算,估计系统矩阵 A\boldsymbol{A},但是在前向时依然需要类似于 RNN 的循环进行预测: image.png 对序列进行展开: image.png

为了并行化推理,我们 S4 引入了 SSM kernel。考虑输出 y2y_{2},有公式:

y2=Ch2=C(Aˉh1+Bˉx2)=C(Aˉ(Aˉh0+Bˉx1)+Bˉx2)=C(Aˉ(AˉBˉx0+Bˉx1)+Bˉx2)=C(AˉAˉBˉx0+AˉBˉx1+Bˉx2)=CAˉAˉ2Bˉx0+CAˉBˉx1+CBˉx2\begin{align} y_{2} & = Ch_{2} \\ &=C\left(\bar{A} h_{1}+\bar{B} x_{2}\right) \\ &=C\left(\bar{A}\left(\bar{A} h_{0}+\bar{B} x_{1}\right)+\bar{B} x_{2}\right) \\ &=C\left(\bar{A}\left(\bar{A} \cdot \bar{B} x_{0}+\bar{B} x_{1}\right)+\bar{B} x_{2}\right) \\ &=C\left(\bar{A} \cdot \bar{A} \cdot \bar{B} x_{0}+\bar{A} \cdot \bar{B} x_{1}+\bar{B} x_{2}\right) \\ &=C \cdot \bar{A} \bar{A}^{2} \cdot \bar{B} x_{0}+C \cdot \bar{A} \cdot \bar{B} \cdot x_{1}+C \cdot \bar{B} x_{2} \end{align}

由此类推,可得:

y3=CAAABx0+CAABx1+CABx2+CBx3y_{3}=\mathbf{C} \overline{\mathbf{A A} \mathbf{A B}} x_{0}+\mathbf{C} \overline{\mathbf{A A B}} x_{1}+\mathbf{C} \overline{\mathbf{A B}} x_{2}+\mathbf{C} \overline{\mathbf{B}} x_{3}

写成矩阵的形式可得:

y3=(CAAABCAABCABCB)(x0x1x2x3)y_{3} = \begin{pmatrix} \mathbf{C} \overline{\mathrm{AAAB}} & \mathbf{C} \overline{\mathrm{AAB}} & \mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{B}} & \mathbf{C} \overline{\mathbf{B}} \end{pmatrix} \begin{pmatrix} x_{0} \\ x_{1} \\ x_{2} \\ x_{3} \end{pmatrix}

其中矩阵 A,B,CA,B,C 都是常数,因此可以预先计算左侧向量值并作为卷积核。

对角化

对于上述我们将我们需要计算矩阵的高次幂依然具有很高的复杂度,我们可以尝试对其进行对角化降低计算复杂度。

定理 3.12
(A,B,C)(V1AV,V1B,CV)(\boldsymbol{A}, \boldsymbol{B}, \boldsymbol{C}) \sim\left(\boldsymbol{V}^{-1} \boldsymbol{A} \boldsymbol{V}, \boldsymbol{V}^{-1} \boldsymbol{B}, \boldsymbol{C} \boldsymbol{V}\right)

如果上述 (A,B,C)(A,B,C) 可以对角化,并保持输出不变,那么输出 yy 的计算复杂度将从 O(N2)\mathcal{O}(N^{2}) 变成 O(N)\mathcal{O}(N)

可以构造矩阵 AA 的相似矩阵(具有同样的特征值),表示如下:

A~=[112133]\tilde{\boldsymbol{A}}= \begin{bmatrix} 1 & & & \\ -1 & 2 & & \\ 1 & -3 & 3 & \\ \vdots & \vdots & \vdots & \ddots \end{bmatrix}

即:

Ank~={(1)(nk)(2k+1) if n>kk+1 if n=k0 if n<k\tilde{\boldsymbol{A}_{n k}} = \begin{cases} (-1)^{(n-k)}(2 k+1) & \text { if } n>k \\ k+1 & \text { if } n=k \\ 0 & \text { if } n<k \end{cases}

那么可以找到一个可逆矩阵:

V=(Ci+jij)ij=[111131]\boldsymbol{V}=\left(C_{i+j}^{i-j}\right)_{i j}= \begin{bmatrix} 1 & & & \\ 1 & 1 & & \\ 1 & 3 & 1 & \\ \vdots & \vdots & \vdots & \ddots \end{bmatrix}

其中 V3i,i=C4i2i24iV_{3i,i}=C_{4i}^{2i}\approx 2^{4i},因此会指数爆炸

所以 S4 将矩阵 A\boldsymbol{A} 表示为正规矩阵 + 低秩矩阵的形式,即:

A=V1ΛVPQT=V1(Λ(VP)(PTV1))V\boldsymbol{A}=\boldsymbol{V}^{-1} \boldsymbol{\Lambda} \boldsymbol{V}-\boldsymbol{P} \boldsymbol{Q}^{T}=\boldsymbol{V}^{-1}\left(\boldsymbol{\Lambda}-(\boldsymbol{V} \boldsymbol{P})\left(\boldsymbol{P}^{T} \boldsymbol{V}^{-1}\right)\right) \boldsymbol{V}

其中 VRN×N\boldsymbol{V} \in \mathbb{R}^{N\times N}Λ\Lambda 是对角矩阵,P,QRN×rP,Q \in \mathbb{R}^{N\times r} 是低秩矩阵。(不懂这里如何得到的,逃)

Mamba

时变

Mamba3 一个主要的贡献就是将上述时不变系统,建模成了时变系统;原本系统为:

{x˙=Ax(t)+Bμ(t)y=Cx(t)\begin{cases} \dot{\mathbf{x}} = \mathbf{A}x(t) + \mathbf{B}\mu(t) \\ y = \mathbf{C}x(t) \end{cases}

其中的 (A,B,C)(\mathbf{A,B,C}) 矩阵都是 “ 固定的 “(这里固定指的是对于输入来说,对于不同的输入其矩阵一样的) image.png 而 Mamba 将该部分进行 ” 函数化 “,利用一个网络计算该部分。上图是二者的区别。

在原本 S4 网络中其 A,B,CA,B,C 是被特殊初始化后的可训练矩阵,如下图所示:

image.png

注意

原本矩阵 A\mathbf{A} 的大小为 N×NN\times N,这里是 D×ND\times N

这是因为矩阵 A\mathbf{A} 可被对角化,变成 NN 维的,并且 S4 是 Single -input-single-output(SISO)的,所以对于输入的每一个维度都有一套 SSM 参数,所以 ×D\times D

而 Mamba 是利用 DNN 网络输出该矩阵:

SB(x)=LinearN(x)SC(x)=LinearN(x)SΔ(x)=Broadcast D( Linear 1(x))τΔ= softplus \begin{array}{l} S_{B}(x)=\operatorname{Linear}_{N}(x) \\ S_{C}(x)=\operatorname{Linear}_{N}(x) \\ S_{\Delta}(x)=\operatorname{Broadcast~}_{D}\left(\text { Linear }_{1}(x)\right) \\ \tau_{\Delta}=\text { softplus } \end{array}

从而做到了与输入相关

为什么矩阵 A\mathbf{A} 不是设计成网络的形式?

这是因为 A\mathbf{A} 矩阵在进行 discretize 时 A\mathbf{A} 矩阵与 τΔ\tau_{\Delta} 会发生运算,所以计算后的 Aˉ\bar{\mathbf{A}} 依然是与输入相关的。

并行化

上述变成时变后,我们就不能利用 SSM kernel 进行并行化。在 Mamba 中使用了一种 parallel scan 的技术。下方是 CUDA 关于该技术的介绍:

所有前缀和任务

给定 nn 个序列的数组

[a0,a1,,an1][a_{0},a_{1},\dots,a_{n-1}]

求出序列

[1,a0,(a0a1),,(a0a1an2)][1,a_{0},(a_{0} \oplus a_{1}),\dots,(a_{0} \oplus a_{1}\oplus \dots \oplus a_{n-2})]

A Naive Parallel Scan

这个算法的效率在 O(nlog2(n))\mathcal{O}(n\log_{2}(n))

前缀和并行算法

整个算法分为两个步骤,上扫和下扫。

在上扫过程,利用树形结构,相邻节点相加得到整个数组的和,而中间结果可以被后续过程所利用。

然后从上往下依次拼凑出前缀和。

而 Mamba 正是利用该算法实现了并行化训练

而上述的动态矩阵,并行扫描算法一起被叫做选择性扫描算法 (selective scan algorithm)

相关资料

Footnotes

  1. Gu A, Dao T, Ermon S, et al. Hippo: Recurrent memory with optimal polynomial projections[J]. Advances in neural information processing systems, 2020, 33: 1474-1487.

  2. Gu A, Goel K, Ré C. Efficiently modeling long sequences with structured state spaces[J]. arXiv preprint arXiv:2111.00396, 2021.

  3. Gu A, Dao T. Mamba: Linear-time sequence modeling with selective state spaces[J]. arXiv preprint arXiv:2312.00752, 2023.