XLNet Generalized Autoregressive Pretraining for Language Understanding

论文地址:

参考资料:

整体思路以及计算方式

XLNET给出一种新的预训练方式,结合了AR(GPT),AE(Bert)的特点。

给定句子x=[x1,,xT]\mathbf{x}=\left[x_{1}, \cdots, x_{T}\right],AR语言模型的目标为:

maxθlogpθ(x)=t=1logpθ(xtx<t)=t=1logexp(hθ(x1:t1)e(xt))xexp(hθ(x1:t1)e(x))\max _{\theta} \log p_{\theta}(\mathbf{x})=\sum_{t=1}^{\top} \log p_{\theta}\left(x_{t} \mid \mathrm{x}_{<t}\right)=\sum_{t=1}^{\top} \log \frac{\exp \left(h_{\theta}\left(\mathbf{x}_{1: t-1}\right)^{\top} \mathbf e\left(x_{t}\right)\right)}{\sum_{x^{\prime}} \exp \left(h_{\theta}\left(\mathbf{x}_{1: t-1}\right)^{\top} \mathbf e\left(x^{\prime}\right)\right)}

AE语言模型的目标为:

maxθlogpθ(xx^)t=1mtlogpθ(xtx^)=t=1mtlogexp(Hθ(x^)te(xt))xexp(Hθ(x^)te(x))\max _{\theta} \log p_{\theta}(\overline{\mathbf{x}} \mid \hat{\mathbf{x}}) \approx \sum_{t=1}^{\top} m_{t} \log p_{\theta}\left(x_{t} \mid \hat{\mathbf{x}}\right)=\sum_{t=1}^{\top} m_{t} \log \frac{\exp \left(H_{\theta}(\hat{\mathbf{x}})_{t}^{\top} \mathbf e\left(x_{t}\right)\right)}{\sum_{x^{\prime}} \exp \left(H_{\theta}(\hat{\mathbf{x}})_{t}^{\top} \mathbf e\left(x^{\prime}\right)\right)}

其中mt=1m_t=1表示xtx_t被mask。

两者的缺点是:

  • AE模型有独立性假设;

  • AE模型在训练的时候有噪声,在测试的时候没有噪声;

  • AR模型只能看到单侧信息;

为了解决这点,论文提出了Permutation Language Modeling,即对长度为TT的句子,考虑全部T!T!种排列:

maxθEzZT[t=1logpθ(xztxz<t)]\max _{\theta} \quad \mathbb{E}_{\mathrm{z} \sim \mathcal{Z}_{T}}\left[\sum_{t=1}^{\top} \log p_{\theta}\left(x_{z_{t}} \mid \mathbf{x}_{\mathrm{z}<t}\right)\right]

其中ZT\mathcal Z_T表示长度为TT的全排列集合。

下一步是计算pθ(xztxz<t)p_{\theta}\left(x_{z_{t}} \mid \mathrm{x}_{\mathrm{z}<t}\right),模型的计算方式为:

pθ(Xzt=xxz<t)=exp(e(x)hθ(xz<t))xexp(e(x)hθ(xz<t))p_{\theta}\left(\mathbf X_{z_{t}}=x \mid \mathbf{x}_{\mathrm{z}<t}\right)=\frac{\exp \left(\mathbf e(\mathbf x)^{\top} h_{\theta}\left(\mathbf{x}_{\mathrm{z}<t}\right)\right)}{\sum_{x^{\prime}} \exp \left(\mathbf e\left(\mathbf x^{\prime}\right)^{\top} h_{\theta}\left(\mathbf{x}_{\mathrm{z}<t}\right)\right)}

但是该方法有问题,因为没有考虑ztz_t,所以作者提出了如下计算方式:

pθ(Xzt=xxz<t)=exp(e(x)gθ(xz<t,zt))xexp(e(x)gθ(xz<t,zt))p_{\theta}\left(\mathbf X_{z_{t}}=x \mid \mathbf{x}_{z_{<t}}\right)=\frac{\exp \left(\mathbf e(\mathbf x)^{\top} g_{\theta}\left(\mathbf{x}_{\mathrm{z}_{<t}}, z_{t}\right)\right)}{\sum_{x^{\prime}} \exp \left(\mathbf e\left(\mathbf x^{\prime}\right)^{\top} g_{\theta}\left(\mathbf{x}_{\mathrm{z}_{<t}}, z_{t}\right)\right)}

作者将hθ,gθh_\theta, g_\theta分别称为content representation和query representation,计算方式为:

gzt(m)Attention(Q=gzt(m1),KV=hz<t(mθ))( query stream: use zt but cannot see xzt).hzt(m)Attention(Q=hzt(m1),KV=hzt(m1);θ),( content stream: use both zt and xzt).\begin{aligned} g_{z_{t}}^{(m)}& \leftarrow \operatorname{Attention} \left(\mathbf{Q}=g_{z_{t}}^{(m-1)}, \mathbf{KV}=\mathrm{h}_{\mathrm{z}_{<t}}^{(m-\theta)} \right) \quad\left(\text { query stream: use } z_{t} \text { but cannot see } x_{z_{t}}\right). \\ h_{z_{t}}^{(m)} &\leftarrow \operatorname{Attention}\left(\mathbf{Q}=h_{z_{t}}^{(m-1)}, \mathbf{KV}=\mathrm{h}_{\mathrm{z}_{\leq t}}^{(m-1)} ; \theta\right), \quad\left(\text { content stream: use both } z_{t} \text { and } x_{z_{t}}\right) . \end{aligned}

作者还借鉴了Transformer-XL的想法,将hh的计算方式修改为:

hzt(m)Attention(Q=hzt(m1),KV=[h~(m1),hzt(m1)];θ)h_{z_{t}}^{(m)} \leftarrow \operatorname{Attention}\left(\mathbf{Q}=h_{z_{t}}^{(m-1)}, \mathbf{KV}=\left[\tilde{\mathrm{h}}^{(m-1)}, \mathrm{h}_{\mathrm{z}_{\leq t}}^{(m-1)}\right] ; \theta\right)

目标函数:

注意穷举全部排列显然是不现实的,所以作者将目标函数定义为:

maxθEzZT[logpθ(xz>cxzc)]=EzZT[t=c+1zlogpθ(xztxz<t)]\max _{\theta} \mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_{T}}\left[\log p_{\theta}\left(\mathbf{x}_{\mathbf{z}_{>c}} \mid \mathbf{x}_{\mathbf{z}_{\leq c}}\right)\right]=\mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_{T}}\left[\sum_{t=c+1}^{|\mathbf{z}|} \log p_{\theta}\left(x_{z_{t}} \mid \mathbf{x}_{\mathbf{z}_{<t}}\right)\right]

时间复杂度

因为是预训练任务,所以不考虑这点。

训练以及loss

已经讨论过。

代码

实验以及适用场景

从实验来看,带来了非常大的提升。

细节

暂无,需要复现之后才能了解细节。

简评

非常有意思的想法,虽然时间有点久远,但是个人觉得很值得复现。

Last updated