Long-Short Transformer: Efficient Transformers for Language and Vision

论文地址:

参考资料:

整体思路以及计算方式

利用Local Attention计算局部Attention(short-term),利用降维的方式计算全局Attention(long-term),最终达到降低时间复杂度的效果。

short-term计算方式:

假设绿色部分长宽分别为l1,l2l_1, l_2,那么总时间复杂度为O(l1l2d×n/l1)=O(nl2d)O(l_1l_2 d \times n/l_1)=O(nl_2 d)

long-term计算方式:

  • 给定QRn×d,K,VRm×dQ\in \mathbb R^{n\times d}, K,V\in \mathbb R^{m\times d}

  • WpRd×r,P=Softmax(KWp)Rm×rW^{p} \in \mathbb R^{d\times r},P=\mathrm{Softmax}( K W^p)\in \mathbb R^{m\times r}

  • Kˉ=PKRr×d,Vˉ=PVRr×d\bar K = P^{\top} K \in \mathbb R^{r\times d}, \bar V = P^{\top} V \in \mathbb R^{r\times d}

  • O=Softmax(QKˉ)VˉRn×dO=\mathrm{Softmax}(Q \bar K^{\top} ) \bar V \in \mathbb R^{n\times d}

总的时间复杂度为O((n+m)dr)O((n+m)dr)

融合:

  • 记short-term对应的K,VK, V分别为K1,V1Rw×dK_1, V_1\in \mathbb R^{w\times d}

  • 记long-term对应的K,VK, V分别为K2,V2Rr×dK_2, V_2\in \mathbb R^{r\times d}

  • O=Softmax(Q[LN1(K1):LN2(K2)])[LN1(V1):LN2(V2)]Rn×dO=\mathrm{Softmax}(Q [\mathrm{LN}_1(K_1): \mathrm{LN}_2(K_2)]^{\top} ) [\mathrm{LN}_1(V_1): \mathrm{LN}_2(V_2)] \in \mathbb R^{n\times d}

总时间复杂度为O(n(r+w)d)O(n(r+w)d)

时间复杂度

总时间复杂度为O(n(r+w)d)O(n(r+w)d)

训练以及loss

不变。

代码

实验以及适用场景

测试了lra, lm以及imagenet,效果都很好。

细节

单向模型中还有一些实现细节。

简评

效果挺好的,但是整体方法感觉不算优雅,一些实现细节可以参考。

Last updated