论文地址:
https://arxiv.org/abs/2107.02192arrow-up-right
参考资料:
https://blog.csdn.net/qq_43542339/article/details/118771339arrow-up-right
利用Local Attention计算局部Attention(short-term),利用降维的方式计算全局Attention(long-term),最终达到降低时间复杂度的效果。
short-term计算方式:
假设绿色部分长宽分别为l1,l2l_1, l_2l1,l2,那么总时间复杂度为O(l1l2d×n/l1)=O(nl2d)O(l_1l_2 d \times n/l_1)=O(nl_2 d)O(l1l2d×n/l1)=O(nl2d)
long-term计算方式:
给定Q∈Rn×d,K,V∈Rm×dQ\in \mathbb R^{n\times d}, K,V\in \mathbb R^{m\times d}Q∈Rn×d,K,V∈Rm×d
Wp∈Rd×r,P=Softmax(KWp)∈Rm×rW^{p} \in \mathbb R^{d\times r},P=\mathrm{Softmax}( K W^p)\in \mathbb R^{m\times r}Wp∈Rd×r,P=Softmax(KWp)∈Rm×r
Kˉ=P⊤K∈Rr×d,Vˉ=P⊤V∈Rr×d\bar K = P^{\top} K \in \mathbb R^{r\times d}, \bar V = P^{\top} V \in \mathbb R^{r\times d}Kˉ=P⊤K∈Rr×d,Vˉ=P⊤V∈Rr×d
O=Softmax(QKˉ⊤)Vˉ∈Rn×dO=\mathrm{Softmax}(Q \bar K^{\top} ) \bar V \in \mathbb R^{n\times d}O=Softmax(QKˉ⊤)Vˉ∈Rn×d
总的时间复杂度为O((n+m)dr)O((n+m)dr)O((n+m)dr)。
融合:
记short-term对应的K,VK, VK,V分别为K1,V1∈Rw×dK_1, V_1\in \mathbb R^{w\times d}K1,V1∈Rw×d
记long-term对应的K,VK, VK,V分别为K2,V2∈Rr×dK_2, V_2\in \mathbb R^{r\times d}K2,V2∈Rr×d
O=Softmax(Q[LN1(K1):LN2(K2)]⊤)[LN1(V1):LN2(V2)]∈Rn×dO=\mathrm{Softmax}(Q [\mathrm{LN}_1(K_1): \mathrm{LN}_2(K_2)]^{\top} ) [\mathrm{LN}_1(V_1): \mathrm{LN}_2(V_2)] \in \mathbb R^{n\times d}O=Softmax(Q[LN1(K1):LN2(K2)]⊤)[LN1(V1):LN2(V2)]∈Rn×d
总时间复杂度为O(n(r+w)d)O(n(r+w)d)O(n(r+w)d)。
不变。
https://github.com/NVIDIA/transformer-lsarrow-up-right
测试了lra, lm以及imagenet,效果都很好。
单向模型中还有一些实现细节。
效果挺好的,但是整体方法感觉不算优雅,一些实现细节可以参考。
Last updated 2 years ago