Skip to main content

Reparameterization Trick

· 5 min read
PuQing
AI, CVer, Pythoner, Half-stack Developer

Motivation

假设我们有个在参数 θ\theta 下的正态分布 qq。我们想要求解下面这样一个问题

minθEq[f(x)]\min_{\theta} E_{q}[f(x)]

其中 Eq[f(x)]E_{q}[f(x)] 的意思是求满足 qq 分布下的随机变量函数 f(x)f(x) 的均值,而最外层的 minθ\min_{\theta} 则是求使得该均值最小时θ\theta

有一种做法就是直接对该期望求 θ\theta 的导数 θEq[f(x)]\nabla_{\theta} E_{q}\left[f(x)\right]

θEq[f(x)]=θqθ(x)f(x)dx=f(x)θqθ(x)qθ(x)qθ(x)dx(积分变量是x)=qθ(x)θlogqθ(x)f(x)dx(Log Derivative Trick)=Eq[f(x)θlogqθ(x)]\begin{split} \nabla_{\theta} E_{q}\left[f(x)\right] &=\displaystyle \nabla_{\theta} \int q_{\theta}(x) f(x) d x \\ &=\displaystyle \int f(x) \nabla_{\theta} q_{\theta}(x) \frac{q_{\theta}(x)}{q_{\theta}(x)} d x \quad \text{(积分变量是$x$)} \\ &=\displaystyle \int q_{\theta}(x) \nabla_{\theta} \log q_{\theta}(x) f(x) d x\quad \text{(Log Derivative Trick)} \\ &=E_{q}\left[f(x) \nabla_{\theta} \log q_{\theta}(x)\right] \end{split}
tip

上面公式告诉我们,期望的梯度,可以转化为梯度的期望。

Log Derivative Trick(This page is not published)

于是我们可以利用 Eq[f(x)θlogqθ(x)]E_{q}\left[f(x) \nabla_{\theta} \log q_{\theta}(x)\right] 去估计梯度。

warning

这样的梯度估计式称为 SFSF 估计,Score Function Estimator,在强化学习中 qq 代表着策略,那么上式就是一个基本的策略梯度,有时也叫成 Reinforce

^c5d6c1

并且上述的推导对于任意的随机变量 xx 不管是连续还是离散变量,都是通用的。这样我们可以直接从 pθ(x)p_{\theta}(x) 采样若干个点来估算损失函数的梯度了。

方差

既然上述 SFSF 估计对于连续离散都适用,为什么我们还需要重参数化呢?

主要的原因是:SFSF 估计的方差太大。

Trick

我们将随机变量 xx 视为另一个随机变量经过变换 gθ(ϵ)g_{\theta}(\epsilon) 得到的

ϵp(ϵ)x=gθ(ϵ)\begin{split} \epsilon \sim p(\epsilon) \\ x = g_{\theta}(\epsilon) \end{split}

这样我们对于原式可以写为:

Eq[f(x)]=Ep[f(gθ(ϵ))]E_{q}\left[f(x)\right]=E_{p}\left[f(g_{\theta}(\epsilon))\right]

现在我们对其求梯度:

θEq[f(x)]=θEp[f(gθ(ϵ))]=Ep[fggθ]\nabla_{\theta} E_{q}\left[f(x)\right]=\nabla_{\theta} E_{p}\left[f(g_{\theta}(\epsilon))\right]=E_{p}[\frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial \theta}]

此时 θ\theta 分布参数参与了前向过程,所以保留了 θ\theta 的梯度,使得能够优化参数 θ\theta。但是这对于随机变量 ϵ\epsilon 有什么要求呢?

info
  1. ϵ\epsilon 应该是方便计算机采样得到的
  2. gg 是可微分的

梯度估计角度

既然上述都在讲梯度估计,我们自然很关心梯度估计的稳定性,我们不如求一下上面两个公式的方差

info

为了便于求解,我们取 f(x)=x2f(x)=x^2q=N(μ,1)q=N(\mu,1) 另外 g(ϵ)=μ+xg(\epsilon) = \mu+ xp=N(0,1)p =N(0,1)

{Var[f(x)θlogqθ(x)]=μ4+14μ2+15Var[fggθ]=4\begin{cases} \operatorname{Var}[f(x) \nabla_{\theta} \log q_{\theta}(x)] = \displaystyle \mu^4+14\mu^2+15 \\ \displaystyle \operatorname{Var}[\frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial \theta}] = 4 \end{cases}

所以在 一般情况下 使用从参数化后的梯度估计方差更小,更稳定。(显然你可以找出一个反例,只是说一般情况)

并且对比两个式子:
{θEq[f(x)]=Eq[f(x)θlogqθ(x)]θEq[f(x)]=Ep[fggθ]\begin{cases} \nabla_{\theta} E_{q}\left[f(x)\right] = E_{q}\left[f(x) \nabla_{\theta} \log q_{\theta}(x)\right] \\ \nabla_{\theta} E_{q}\left[f(x)\right] = E_{p}[\frac{\partial f}{\partial g} \cdot \frac{\partial g}{\partial \theta}] \end{cases}

可以看到 SFSF 估计具有 log\log。我们知道,作为一个合理的概率分布,一般都在无穷远处(即 x\left \| x \right \|\to \infty,都会有 qθ(x)0q_{\theta}(x)\to 0,而 log\log 将远处的扰动噪声进行了一定程度的放大,所以方差会大

相关资料