什么是 SSM
状态空间模型(State Space Models,SSM),与 线性系统入门-动态系统 中我们介绍过线性时不变系统就是一个东西。
{ d x ( t ) d t = A x ( t ) + B u ( t ) ∈ R N y ( t ) = C x ( t ) + D u ( t ) ∈ R \begin{cases}
\displaystyle \frac{d \boldsymbol{x}(t)}{d t}=\boldsymbol{A} \boldsymbol{x}(t)+\boldsymbol{B} u(t) \in \mathbb{R}^{N} \\
y(t)=\boldsymbol{C} \boldsymbol{x}(t)+\boldsymbol{D} u(t) \in \mathbb{R}
\end{cases} ⎩ ⎨ ⎧ d t d x ( t ) = A x ( t ) + B u ( t ) ∈ R N y ( t ) = C x ( t ) + D u ( t ) ∈ R
其中每个矩阵是具有含义的,其中的 A \boldsymbol{A} A 矩阵是状态矩阵(系统矩阵),它描述了系统状态是如何变化的,B \boldsymbol{B} B 矩阵是输入矩阵描述了输入是如何影响到系统状态的。而 C \boldsymbol{C} C 矩阵是输出矩阵,描述了当前状态是如何作用到输出量上的,D \boldsymbol{D} D 矩阵是直接传递矩阵描述了输入是如何直接作用到输出量上的。
HiPPO(High-order Polynomial Projection Operators)
HiPPO1 是 2020 年 NIPS 上的一篇工作,长序列依赖建模的核心问题是如何通过有限的 memory 来尽可能的记住之前所有的历史信息,这个 memory 就体现在上述状态矩阵上。当前的主流序列建模模型(即 Transformer 和 RNN)存在着普遍的遗忘问题
fixed-size context windows:Transformer 的 window size 通常是有限的,一般来说 Quadratic 的 Attention 最多建模到大约 10K 的 Token 就到计算极限了
vanishing gradient:RNN 通过 hidden state 来存储历史信息,理论上能记住之前所有内容,但是实际上 effective memory 大概是 < 1 K <1K < 1 K 个 Token,可能的原因是 gradient vanishing
在线函数逼近(Online Function Approximation)
问题的设定是:
考虑一个一维函数,我们能否用一个固定大小的 representation c ( t ) ∈ R N c(t) \in \mathbb{R}^N c ( t ) ∈ R N 来拟合 f f f 在 [ 0 , t ] [0,t] [ 0 , t ] (记为 f ≤ t f_{\leq t} f ≤ t ) 的曲线?并且随着 t t t 的增加,例如从 t 1 t_{1} t 1 到 t 2 t_{2} t 2 ,我们可以在线的根据 c ( t 1 ) c(t_{1}) c ( t 1 ) 求 c ( t 2 ) c(t_{2}) c ( t 2 ) 来拟合 f ≤ t 2 f_{\leq t_{2}} f ≤ t 2
为了判断拟合的效果,我们需要一个测度 (measure) 来判定拟合出来的连续函数和原来的连续函数的相似度,并且假设对于不同的 time step x x x 有一个权重 μ ( x ) \mu(x) μ ( x ) 。每个 measure 都需要在函数空间里定义一个距离,即定义函数的内积:
⟨ f , g ⟩ μ = ∫ 0 ∞ f ( x ) g ( x ) d μ ( x ) \langle f, g\rangle_{\mu}=\int_{0}^{\infty} f(x) g(x) \mathrm{d} \mu(x) ⟨ f , g ⟩ μ = ∫ 0 ∞ f ( x ) g ( x ) d μ ( x )
如何用
N N N 维向量来 encoder
f ≤ t f_{\leq t} f ≤ t 假设一组多项式正交基 G = { g n } n < N \mathcal{G}=\{g_{n}\}_{n<N} G = { g n } n < N 满足 ⟨ g i , g j ⟩ μ = 0 \langle g_{i}, g_{j}\rangle_{\mu}=0 ⟨ g i , g j ⟩ μ = 0 (正交基是由 measure μ \mu μ 决定的,不同的 μ \mu μ 对应不同 G \mathcal{G} G )。把 f ≤ t f_{\leq t} f ≤ t 投影到多项式正交基 G \mathcal{G} G 上,让每个投影分量为 c ( t ) c(t) c ( t ) 的分量:
c n ( t ) : = ⟨ f ≤ t , g n ⟩ μ ( t ) c_{n}^{(t)}:=\left\langle f_{\leq t}, g_{n}\right\rangle_{\mu^{(t)}} c n ( t ) := ⟨ f ≤ t , g n ⟩ μ ( t ) 也就是说 c ( t ) c(t) c ( t ) 是多项式正交基的系数向量,有点像傅立叶变换后的频率上的值,每一个 c n c_{n} c n 都对应一个频率
从离散的角度来理解的话,从 t = 0 t=0 t = 0 开始,每次输入一个 f ( t ) f(t) f ( t ) ,然后来更新 c ( t ) c(t) c ( t ) 用来 encodef ≤ t f_{\leq t} f ≤ t ,接着再输入 f ( t + 1 ) f(t+1) f ( t + 1 ) 并更新 c ( t + 1 ) c(t+1) c ( t + 1 ) 来 encodef ≤ t + 1 f_{\leq t+1} f ≤ t + 1 。可以看出,c ( t + 1 ) c(t+1) c ( t + 1 ) 的值受到 c ( t ) c(t) c ( t ) 和 f ( t + 1 ) f(t+1) f ( t + 1 ) 的影响,HiPPO 证明,在连续情况下,这个过程可以用一个一阶 ODE 来建模:
c ˙ ( t ) = A ( t ) c ( t ) + B ( t ) f ( t ) \dot{c}(t) = A(t)c(t)+B(t)f(t) c ˙ ( t ) = A ( t ) c ( t ) + B ( t ) f ( t ) 只要给定 measure μ \mu μ 我们就能确定 A ∈ R N × N , B ∈ R N × 1 A \in \mathbb{R}^{N\times N},B\in \mathbb{R}^{N\times 1} A ∈ R N × N , B ∈ R N × 1 。如果按照 ODE 来求解,c ( t ) c(t) c ( t ) 就是我们要找的 optimally encodef ≤ t f_{\leq t} f ≤ t 的系数。
下面是一个实例,其中的蓝色线是 c ( t ) c(t) c ( t ) ,黑色线给定的时间序列,红色线是解码出来的值。
这两个 measure 都是在给定的窗口内的 uniform measures
第一个例子是 Translated Legendre Measure(LegT),它的 window size 是固定的,也就是说,它只在乎 recent history(within the window),而不在乎更早的 history。第二个例子是 scaled Legendre Measure(LegS),它的 window size 随着时间变换的,并且 window size 等于整个 history,所有的历史都同等重要。相应的,为了归一化,对每个时刻的 measure 的 scale 大小会对应缩放。
其中的 HiPPO 矩阵长这个样子:
上述连续的情形 c ˙ ( t ) = A ( t ) c ( t ) + B ( t ) f ( t ) \dot{c}(t)=A(t)c(t)+B(t)f(t) c ˙ ( t ) = A ( t ) c ( t ) + B ( t ) f ( t ) 其实就对应着一个 state space model。由于除了连续的输入之外,还会碰到离散的输入(如文本序列),可以用离散化的操作来写成一个线性 RNN 的形式
c t + 1 = A ˉ t c t + B ˉ t f t c_{t+1}=\bar{A}_{t} c_{t}+\bar{B}_{t} f_{t} c t + 1 = A ˉ t c t + B ˉ t f t
离散化的办法有很多,例如欧拉法,零阶保持等等。
采用如下公式来近似微分:
x ˙ = x ( k + 1 ) − x ( k ) T \dot{x}=\frac{x(k+1)-x(k)}{T} x ˙ = T x ( k + 1 ) − x ( k ) 欧拉法是一阶数值方法,取其曲线在 k k k 处的切线进行近似
已知一定常连续系统的状态空间方程为:
{ x ˙ = A x + B u y = C x + D u \begin{cases}
\dot{\mathbf{x}}=\mathbf{A} \mathbf{x}+\mathbf{B} u \\
y=\mathbf{C x}+\mathbf{D} u
\end{cases} { x ˙ = Ax + B u y = Cx + D u 由 x ˙ = x ( k + 1 ) − x ( k ) T \dot{x}=\frac{x(k+1)-x(k)}{T} x ˙ = T x ( k + 1 ) − x ( k ) 可得:
x ˙ = x ( k + 1 ) − x ( k ) T = A x ( k ) + B u ( k ) x ( k + 1 ) − x ( k ) = T [ A x ( k ) + B u ( k ) ] x ( k + 1 ) = ( I + T A ) x ( k ) + T B u ( k ) = Φ x ( k ) + G u ( k ) \begin{array}{l}
\dot{\mathbf{x}}=\frac{\mathbf{x}(k+1)-\mathbf{x}(k)}{T}=\mathbf{A} \mathbf{x}(k)+\mathbf{B} u(k) \\
\mathbf{x}(k+1)-\mathbf{x}(k)=T[\mathbf{A} \mathbf{x}(k)+\mathbf{B} u(k)] \\
\mathbf{x}(k+1)=(\mathbf{I}+T \mathbf{A}) \mathbf{x}(k)+T \mathbf{B} u(k)=\mathbf{\Phi} \mathbf{x}(k)+\mathbf{G} u(k)
\end{array} x ˙ = T x ( k + 1 ) − x ( k ) = Ax ( k ) + B u ( k ) x ( k + 1 ) − x ( k ) = T [ Ax ( k ) + B u ( k )] x ( k + 1 ) = ( I + T A ) x ( k ) + T B u ( k ) = Φx ( k ) + G u ( k ) 其中:Φ = I + T A ; G = T B \mathbf{\Phi}=\mathbf{I}+T \mathbf{A} ; \mathbf{G}=T \mathbf{B} Φ = I + T A ; G = T B
输出方程同样可以得到:
y = C x + D u ⇒ y ( k ) = H x ( k ) + J u ( k ) y=\mathbf{C x}+\mathbf{D} u \Rightarrow y(k)=\mathbf{H} \mathbf{x}(k)+\mathbf{J} u(k) y = Cx + D u ⇒ y ( k ) = Hx ( k ) + J u ( k ) 其中:H = C ; J = D \mathbf{H}=\mathbf{C} ; \mathbf{J}=\mathbf{D} H = C ; J = D
综上述,离散化后的状态空间为:
{ x ( k + 1 ) = Φ x ( k ) + G u ( k ) y ( k ) = H x ( k ) + J u ( k ) \begin{cases}
\mathbf{x}(k+1)=\boldsymbol{\Phi} \mathbf{x}(k)+\mathbf{G} u(k) \\
y(k)=\mathbf{H} \mathbf{x}(k)+\mathbf{J} u(k)
\end{cases} { x ( k + 1 ) = Φ x ( k ) + G u ( k ) y ( k ) = Hx ( k ) + J u ( k )
S4(Sequences With Structured State Spaces)
并行化加速
上述 HiPPO 构建了一个方法用于计算,估计系统矩阵 A \boldsymbol{A} A ,但是在前向时依然需要类似于 RNN 的循环进行预测:
对序列进行展开:
为了并行化推理,我们 S4 引入了 SSM kernel。考虑输出 y 2 y_{2} y 2 ,有公式:
y 2 = C h 2 = C ( A ˉ h 1 + B ˉ x 2 ) = C ( A ˉ ( A ˉ h 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ( A ˉ ⋅ B ˉ x 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ⋅ A ˉ ⋅ B ˉ x 0 + A ˉ ⋅ B ˉ x 1 + B ˉ x 2 ) = C ⋅ A ˉ A ˉ 2 ⋅ B ˉ x 0 + C ⋅ A ˉ ⋅ B ˉ ⋅ x 1 + C ⋅ B ˉ x 2 \begin{align}
y_{2} & = Ch_{2} \\
&=C\left(\bar{A} h_{1}+\bar{B} x_{2}\right) \\
&=C\left(\bar{A}\left(\bar{A} h_{0}+\bar{B} x_{1}\right)+\bar{B} x_{2}\right) \\
&=C\left(\bar{A}\left(\bar{A} \cdot \bar{B} x_{0}+\bar{B} x_{1}\right)+\bar{B} x_{2}\right) \\
&=C\left(\bar{A} \cdot \bar{A} \cdot \bar{B} x_{0}+\bar{A} \cdot \bar{B} x_{1}+\bar{B} x_{2}\right) \\
&=C \cdot \bar{A} \bar{A}^{2} \cdot \bar{B} x_{0}+C \cdot \bar{A} \cdot \bar{B} \cdot x_{1}+C \cdot \bar{B} x_{2}
\end{align} y 2 = C h 2 = C ( A ˉ h 1 + B ˉ x 2 ) = C ( A ˉ ( A ˉ h 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ( A ˉ ⋅ B ˉ x 0 + B ˉ x 1 ) + B ˉ x 2 ) = C ( A ˉ ⋅ A ˉ ⋅ B ˉ x 0 + A ˉ ⋅ B ˉ x 1 + B ˉ x 2 ) = C ⋅ A ˉ A ˉ 2 ⋅ B ˉ x 0 + C ⋅ A ˉ ⋅ B ˉ ⋅ x 1 + C ⋅ B ˉ x 2
由此类推,可得:
y 3 = C A A A B ‾ x 0 + C A A B ‾ x 1 + C A B ‾ x 2 + C B ‾ x 3 y_{3}=\mathbf{C} \overline{\mathbf{A A} \mathbf{A B}} x_{0}+\mathbf{C} \overline{\mathbf{A A B}} x_{1}+\mathbf{C} \overline{\mathbf{A B}} x_{2}+\mathbf{C} \overline{\mathbf{B}} x_{3} y 3 = C AA AB x 0 + C AAB x 1 + C AB x 2 + C B x 3
写成矩阵的形式可得:
y 3 = ( C A A A B ‾ C A A B ‾ C A ‾ B ‾ C B ‾ ) ( x 0 x 1 x 2 x 3 ) y_{3} = \begin{pmatrix}
\mathbf{C} \overline{\mathrm{AAAB}} & \mathbf{C} \overline{\mathrm{AAB}} & \mathbf{C} \overline{\mathbf{A}} \overline{\mathbf{B}} & \mathbf{C} \overline{\mathbf{B}}
\end{pmatrix} \begin{pmatrix}
x_{0} \\
x_{1} \\
x_{2} \\
x_{3}
\end{pmatrix} y 3 = ( C AAAB C AAB C A B C B ) x 0 x 1 x 2 x 3
其中矩阵 A , B , C A,B,C A , B , C 都是常数,因此可以预先计算左侧向量值并作为卷积核。
对角化
对于上述我们将我们需要计算矩阵的高次幂依然具有很高的复杂度,我们可以尝试对其进行对角化降低计算复杂度。
( A , B , C ) ∼ ( V − 1 A V , V − 1 B , C V ) (\boldsymbol{A}, \boldsymbol{B}, \boldsymbol{C}) \sim\left(\boldsymbol{V}^{-1} \boldsymbol{A} \boldsymbol{V}, \boldsymbol{V}^{-1} \boldsymbol{B}, \boldsymbol{C} \boldsymbol{V}\right) ( A , B , C ) ∼ ( V − 1 A V , V − 1 B , C V ) 如果上述 ( A , B , C ) (A,B,C) ( A , B , C ) 可以对角化,并保持输出不变,那么输出 y y y 的计算复杂度将从 O ( N 2 ) \mathcal{O}(N^{2}) O ( N 2 ) 变成 O ( N ) \mathcal{O}(N) O ( N ) 。
可以构造矩阵 A A A 的相似矩阵(具有同样的特征值),表示如下:
A ~ = [ 1 − 1 2 1 − 3 3 ⋮ ⋮ ⋮ ⋱ ] \tilde{\boldsymbol{A}}= \begin{bmatrix}
1 & & & \\
-1 & 2 & & \\
1 & -3 & 3 & \\
\vdots & \vdots & \vdots & \ddots
\end{bmatrix} A ~ = 1 − 1 1 ⋮ 2 − 3 ⋮ 3 ⋮ ⋱
即:
A n k ~ = { ( − 1 ) ( n − k ) ( 2 k + 1 ) if n > k k + 1 if n = k 0 if n < k \tilde{\boldsymbol{A}_{n k}} = \begin{cases}
(-1)^{(n-k)}(2 k+1) & \text { if } n>k \\
k+1 & \text { if } n=k \\
0 & \text { if } n<k
\end{cases} A nk ~ = ⎩ ⎨ ⎧ ( − 1 ) ( n − k ) ( 2 k + 1 ) k + 1 0 if n > k if n = k if n < k
那么可以找到一个可逆矩阵:
V = ( C i + j i − j ) i j = [ 1 1 1 1 3 1 ⋮ ⋮ ⋮ ⋱ ] \boldsymbol{V}=\left(C_{i+j}^{i-j}\right)_{i j}= \begin{bmatrix}
1 & & & \\
1 & 1 & & \\
1 & 3 & 1 & \\
\vdots & \vdots & \vdots & \ddots
\end{bmatrix} V = ( C i + j i − j ) ij = 1 1 1 ⋮ 1 3 ⋮ 1 ⋮ ⋱
其中 V 3 i , i = C 4 i 2 i ≈ 2 4 i V_{3i,i}=C_{4i}^{2i}\approx 2^{4i} V 3 i , i = C 4 i 2 i ≈ 2 4 i ,因此会指数爆炸
所以 S4 将矩阵 A \boldsymbol{A} A 表示为正规矩阵 + 低秩矩阵的形式,即 :
A = V − 1 Λ V − P Q T = V − 1 ( Λ − ( V P ) ( P T V − 1 ) ) V \boldsymbol{A}=\boldsymbol{V}^{-1} \boldsymbol{\Lambda} \boldsymbol{V}-\boldsymbol{P} \boldsymbol{Q}^{T}=\boldsymbol{V}^{-1}\left(\boldsymbol{\Lambda}-(\boldsymbol{V} \boldsymbol{P})\left(\boldsymbol{P}^{T} \boldsymbol{V}^{-1}\right)\right) \boldsymbol{V} A = V − 1 Λ V − P Q T = V − 1 ( Λ − ( V P ) ( P T V − 1 ) ) V
其中 V ∈ R N × N \boldsymbol{V} \in \mathbb{R}^{N\times N} V ∈ R N × N ,Λ \Lambda Λ 是对角矩阵,P , Q ∈ R N × r P,Q \in \mathbb{R}^{N\times r} P , Q ∈ R N × r 是低秩矩阵。(不懂这里如何得到的,逃)
Mamba
Mamba3 一个主要的贡献就是将上述时不变系统,建模成了时变系统;原本系统为:
{ x ˙ = A x ( t ) + B μ ( t ) y = C x ( t ) \begin{cases}
\dot{\mathbf{x}} = \mathbf{A}x(t) + \mathbf{B}\mu(t) \\
y = \mathbf{C}x(t)
\end{cases} { x ˙ = A x ( t ) + B μ ( t ) y = C x ( t )
其中的 ( A , B , C ) (\mathbf{A,B,C}) ( A , B , C ) 矩阵都是 “ 固定的 “(这里固定指的是对于输入来说,对于不同的输入其矩阵一样的)
而 Mamba 将该部分进行 ” 函数化 “,利用一个网络计算该部分。上图是二者的区别。
在原本 S4 网络中其 A , B , C A,B,C A , B , C 是被特殊初始化后的可训练矩阵,如下图所示:
原本矩阵 A \mathbf{A} A 的大小为 N × N N\times N N × N ,这里是 D × N D\times N D × N 。
这是因为矩阵 A \mathbf{A} A 可被对角化,变成 N N N 维的,并且 S4 是 Single -input-single-output(SISO)的,所以对于输入的每一个维度都有一套 SSM 参数,所以 × D \times D × D
而 Mamba 是利用 DNN 网络输出该矩阵:
S B ( x ) = Linear N ( x ) S C ( x ) = Linear N ( x ) S Δ ( x ) = Broadcast D ( Linear 1 ( x ) ) τ Δ = softplus \begin{array}{l}
S_{B}(x)=\operatorname{Linear}_{N}(x) \\
S_{C}(x)=\operatorname{Linear}_{N}(x) \\
S_{\Delta}(x)=\operatorname{Broadcast~}_{D}\left(\text { Linear }_{1}(x)\right) \\
\tau_{\Delta}=\text { softplus }
\end{array} S B ( x ) = Linear N ( x ) S C ( x ) = Linear N ( x ) S Δ ( x ) = Broadcast D ( Linear 1 ( x ) ) τ Δ = softplus
从而做到了与输入相关
为什么矩阵
A \mathbf{A} A 不是设计成网络的形式?
这是因为 A \mathbf{A} A 矩阵在进行 discretize 时 A \mathbf{A} A 矩阵与 τ Δ \tau_{\Delta} τ Δ 会发生运算,所以计算后的 A ˉ \bar{\mathbf{A}} A ˉ 依然是与输入相关的。
并行化
上述变成时变后,我们就不能利用 SSM kernel 进行并行化。在 Mamba 中使用了一种 parallel scan 的技术。下方是 CUDA 关于该技术的介绍:
给定 n n n 个序列的数组
[ a 0 , a 1 , … , a n − 1 ] [a_{0},a_{1},\dots,a_{n-1}] [ a 0 , a 1 , … , a n − 1 ] 求出序列
[ 1 , a 0 , ( a 0 ⊕ a 1 ) , … , ( a 0 ⊕ a 1 ⊕ ⋯ ⊕ a n − 2 ) ] [1,a_{0},(a_{0} \oplus a_{1}),\dots,(a_{0} \oplus a_{1}\oplus \dots \oplus a_{n-2})] [ 1 , a 0 , ( a 0 ⊕ a 1 ) , … , ( a 0 ⊕ a 1 ⊕ ⋯ ⊕ a n − 2 )]
A Naive Parallel Scan
这个算法的效率在 O ( n log 2 ( n ) ) \mathcal{O}(n\log_{2}(n)) O ( n log 2 ( n ))
前缀和并行算法
整个算法分为两个步骤,上扫和下扫。
在上扫过程,利用树形结构,相邻节点相加得到整个数组的和,而中间结果可以被后续过程所利用。
然后从上往下依次拼凑出前缀和。
而 Mamba 正是利用该算法实现了并行化训练
而上述的动态矩阵,并行扫描算法一起被叫做选择性扫描算法 (selective scan algorithm)
相关资料