XLNet Generalized Autoregressive Pretraining for Language Understanding 论文地址:
参考资料:
整体思路以及计算方式
XLNET给出一种新的预训练方式,结合了AR(GPT),AE(Bert)的特点。
给定句子x = [ x 1 , ⋯ , x T ] \mathbf{x}=\left[x_{1}, \cdots, x_{T}\right] x = [ x 1 , ⋯ , x T ] ,AR语言模型的目标为:
max θ log p θ ( x ) = ∑ t = 1 ⊤ log p θ ( x t ∣ x < t ) = ∑ t = 1 ⊤ log exp ( h θ ( x 1 : t − 1 ) ⊤ e ( x t ) ) ∑ x ′ exp ( h θ ( x 1 : t − 1 ) ⊤ e ( x ′ ) ) \max _{\theta} \log p_{\theta}(\mathbf{x})=\sum_{t=1}^{\top} \log p_{\theta}\left(x_{t} \mid \mathrm{x}_{<t}\right)=\sum_{t=1}^{\top} \log \frac{\exp \left(h_{\theta}\left(\mathbf{x}_{1: t-1}\right)^{\top} \mathbf e\left(x_{t}\right)\right)}{\sum_{x^{\prime}} \exp \left(h_{\theta}\left(\mathbf{x}_{1: t-1}\right)^{\top} \mathbf e\left(x^{\prime}\right)\right)} θ max log p θ ( x ) = t = 1 ∑ ⊤ log p θ ( x t ∣ x < t ) = t = 1 ∑ ⊤ log ∑ x ′ exp ( h θ ( x 1 : t − 1 ) ⊤ e ( x ′ ) ) exp ( h θ ( x 1 : t − 1 ) ⊤ e ( x t ) ) AE语言模型的目标为:
max θ log p θ ( x ‾ ∣ x ^ ) ≈ ∑ t = 1 ⊤ m t log p θ ( x t ∣ x ^ ) = ∑ t = 1 ⊤ m t log exp ( H θ ( x ^ ) t ⊤ e ( x t ) ) ∑ x ′ exp ( H θ ( x ^ ) t ⊤ e ( x ′ ) ) \max _{\theta} \log p_{\theta}(\overline{\mathbf{x}} \mid \hat{\mathbf{x}}) \approx \sum_{t=1}^{\top} m_{t} \log p_{\theta}\left(x_{t} \mid \hat{\mathbf{x}}\right)=\sum_{t=1}^{\top} m_{t} \log \frac{\exp \left(H_{\theta}(\hat{\mathbf{x}})_{t}^{\top} \mathbf e\left(x_{t}\right)\right)}{\sum_{x^{\prime}} \exp \left(H_{\theta}(\hat{\mathbf{x}})_{t}^{\top} \mathbf e\left(x^{\prime}\right)\right)} θ max log p θ ( x ∣ x ^ ) ≈ t = 1 ∑ ⊤ m t log p θ ( x t ∣ x ^ ) = t = 1 ∑ ⊤ m t log ∑ x ′ exp ( H θ ( x ^ ) t ⊤ e ( x ′ ) ) exp ( H θ ( x ^ ) t ⊤ e ( x t ) ) 其中m t = 1 m_t=1 m t = 1 表示x t x_t x t 被mask。
两者的缺点是:
AE模型在训练的时候有噪声,在测试的时候没有噪声;
为了解决这点,论文提出了Permutation Language Modeling,即对长度为T T T 的句子,考虑全部T ! T! T ! 种排列:
max θ E z ∼ Z T [ ∑ t = 1 ⊤ log p θ ( x z t ∣ x z < t ) ] \max _{\theta} \quad \mathbb{E}_{\mathrm{z} \sim \mathcal{Z}_{T}}\left[\sum_{t=1}^{\top} \log p_{\theta}\left(x_{z_{t}} \mid \mathbf{x}_{\mathrm{z}<t}\right)\right] θ max E z ∼ Z T [ t = 1 ∑ ⊤ log p θ ( x z t ∣ x z < t ) ] 其中Z T \mathcal Z_T Z T 表示长度为T T T 的全排列集合。
下一步是计算p θ ( x z t ∣ x z < t ) p_{\theta}\left(x_{z_{t}} \mid \mathrm{x}_{\mathrm{z}<t}\right) p θ ( x z t ∣ x z < t ) ,模型的计算方式为:
p θ ( X z t = x ∣ x z < t ) = exp ( e ( x ) ⊤ h θ ( x z < t ) ) ∑ x ′ exp ( e ( x ′ ) ⊤ h θ ( x z < t ) ) p_{\theta}\left(\mathbf X_{z_{t}}=x \mid \mathbf{x}_{\mathrm{z}<t}\right)=\frac{\exp \left(\mathbf e(\mathbf x)^{\top} h_{\theta}\left(\mathbf{x}_{\mathrm{z}<t}\right)\right)}{\sum_{x^{\prime}} \exp \left(\mathbf e\left(\mathbf x^{\prime}\right)^{\top} h_{\theta}\left(\mathbf{x}_{\mathrm{z}<t}\right)\right)} p θ ( X z t = x ∣ x z < t ) = ∑ x ′ exp ( e ( x ′ ) ⊤ h θ ( x z < t ) ) exp ( e ( x ) ⊤ h θ ( x z < t ) ) 但是该方法有问题,因为没有考虑z t z_t z t ,所以作者提出了如下计算方式:
p θ ( X z t = x ∣ x z < t ) = exp ( e ( x ) ⊤ g θ ( x z < t , z t ) ) ∑ x ′ exp ( e ( x ′ ) ⊤ g θ ( x z < t , z t ) ) p_{\theta}\left(\mathbf X_{z_{t}}=x \mid \mathbf{x}_{z_{<t}}\right)=\frac{\exp \left(\mathbf e(\mathbf x)^{\top} g_{\theta}\left(\mathbf{x}_{\mathrm{z}_{<t}}, z_{t}\right)\right)}{\sum_{x^{\prime}} \exp \left(\mathbf e\left(\mathbf x^{\prime}\right)^{\top} g_{\theta}\left(\mathbf{x}_{\mathrm{z}_{<t}}, z_{t}\right)\right)} p θ ( X z t = x ∣ x z < t ) = ∑ x ′ exp ( e ( x ′ ) ⊤ g θ ( x z < t , z t ) ) exp ( e ( x ) ⊤ g θ ( x z < t , z t ) ) 作者将h θ , g θ h_\theta, g_\theta h θ , g θ 分别称为content representation和query representation,计算方式为:
g z t ( m ) ← Attention ( Q = g z t ( m − 1 ) , K V = h z < t ( m − θ ) ) ( query stream: use z t but cannot see x z t ) . h z t ( m ) ← Attention ( Q = h z t ( m − 1 ) , K V = h z ≤ t ( m − 1 ) ; θ ) , ( content stream: use both z t and x z t ) . \begin{aligned} g_{z_{t}}^{(m)}& \leftarrow \operatorname{Attention} \left(\mathbf{Q}=g_{z_{t}}^{(m-1)}, \mathbf{KV}=\mathrm{h}_{\mathrm{z}_{<t}}^{(m-\theta)} \right) \quad\left(\text { query stream: use } z_{t} \text { but cannot see } x_{z_{t}}\right). \\ h_{z_{t}}^{(m)} &\leftarrow \operatorname{Attention}\left(\mathbf{Q}=h_{z_{t}}^{(m-1)}, \mathbf{KV}=\mathrm{h}_{\mathrm{z}_{\leq t}}^{(m-1)} ; \theta\right), \quad\left(\text { content stream: use both } z_{t} \text { and } x_{z_{t}}\right) . \end{aligned} g z t ( m ) h z t ( m ) ← Attention ( Q = g z t ( m − 1 ) , KV = h z < t ( m − θ ) ) ( query stream: use z t but cannot see x z t ) . ← Attention ( Q = h z t ( m − 1 ) , KV = h z ≤ t ( m − 1 ) ; θ ) , ( content stream: use both z t and x z t ) . 作者还借鉴了Transformer-XL的想法,将h h h 的计算方式修改为:
h z t ( m ) ← Attention ( Q = h z t ( m − 1 ) , K V = [ h ~ ( m − 1 ) , h z ≤ t ( m − 1 ) ] ; θ ) h_{z_{t}}^{(m)} \leftarrow \operatorname{Attention}\left(\mathbf{Q}=h_{z_{t}}^{(m-1)}, \mathbf{KV}=\left[\tilde{\mathrm{h}}^{(m-1)}, \mathrm{h}_{\mathrm{z}_{\leq t}}^{(m-1)}\right] ; \theta\right) h z t ( m ) ← Attention ( Q = h z t ( m − 1 ) , KV = [ h ~ ( m − 1 ) , h z ≤ t ( m − 1 ) ] ; θ ) 目标函数:
注意穷举全部排列显然是不现实的,所以作者将目标函数定义为:
max θ E z ∼ Z T [ log p θ ( x z > c ∣ x z ≤ c ) ] = E z ∼ Z T [ ∑ t = c + 1 ∣ z ∣ log p θ ( x z t ∣ x z < t ) ] \max _{\theta} \mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_{T}}\left[\log p_{\theta}\left(\mathbf{x}_{\mathbf{z}_{>c}} \mid \mathbf{x}_{\mathbf{z}_{\leq c}}\right)\right]=\mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_{T}}\left[\sum_{t=c+1}^{|\mathbf{z}|} \log p_{\theta}\left(x_{z_{t}} \mid \mathbf{x}_{\mathbf{z}_{<t}}\right)\right] θ max E z ∼ Z T [ log p θ ( x z > c ∣ x z ≤ c ) ] = E z ∼ Z T t = c + 1 ∑ ∣ z ∣ log p θ ( x z t ∣ x z < t ) 时间复杂度
因为是预训练任务,所以不考虑这点。
训练以及loss
已经讨论过。
代码
实验以及适用场景
从实验来看,带来了非常大的提升。
细节
暂无,需要复现之后才能了解细节。
简评
非常有意思的想法,虽然时间有点久远,但是个人觉得很值得复现。