Transformer with a Mixture of Gaussian Keys
论文地址:
整体思路以及计算方式
作者对Attention中Softmax部分利用GMM替换,最后达到了相当的效果,这里回顾下技术细节。
全篇文章的出发点是如下假设:
P(qi∣tj=1)=N(qi∣kj,σj2I) 其中等号左边的概率表示qi和kj有交互的概率(相当于Softmax中qi和kj对应的权重)。
通过该假设,以及qi,kj模长相等的假设,可以得到Softmax函数。
随后作者对上式进行推广,利用GMM可以拟合任意分布,作者假设:
P(qi∣tj=1)=r∑πjrN(qi∣kjr,σjr2I) 所以:
P(tj=1∣qi)=∑j′∑rπj′rexp(−∥qi−kj′r∥2/2σj′r2)∑rπjrexp(−∥qi−kjr∥2/2σjr2) 最后的输出为:
hi=j∑∑j′∑rπj′rexp(−∥qi−kj′r∥2/2σj′r2)∑rπjrexp(−∥qi−kjr∥2/2σjr2)vj Linear版本:
上述方法可以推广到Linear Attention,唯一的区别就是增加了权重πjr:
hi=∑j∑rπjrϕ(qi)⊤ϕ(kjr)∑j∑rπjrϕ(qi)⊤ϕ(kjr)vj=ϕ(qi)⊤∑j∑rπjrϕ(kjr)ϕ(qi)⊤∑j∑rπjrϕ(kjr)vj⊤ 学习策略:
具体的细节可以参考论文,主要是利用了EM算法。
时间复杂度
Vanilla版本为O(N2d),Linear版本为O(Nd2)。
训练以及loss
不变。
代码
实验以及适用场景
适用于所有场景,从结果来看提升并不明显。
细节
见代码。
简评
一个很好的思路,但是缺点也比较明显,性能基本没有提升,而且感觉学习的效率会降低。