什么是 SSM
状态空间模型(State Space Models,SSM),与 线性系统入门-动态系统 中我们介绍过线性时不变系统就是一个东西。
⎩⎨⎧dtdx(t)=Ax(t)+Bu(t)∈RNy(t)=Cx(t)+Du(t)∈R
其中每个矩阵是具有含义的,其中的 A 矩阵是状态矩阵(系统矩阵),它描述了系统状态是如何变化的,B 矩阵是输入矩阵描述了输入是如何影响到系统状态的。而 C 矩阵是输出矩阵,描述了当前状态是如何作用到输出量上的,D 矩阵是直接传递矩阵描述了输入是如何直接作用到输出量上的。
HiPPO(High-order Polynomial Projection Operators)
HiPPO1 是 2020 年 NIPS 上的一篇工作,长序列依赖建模的核心问题是如何通过有限的 memory 来尽可能的记住之前所有的历史信息,这个 memory 就体现在上述状态矩阵上。当前的主流序列建模模型(即 Transformer 和 RNN)存在着普遍的遗忘问题
- fixed-size context windows:Transformer 的 window size 通常是有限的,一般来说 Quadratic 的 Attention 最多建模到大约 10K 的 Token 就到计算极限了
- vanishing gradient:RNN 通过 hidden state 来存储历史信息,理论上能记住之前所有内容,但是实际上 effective memory 大概是 <1K 个 Token,可能的原因是 gradient vanishing
在线函数逼近(Online Function Approximation)
问题的设定是:
考虑一个一维函数,我们能否用一个固定大小的 representation c(t)∈RN 来拟合 f 在 [0,t](记为 f≤t) 的曲线?并且随着 t 的增加,例如从 t1 到 t2,我们可以在线的根据 c(t1) 求 c(t2) 来拟合 f≤t2
为了判断拟合的效果,我们需要一个测度 (measure) 来判定拟合出来的连续函数和原来的连续函数的相似度,并且假设对于不同的 time step x 有一个权重 μ(x)。每个 measure 都需要在函数空间里定义一个距离,即定义函数的内积:
⟨f,g⟩μ=∫0∞f(x)g(x)dμ(x)
如何用
N 维向量来 encoder
f≤t假设一组多项式正交基 G={gn}n<N 满足 ⟨gi,gj⟩μ=0(正交基是由 measure μ 决定的,不同的 μ 对应不同 G)。把 f≤t 投影到多项式正交基 G 上,让每个投影分量为 c(t) 的分量:
cn(t):=⟨f≤t,gn⟩μ(t)也就是说 c(t) 是多项式正交基的系数向量,有点像傅立叶变换后的频率上的值,每一个 cn 都对应一个频率
从离散的角度来理解的话,从 t=0 开始,每次输入一个 f(t),然后来更新 c(t) 用来 encodef≤t,接着再输入 f(t+1) 并更新 c(t+1) 来 encodef≤t+1。可以看出,c(t+1) 的值 受到 c(t) 和 f(t+1) 的影响,HiPPO 证明,在连续情况下,这个过程可以用一个一阶 ODE 来建模:
c˙(t)=A(t)c(t)+B(t)f(t)只要给定 measure μ 我们就能确定 A∈RN×N,B∈RN×1。如果按照 ODE 来求解,c(t) 就是我们要找的 optimally encodef≤t 的系数。
下面是一个实例,其中的蓝色线是 c(t),黑色线给定的时间序列,红色线是解码出来的值。


这两个 measure 都是在给定的窗口内的 uniform measures
第一个例子是 Translated Legendre Measure(LegT),它的 window size 是固定的,也就是说,它只在乎 recent history(within the window),而不在乎更早的 history。第二个例子是 scaled Legendre Measure(LegS),它的 window size 随着时间变换的,并且 window size 等于整个 history,所有的历史都同等重要。相应的,为了归一化,对每个时刻的 measure 的 scale 大小会对应缩放。
其中的 HiPPO 矩阵长这个样子:

