Skip to main content

Gumbel Softmax

· 7 min read
PuQing
AI, CVer, Pythoner, Half-stack Developer

之前已经写过 Reparameterization trick,这里主要是想重新讲讲整个重参数化的逻辑。

强化学习-基本组件 中说强化学习会将动作建模一个随机变量。即:

atπ(st)a_{t} \sim \pi(\cdot \mid s_{t})

深度强化学习将会预测其动作的分布参数 θ\theta,然后在计算奖励函数时输入 ata_{t},但是问题是该 ata_{t} 是从参数 θ\theta 下分布采样得到的。也就是说这个地方的梯度无法反传。

SF估计 中说过我们可以通过对数技巧,将期望的导数转化为导数的梯度,即:

θEq[f(x)]=Eq[f(x)θlogqθ(x)]\nabla_{\theta} E_{q}\left[f(x)\right]=E_{q}\left[f(x) \nabla_{\theta} \log q_{\theta}(x)\right]

该式其实是利用采样点估计目标点的梯度,之后在 Reparameterization trick#Trick 中介绍了重参数方法,将上述需要采样的操作变形为 采样+变换\text{采样}+\text{变换} 的操作,这样分布参数 θ\theta 便通过变换的操作参与了运算,所以可以求导数。

而对于那些需要使用采样得到的中间隐变量,比如 VAE 的采样,然后解码同样需要重参数操作,特别的这里可以根据随机变量的类型区分重参数的方法,对于连续随机变量,我们添加一个变换即可完成,对于离散变量,就引出了这篇 《CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX》

离散

对于离散随机变量,深度网络输出的逻辑值 oio_{i} 表示了序号 ii 被选中的概率,当然这可以视作为一个 kk 分类模型,还是同样的问题,如果我们不需要保证这里的随机变量是采样得到的,我们可以直接使用软化后的 argmax\mathrm{argmax} 来使得此处可导 (可见 不可导函数的可导逼近),但是为了保证该变量具有随机性,我们还是需要重参数化进行保证。 1

argmaxi(logpilog(logεi))i=1k,εiU[0,1]\underset{i}{\operatorname{argmax}}\left(\log p_{i}-\log \left(-\log \varepsilon_{i}\right)\right)_{i=1}^{k}, \quad \varepsilon_{i} \sim U[0,1]

这被称作为Gumbel-Max Trick,这个可以看起来没有连续随机变量那样美观,这是因为没有一些前置知识。

累计分布函数与逆变换采样 2

^7ade6d

定理

XX 为连续型随机变量, 取值于区间 (a,b)(a, b) (可包括 ±\pm\infty 和端点), XX 的密度在 (a,b)(a,b) 上取正值, XX 的分布函数为 F(x)F(x), UU(0,1),\quad U \sim \mathrm{U}(0,1), \quadY=F1(U)F()Y=F^{-1}(U) \sim F(\cdot)3

这告诉我们使用一个均匀分布和累积分布函数就可以得到任意形式的分布。

info

服从指数分布 Exp(λ)(λ>0)\mathrm{Exp}(\lambda)(\lambda>0) 的随机变量 XX 的概率密度函数和累积分布函数为:

p(x)=λeλx,x>0F(x)=1eλx,x>0\begin{aligned} p(x) & =\lambda e^{-\lambda x}, x>0 \\ F(x) & =1-e^{-\lambda x}, x>0 \end{aligned}

反函数为

F1(u)=λ1log(1u).F^{-1}(u)=-\lambda^{-1} \log (1-u) .

所以 UU(0,1)U \sim \mathrm{U}(0,1)X=λ1log(1U)X=-\lambda^{-1} \log (1-U) 服从 Exp(λ)\operatorname{Exp}(\lambda) 。因为 1U1-UUU 同分布, 所以取 X=λ1logUX=-\lambda^{-1} \log U 也服从 Exp(λ)\operatorname{Exp}(\lambda)

定理

定理 6.2 设 XX 为离散型随机变量, 取值于集合 {a1,a2,}(a1<a2<),F(x)\left\{a_{1}, a_{2}, \ldots\right\}\left(a_{1}<a_{2}<\ldots\right), \quad F(x)XX 的分布函数, UU(0,1)U \sim \mathrm{U}(0,1) , 根据 UU 的值定义随机变量 YY

Y=ai 当且仅当 F(ai1)<UF(ai),i=1,2,\begin{array}{c} Y=a_{i} \text { 当且仅当 } F\left(a_{i-1}\right)<U \leq F\left(a_{i}\right), i=1,2, \ldots \end{array}

(定义 F(a0)=0F\left(a_{0}\right)=0) 则 YF(y)Y \sim F(y)

