Large Memory Layers with Product Keys

论文地址:

参考资料:

整体思路以及计算方式

对Transformer中FFN\mathrm{FFN}的改进,称为PKM\mathrm{PKM}(Product Key Memory),注意这里也有q,k,v\mathrm q,\mathrm k,\mathrm v,思路是值得借鉴的。

首先给出符号:

  • Tm\mathcal T_m表示Topm\mathrm{Top}-m

  • kiK,kiR1×d,K=n\mathrm k_i\in \mathcal K,\mathrm k_i\in \mathbb R^{1\times d}, |\mathcal K|= n

  • q(x),viR1×d\mathrm q(\mathrm x),\mathrm v_i \in \mathbb R^{1\times d}

核心为如下计算问题:

I=Tm(q(x)ki)w=Softmax((q(x)ki)iI)m(x)=iIwivi\begin{aligned} \mathcal{I} &=\mathcal{T}_{m}\left(\mathrm q(\mathrm x)^{\top}\mathrm k_{i}\right) \\ \mathrm w &=\operatorname{Softmax}\left(\left(\mathrm q(\mathrm x)^{\top}\mathrm k_{i}\right)_{i \in \mathcal{I}}\right) \\ m(\mathrm x) &=\sum_{i \in \mathcal{I}} \mathrm w_{i} \mathrm v_{i} \end{aligned}

注意该模块依然为RdRd\mathbb R^d\to \mathbb R^d的映射,所以可以类比FFN\mathrm {FFN}

分析:

  • 第一步需要计算Tm(qk)R1×m,kK\mathcal{T}_{m}(\mathrm q\mathrm k^{\top})\in \mathbb R^{1\times m},\mathrm k\in \mathcal K

    • 由于需要求出全部nn项,每一项的计算复杂度为O(d)O(d),所以总计算复杂度为O(nd)O(nd)

    • Tm\mathcal T_m操作的时间复杂度为O(mlogn)O(m\log n)

  • 第二步的时间复杂度为O(m)O(m)

  • 第三步的时间复杂度为O(md)O(md)

由于第一步是主要开销,为了提速,论文里做了如下假设:

  • kK={(c,c)cC,cC}\mathbf k\in \mathcal K=\{(\mathbf c, \mathbf c')| \mathbf c\in \mathcal C, \mathbf c'\in \mathcal C\}

    • 这里(c,c)(\mathbf c,\mathbf c')表示向量拼接,c,cR1×d/2\mathbf c,\mathbf c'\in \mathbb R^{1\times d/2}

    • c=c=n|c|=|c'|= \sqrt{n}

注意到:

qk=q[:d/2]k[:d/2]+q[d/2:]k[d/2:]q(1)(k(1))+q(2)(k(2))\begin{aligned} \mathbf q\mathbf k^{\top} &= \mathbf q[:d/2]\mathbf k[:d/2]^{\top} +\mathbf q[d/2:] \mathbf k[d/2:]^{\top} \\ &\triangleq \mathbf q^{(1)} (\mathbf k^{(1)})^{\top} + \mathbf q^{(2)} (\mathbf k^{(2)})^{\top} \end{aligned}

结合假设:

q(1)(ki(1)),q(2)(kj(2)),ki(1)C,kj(2)C\mathbf q^{(1)} (\mathbf k^{(1)}_i)^{\top}, \mathbf q^{(2)} (\mathbf k^{(2)}_j)^{\top},\mathbf k^{(1)}_i\in \mathcal C, \mathbf k^{(2)}_j\in \mathcal C'

所以:

  • 只要求出2n2\sqrt n项即可,每一项的计算复杂度为O(d/2)O(d/2),所以总计算复杂度为O(nd)O(\sqrt nd)

接着要从2n2\sqrt n项中恢复qk\mathbf q\mathbf k^{\top},作者使用的方式为:

qk={q(1)(ki(1)),q(2)(kj(2))ki(1)C,kj(2)C}\mathbf q\mathbf k^{\top}=\{\mathbf q^{(1)} (\mathbf k^{(1)}_i)^{\top},\mathbf q^{(2)} (\mathbf k^{(2)}_j)^{\top}|\mathbf k^{(1)}_i\in {\mathcal C}, \mathbf k^{(2)}_j\in \mathcal C' \}

这里一共有n×n=n\sqrt n\times \sqrt n =n个元素,从这nn个元素中进行Tm\mathcal{T}_{m}运行即可,因此总时间复杂度为

O(nd+mlogn+md+d)=O((n+m)d)O(\sqrt nd +m\log n + md +d ) = O((\sqrt n+m)d)

备注,这里假设logn<d\log n < d

时间复杂度

假设xRL×d\mathbf x\in \mathbb R^{L\times d},所以总时间复杂度为:

O(L(n+m)d)O(L (\sqrt n+m)d)

注意到FFN\mathrm{FFN}的时间复杂度为:

O(4Ld2)O(4Ld^2)

所以一般来说前者比FFN\mathrm{FFN}快。

训练以及loss

保持一致。

代码

实验以及适用场景

因为是替换FFN,所以适用于所有场景,但是这样做的动机还不明确;从实验效果来说非常不错。

细节

实现细节需要看查看官方代码。

简评

总结:

  • 思路挺特别的,而且效果出人意料的好;

  • 值得复现;

Last updated