Large Memory Layers with Product Keys
论文地址:
参考资料:
整体思路以及计算方式
对Transformer中FFN的改进,称为PKM(Product Key Memory),注意这里也有q,k,v,思路是值得借鉴的。
首先给出符号:
Tm表示Top−m
ki∈K,ki∈R1×d,∣K∣=n
q(x),vi∈R1×d
核心为如下计算问题:
Iwm(x)=Tm(q(x)⊤ki)=Softmax((q(x)⊤ki)i∈I)=i∈I∑wivi 注意该模块依然为Rd→Rd的映射,所以可以类比FFN。
分析:
第一步需要计算Tm(qk⊤)∈R1×m,k∈K,
由于需要求出全部n项,每一项的计算复杂度为O(d),所以总计算复杂度为O(nd);
Tm操作的时间复杂度为O(mlogn);
第三步的时间复杂度为O(md);
由于第一步是主要开销,为了提速,论文里做了如下假设:
k∈K={(c,c′)∣c∈C,c′∈C}
这里(c,c′)表示向量拼接,c,c′∈R1×d/2
∣c∣=∣c′∣=n
注意到:
qk⊤=q[:d/2]k[:d/2]⊤+q[d/2:]k[d/2:]⊤≜q(1)(k(1))⊤+q(2)(k(2))⊤ 结合假设:
q(1)(ki(1))⊤,q(2)(kj(2))⊤,ki(1)∈C,kj(2)∈C′ 所以:
只要求出2n项即可,每一项的计算复杂度为O(d/2),所以总计算复杂度为O(nd)
接着要从2n项中恢复qk⊤,作者使用的方式为:
qk⊤={q(1)(ki(1))⊤,q(2)(kj(2))⊤∣ki(1)∈C,kj(2)∈C′} 这里一共有n×n=n个元素,从这n个元素中进行Tm运行即可,因此总时间复杂度为
O(nd+mlogn+md+d)=O((n+m)d) 备注,这里假设logn<d。
时间复杂度
假设x∈RL×d,所以总时间复杂度为:
O(L(n+m)d) 注意到FFN的时间复杂度为:
所以一般来说前者比FFN快。
训练以及loss
保持一致。
代码
实验以及适用场景
因为是替换FFN,所以适用于所有场景,但是这样做的动机还不明确;从实验效果来说非常不错。
细节
实现细节需要看查看官方代码。
简评
总结: