CoLT5: Faster Long-Range Transformers with Conditional Computation

论文地址:

整体思路以及计算方式

分成两个部分:

  • Attention部分使用Sparse Attention,类似于window attention加上少量global pattern,后续记为SMHA\mathrm{SMHA}

  • 在Attention和FFN部分别使用Heavy和Light模块,前者参数多,后者参数少;

计算方式如下:

  • 输入XRn×d\mathbf X\in \mathbb R^{n\times d}

  • 路由函数:su(X)=Softmax(Topk(Xu)),uRds_{\mathbf u}(\mathbf X) = \mathrm{Softmax}(\mathrm {Topk} (\mathbf X \mathbf u^{\top})), \mathbf u \in \mathbb R^d

    • Topk函数:Topk(s)Rn\mathrm{Topk}(\mathrm{s})\in \mathbb R^n,取值最大的kk个值,其余设置为-\infty

  • Attention部分:

    • X=SMHAlight(X,X)+su1(X)SMHAheavy(X,su2(X))\mathbf X= \mathrm{SMHA}_{\mathrm {light}}(\mathbf X, \mathbf X) + s_{\mathbf u_1} (\mathbf X)\mathrm{SMHA}_{\mathrm {heavy}}(\mathbf X, s_{\mathbf u_2} (\mathbf X))

  • FFN部分:

    • X=FFNlight(X)+su3(X)FFNheavy(X)\mathbf X= \mathrm{FFN}_{\mathrm {light}}(\mathbf X) + s_{\mathbf u_3} (\mathbf X)\mathrm{FFN}_{\mathrm {heavy}}(\mathbf X)

时间复杂度

见论文。

代码

简评

很工程的思路,感觉一般。

Last updated