Transformer with a Mixture of Gaussian Keys

论文地址:

整体思路以及计算方式

作者对Attention中Softmax部分利用GMM替换,最后达到了相当的效果,这里回顾下技术细节。

全篇文章的出发点是如下假设:

P(qitj=1)=N(qikj,σj2I)\mathbb{P}\left(\mathbf{q}_{i} | {t}_{j}=1\right)=\mathcal{N}\left(\mathbf{q}_{i} |\mathbf{k}_{j}, \sigma_{j}^{2} \mathbf{I}\right)

其中等号左边的概率表示qi\mathbf q_ikj\mathbf k_j有交互的概率(相当于Softmax中qi\mathbf q_ikj\mathbf k_j对应的权重)。

通过该假设,以及qi,kj\mathbf q_i, \mathbf k_j模长相等的假设,可以得到Softmax函数。

随后作者对上式进行推广,利用GMM可以拟合任意分布,作者假设:

P(qitj=1)=rπjrN(qikjr,σjr2I)\mathbb{P}\left(\mathbf {q}_{i} | {t}_{j}=1\right)=\sum_r \pi_{jr}\mathcal{N}\left(\mathbf {q}_{i} |\mathbf {k}_{jr}, \sigma_{jr}^{2} {I}\right)

所以:

P(tj=1qi)=rπjrexp(qikjr2/2σjr2)jrπjrexp(qikjr2/2σjr2)\mathbb{P}\left({t}_{j}=1 | \mathbf {q}_{i}\right)=\frac{\sum_{r} \pi_{j r} \exp \left(-\left\|\mathbf {q}_{i}-\mathbf {k}_{j r}\right\|^{2} / 2 \sigma_{j r}^{2}\right)}{\sum_{j^{\prime}} \sum_{r} \pi_{j^{\prime} r} \exp \left(-\left\|\mathbf {q}_{i}-\mathbf {k}_{j^{\prime} r}\right\|^{2} / 2 \sigma_{j^{\prime} r}^{2}\right)}

最后的输出为:

hi=j(rπjrexp(qikjr2/2σjr2)jrπjrexp(qikjr2/2σjr2))vj{h}_{i}=\sum_{j}\left(\frac{\sum_{r} \pi_{j r} \exp \left(-\left\|\mathbf {q}_{i}-\mathbf {k}_{j r}\right\|^{2} / 2 \sigma_{j r}^{2}\right)}{\sum_{j^{\prime}} \sum_{r} \pi_{j^{\prime} r} \exp \left(-\left\|\mathbf {q}_{i}-\mathbf {k}_{j^{\prime} r}\right\|^{2} / 2 \sigma_{j^{\prime} r}^{2}\right)}\right) {v}_{j}

Linear版本:

上述方法可以推广到Linear Attention,唯一的区别就是增加了权重πjr\pi_{jr}

hi=jrπjrϕ(qi)ϕ(kjr)vjjrπjrϕ(qi)ϕ(kjr)=ϕ(qi)jrπjrϕ(kjr)vjϕ(qi)jrπjrϕ(kjr){h}_{i}=\frac{\sum_{j} \sum_{r} \pi_{j r} \phi\left(\mathbf {q}_{i}\right)^{\top} \phi\left(\mathbf {k}_{j r}\right) \mathbf {v}_{j}}{\sum_{j} \sum_{r} \pi_{j r} \phi\left(\mathbf {q}_{i}\right)^{\top} \phi\left(\mathbf {k}_{j r}\right)}=\frac{\phi\left(\mathbf {q}_{i}\right)^{\top} \sum_{j} \sum_{r} \pi_{j r} \phi\left(\mathbf {k}_{j r}\right) \mathbf {v}_{j}^{\top}}{\phi\left(\mathbf {q}_{i}\right)^{\top} \sum_{j} \sum_{r} \pi_{j r} \phi\left(\mathbf {k}_{j r}\right)}

学习策略:

具体的细节可以参考论文,主要是利用了EM算法。

时间复杂度

Vanilla版本为O(N2d)O(N^2d),Linear版本为O(Nd2)O(Nd^2)

训练以及loss

不变。

代码

实验以及适用场景

适用于所有场景,从结果来看提升并不明显。

细节

见代码。

简评

一个很好的思路,但是缺点也比较明显,性能基本没有提升,而且感觉学习的效率会降低。

Last updated