Linear Complexity Randomized Self-attention Mechanism
论文地址:
整体思路以及计算方式
之前像RFA和Performer(后续统称为RFA)都是exp(q⊤v)的无偏估计,但并不是exp(q⊤v)/(∑vexp(q⊤v))的无偏估计,这偏论文的主要出发点就是解决这点,论文的整体思路如下:
指出RFA不是无偏估计,通过重要度抽样引入RA(Randomized Attention);
指出RA的计算复杂度太高,作为一个折中方案,引入LARA(Linear Randomized Attention);
RFA和重要度抽样
RFA:
如果有:
exp(x⊤y)=Eω∼N(ω;0,I)[ξ(x,ω)⊤ξ(y,ω)](1) 那么:
∑m′=1Mexp(qn⊤km′)∑m=1Mexp(qn⊤km)vm⊤≈∑m′=1M∑s=1Sξ(qn,ωs)⊤ξ(km′,ωs)∑m=1M∑s=1Sξ(qn,ωs)⊤ξ(km,ωs)vm⊤=∑s=1Sξ(qn,ωs)⊤∑m′=1Mξ(km′,ωs)∑s=1Sξ(qn,ωs)⊤∑m=1Mξ(km,ωs)vm⊤:=RFA(qn,K,V) 从这里不难看出,尽管公式(1)是exp的无偏估计,但是RFA并不是Attention的无偏估计,这里是利用了如下事实:
E[xi]=x,E[yi]=y⇒E[yixi]=yx(2) 这也是本文的主要出发点,注意到公式(2)涉及到分母,这一点是比较难处理的,因此,作者引入了重要度抽样的方法:
Ep(ω)[f(ω)]=Eg(ω)[g(ω)p(ω)f(ω)]≈S1s=1∑Sg(ωs)p(ωs)f(ωs)(3) 注意到概率分布p(ωs)一般可以写成:
p(ω)=p~(ω)/Z 而Z作为分母通常很难计算,所以公式(3)通常无法直接使用,为了消去Z,在公式(3)中取f=1:
1=Ep(ω)[1]=Eg(ω)[g(ω)p(ω)]≈S1s=1∑Sg(ωs)p(ωs)(3) 那么:
Ep(ω)[f(ω)]=Eg(ω)[g(ω)p(ω)]Eg(ω)[g(ω)p(ω)f(ω)]≈S1∑s=1SZ1g(ωs)p~(ωs)S1∑s=1SZ1g(ωs)p~(ωs)f(ωs)=∑s=1Sg(ωs)p~(ωs)∑s=1Sg(ωs)p~(ωs)f(ωs) 这样就可以消去分母Z,从而让重要度抽样的方法可计算。
RA
将之前的内容结合,最终作者得到如下结论:
SoftmaxAttn(qn,K,V)=Epn(ω)[fn(ω)](4) 其中:
p(ω)f(ω)=m=1∑MπmN(ω;qn+km,I)=ξ(qn,ω)⊤∑m′=1Mξ(km′,ω)ξ(qn,ω)⊤∑m=1Mξ(km,ω)vm⊤ 注意到这里一共涉及MN个概率分布p(ω),所以时间复杂度依然为O(MNd),并没有带来速度提升,所以后续需要解决这点。
LARA
首先引出MIS(multiple importance sampling):
Epn(ω)[fn(ω)]≈c=1∑Cαnc(ωc)gc(ωc)pn(ωc)fn(ωc)ωc∼gc(ω),c=1,…,Cc=1∑Cαnc=1 这样做的好处是可以将分布数量降低到C≪MN,通过比较复杂的推导,最终作者给出:
αnc(ωc)rnc′=∑c′=1Cgc′(ωc)qc(ωc)+rnc′−C1c=1∑Crnc′=∑n=1Nexp(qn⊤q~c′)exp(qn⊤q~c) 其中q~c是如何计算还没有完全理清,后续进行补充。
最终的计算式:
Epn(ω)[fn(ω)]p~n(ω)≈∑c=1Cαnc(ωc)qc(ωc)p~n(ωc)∑c=1Cαnc(ωc)qc(ωc)p~n(ωc)fn(ωc):=LARA(qn,K,V)=N(ω;0,I)ξ(qn,ω)⊤m=1∑Mξ(km,ω) 时间复杂度
不难看出为O(NCd2)。
训练以及loss
不变。
代码
暂无,详细的伪代码可以参考原论文。
实验以及适用场景
论文主要测试了Encoder,效果还不错,Decoder还没进行测试。
细节
实现的时候应该有不少技巧,等后续复现的时候进行讨论。
简评
理论性很强的一篇文章,但是写的很容易懂,出发点也比较明确,个人感觉比Performer这篇更值得关注。