Fast Transformers with Clustered Attention

论文地址:

整体思路以及计算方式

QQ进行聚类,从而降低时间复杂度。

输入:

  • QRn×d,KRn×d,VRn×dQ\in \mathbb R^{n\times d}, K\in \mathbb R^{n\times d}, V\in \mathbb R^{n\times d}

  • 聚类矩阵:S{0,1}n×cS\in \{0,1\}^{n\times c}

  • qjc=i=1Nsijqii=1Nsijq_{j}^{c}=\frac{\sum_{i=1}^{N} s_{i j} q_{i}}{\sum_{i=1}^{N} s_{i j}}

  • Ac=softmax(QcK)Rc×nA^{c}=\operatorname{softmax}\left({Q^{c} K^{\top}}\right)\in \mathbb R^{c\times n}

  • Oˉ=AcVRc×d\bar O=A^{c} V\in \mathbb R^{c\times d}

  • oi=j=1csijoˉjo_{i}=\sum_{j=1}^{c} s_{i j} \bar o_{j}

聚类方式见论文。

时间复杂度

O(ncd)O(ncd)

训练以及loss

不变。

代码

实验以及适用场景

作者跑了Encoder实验,Decoder部分需要适配。

细节

暂无。

简评

一个很简洁的思路,不过高效实现需要花一定的功夫,主要是聚类方式部分。

Last updated