上述连续的情形 c˙(t)=A(t)c(t)+B(t)f(t) 其实就对应着一个 state space model。由于除了连续的输入之外,还会碰到离散的输入(如文本序列),可以用离散化的操作来写成一个线性 RNN 的形式
ct+1=Aˉtct+Bˉtft
离散化的办法有很多,例如欧拉法,零阶保持等等。
采用如下公式来近似微分:
x˙=Tx(k+1)−x(k)欧拉法是一阶数值方法,取其曲线在 k 处的切线进行近似
已知一定常连续系统的状态空间方程为:
{x˙=Ax+Buy=Cx+Du由 x˙=Tx(k+1)−x(k) 可得:
x˙=Tx(k+1)−x(k)=Ax(k)+Bu(k)x(k+1)−x(k)=T[Ax(k)+Bu(k)]x(k+1)=(I+TA)x(k)+TBu(k)=Φx(k)+Gu(k)其中:Φ=I+TA;G=TB
输出方程同样可以得到:
y=Cx+Du⇒y(k)=Hx(k)+Ju(k)其中:H=C;J=D
综上述,离散化后的状态空间为:
{x(k+1)=Φx(k)+Gu(k)y(k)=Hx(k)+Ju(k)
S4(Sequences With Structured State Spaces)
并行化加速
上述 HiPPO 构建了一个方法用于计算,估计系统矩阵 A,但是在前向时依然需要类似于 RNN 的循环进行预测:
对序列进行展开:

为了并行化推理,我们 S4 引入了 SSM kernel。考虑输出 y2,有公式:
y2=Ch2=C(Aˉh1+Bˉx2)=C(Aˉ(Aˉh0+Bˉx1)+Bˉx2)=C(Aˉ(Aˉ⋅Bˉx0+Bˉx1)+Bˉx2)=C(Aˉ⋅Aˉ⋅Bˉx0+Aˉ⋅Bˉx1+Bˉx2)=C⋅AˉAˉ2⋅Bˉx0+C⋅Aˉ⋅Bˉ⋅x1+C⋅Bˉx2
由此类推,可得:
y3=CAAABx0+CAABx1+CABx2+CBx3
写成矩阵的形式可得:
y3=(CAAABCAABCABCB)x0x1x2x3
其中矩阵 A,B,C 都是常数,因此可以预先计算左侧向量值并作为卷积核。
对角化
对于上述我们将我们需要计算矩阵的高次幂依然具有很高的复杂度,我们可以尝试对其进行对角化降低计算复杂度。
(A,B,C)∼(V−1AV,V−1B,CV)如果上述 (A,B,C) 可以对角化,并保持输出不变,那么输出 y 的计算复杂度将从 O(N2) 变成 O(N)。
可以构造矩阵 A 的相似矩阵(具有同样的特征值),表示如下:
A~=1−11⋮2−3⋮3⋮⋱
即:
Ank~=⎩⎨⎧(−1)(n−k)(2k+1)k+10 if n>k if n=k if n<k
那么可以找到一个可逆矩阵:
V=(Ci+ji−j)ij=111⋮13⋮1⋮⋱
其中 V3i,i=C4i2i≈24i,因此会指数爆炸
所以 S4 将矩阵 A 表示为正规矩阵 + 低秩矩阵的形式,即:
A=V−1ΛV−PQT=V−1(Λ−(VP)(PTV−1))V
其中 V∈RN×N,Λ 是对角矩阵,P,Q∈RN×r 是低秩矩阵。(不懂这里如何得到的,逃)
Mamba
Mamba3 一个主要的贡献就是将上述时不变系统,建模成了时变系统;原本系统为:
{x˙=Ax(t)+Bμ(t)y=Cx(t)
其中的 (A,B,C) 矩阵都是 “ 固定的 “(这里固定指的是对于输入来说,对于不同的输入其矩阵一样的)
而 Mamba 将该部分进行 ” 函数化 “,利用一个网络计算该部分。上图是二者的区别。
在原本 S4 网络中其 A,B,C 是被特殊初始化后的可训练矩阵,如下图所示:

原本矩阵 A 的大小为 N×N,这里是 D×N。
这是因为矩阵 A 可被对角化,变成 N 维的,并且 S4 是 Single -input-single-output(SISO)的,所以对于输入的每一个维度都有一套 SSM 参数,所以 ×D
而 Mamba 是利用 DNN 网络输出该矩阵:
SB(x)=LinearN(x)SC(x)=LinearN(x)SΔ(x)=Broadcast D( Linear 1(x))τΔ= softplus
从而做到了与输入相关
为什么矩阵
A 不是设计成网络的形式?
这是因为 A 矩阵在进行 discretize 时 A 矩阵与 τΔ 会发生运算,所以计算后的 Aˉ 依然是与输入相关的。
并行化
上述变成时变后,我们就不能利用 SSM kernel 进行并行化。在 Mamba 中使用了一种 parallel scan 的技术。下方是 CUDA 关于该技术的介绍:
给定 n 个序列的数组
[a0,a1,…,an−1]求出序列
[1,a0,(a0⊕a1),…,(a0⊕a1⊕⋯⊕an−2)]
A Naive Parallel Scan

这个算法的效率在 O(nlog2(n))