info

例 6.4 (几何分布随机数) 设随机变量 XX 表示在成功概率为 p(0<p<1)p(0<p <1) 的独立重复试验中首次成功所需的试验次数, 则 XX 的概率分布为

P(X=k)=pqk1,k=1,2,,(q=1p)P(X=k)=p q^{k-1}, k=1,2, \ldots,(q=1-p)

XX 服从几何分布, 记为 XGeom(p)X \sim \operatorname{Geom}(p).

UU(0,1)U\sim \mathrm{U}(0,1), 注意到

F(k)=P(Xk)=P( 在前 k 次试验中至少  次成功 )=1P( 前 k 次试验都失败 )=1qk,k=1,2,\begin{aligned} F(k) & =P(X \leq k)=P(\text { 在前 } k \text { 次试验中至少 }- \text { 次成功 }) \\ & =1-P(\text { 前 } k \text { 次试验都失败 }) \\ & =1-q^{k}, k=1,2, \ldots \end{aligned}

利用上述定理,生成 XX 的方法当且仅当 1qk1<U1qk1-q^{k-1}<U \leq 1-q^{k} 时,取 X=k,k=1,2,X=k,k=1,2,\dots

等价于

qk1U<qk1q^{k} \leq 1-U<q^{k-1}

X=min{k:qk1U}=min{k:klog(q)log(1U)}=min{k:klog(1U)log(q)}=ceil(log(1U)log(q))\begin{aligned} X & =\min \left\{k: q^{k} \leq 1-U\right\} \\ & =\min \{k: k \log (q) \leq \log (1-U)\} \\ & =\min \left\{k: k \geq \frac{\log (1-U)}{\log (q)}\right\} \\ & =\operatorname{ceil}\left(\frac{\log (1-U)}{\log (q)}\right) \end{aligned}

注意到 1U1-U 也是服从 U(0,1)\mathrm{U}(0,1) 分布的,所以只要取

X=ceil(ln(U)ln(q))X=\operatorname{ceil}\left(\frac{\ln (U)}{\ln (q)}\right)

XX 服从几何分布

Gumel

终于可以回到我们的问题了,Gumel(μ,β)\mathrm{Gumel}(\mu,\beta) 分布的累积分布函数为:

F(x;μ,β)=ee(xμ)/βF(x ; \mu, \beta)=e^{-e^{-(x-\mu) / \beta}}

所以其反函数为

F1(y,μ,β)=βlog(eμβlog(1x))F^{-1}(y,\mu,\beta) = - \beta \log{\left(e^{- \frac{\mu}{\beta}} \log{\left(\frac{1}{x} \right)} \right)}

当然,在 Gumbel max\mathrm{Gumbel \ max } 中使用的是标准 Gumbel\mathrm{Gumbel} 分布,所以 μ=0,β=1\mu=0,\beta=1,所以上式化简为:

F1(y)=log(log(x))F^{-1}(y) = -\log \left( -\log(x) \right)

所以 Gumbel Max Trick\mathrm{Gumbel\ Max\ Trick} 就是:

z=argmaxi(log(πi)+gi),z=\operatorname{argmax}_{i}\left(\log \left(\pi_{i}\right)+g_{i}\right),

其中 gi=log(log(ui)),uiU(0,1)g_{i}=-\log \left(-\log \left(u_{i}\right)\right), u_{i} \sim U(0,1),这一项就是从 Gumbel\mathrm{Gumbel} 分布采样得到的噪声。

所以也就是说相当于为每个逻辑值添加了一定的噪声。但是 argmax\arg \max 也不是可导的,所以我们在得再软化一下。

由此,我们得到 Gumbel Max\mathrm{Gumbel\ Max} 的光滑近似版本——Gumbel Softmax\mathrm{Gumbel\ Softmax}

softmax((logpilog(logεi))/τ)i=1k,εiU[0,1]\operatorname{softmax}\left(\left(\log p_{i}-\log \left(-\log \varepsilon_{i}\right)\right) / \tau\right)_{i=1}^{k}, \quad \varepsilon_{i} \sim U[0,1]

其中参数 τ>0\tau>0 称为退火参数, 它越小输出结果就越接近 one hot\mathrm{one\ hot} 形式 (但同时梯度消失就越严重)。

问题

为什么要使用 Gumbel 分布,而不是其他的分布?

在进行变换后如何保证与变换前的概率一致?

引用

Footnotes

  1. 漫谈重参数:从正态分布到Gumbel Softmax - 科学空间|Scientific Spaces

  2. Inverse transform sampling - Wikipedia

  3. 6 非均匀随机数生成 | 统计计算