论文地址:
参考资料:
整体思路以及计算方式
利用Local Attention计算局部Attention(short-term),利用降维的方式计算全局Attention(long-term),最终达到降低时间复杂度的效果。
short-term计算方式:
假设绿色部分长宽分别为l1,l2,那么总时间复杂度为O(l1l2d×n/l1)=O(nl2d)
long-term计算方式:
给定Q∈Rn×d,K,V∈Rm×d
Wp∈Rd×r,P=Softmax(KWp)∈Rm×r
Kˉ=P⊤K∈Rr×d,Vˉ=P⊤V∈Rr×d
O=Softmax(QKˉ⊤)Vˉ∈Rn×d
总的时间复杂度为O((n+m)dr)。
融合:
记short-term对应的K,V分别为K1,V1∈Rw×d
记long-term对应的K,V分别为K2,V2∈Rr×d
O=Softmax(Q[LN1(K1):LN2(K2)]⊤)[LN1(V1):LN2(V2)]∈Rn×d
总时间复杂度为O(n(r+w)d)。
时间复杂度
总时间复杂度为O(n(r+w)d)。
训练以及loss
不变。
代码
实验以及适用场景
测试了lra, lm以及imagenet,效果都很好。
细节
单向模型中还有一些实现细节。
简评
效果挺好的,但是整体方法感觉不算优雅,一些实现细节可以参考。