Is Attention Better Than Matrix Decomposition

Is Attention Better Than Matrix Decomposition?

论文地址:

参考资料:

整体思路以及计算方式

思路是利用矩阵分解代替Attention。

假设XRd×n\mathbf X\in \mathbb R^{d\times n}可以分解为如下形式:

X=X+E=DC+E\mathbf {X}=\overline{\mathbf {X}}+\mathbf {E}=\mathbf {D}\mathbf C+\mathbf {E}

这些变量满足如下条件(E\mathbf E表示噪声):

XRd×nERd×nDRd×rCRr×n\begin{aligned} \overline{\mathbf {X}}& \in \mathbb{R}^{d \times n} \\ \mathbf {E}& \in \mathbb{R}^{d \times n} \\ \mathbf D&\in \mathbb R^{d\times r}\\ \mathbf C&\in \mathbb R^{r\times n}\\ \end{aligned}

使用流程为通过某种方式计算D,C\mathbf D,\mathbf C,最后输出DC\mathbf D\mathbf C

作者给出了两种方式计算D,C\mathbf D,\mathbf C

  • Soft VQ

    • for k=1,,Kk=1,\ldots, K:

      • CSoftmax(1Tcosine(D,X))\mathbf {C} \leftarrow \operatorname{Softmax}\left(\frac{1}{T} \operatorname{cosine}(\mathbf {D},\mathbf {X})\right)

      • DXCdiag(C1n)1\mathbf {D} \leftarrow \mathbf {X} \mathbf {C}^{\top} \operatorname{diag}\left(\mathbf {C} \mathbf{1}_{n}\right)^{-1}

    • return X=DC\overline {\mathbf X} =\mathbf D\mathbf C

  • NMF with MU

    • for k=1,,Kk=1,\ldots, K:

      • CijCij(DX)ij(DDC)ij{\mathbf C}_{i j} \leftarrow {\mathbf C}_{i j} \frac{\left(\mathbf {D}^{\top}\mathbf X\right)_{i j}}{\left(\mathbf {D}^{\top} \mathbf {D C}\right)_{i j}}

      • DijDij(XC)ij(DCC)ij\mathbf {D}_{i j} \leftarrow \mathbf {D}_{i j} \frac{\left(\mathbf {X} \mathbf C^{\top}\right)_{i j}}{\left(\mathbf {D C} \mathbf {C}^{\top}\right)_{i j}}

    • return X=DC\overline{\mathbf X} =\mathbf {DC}

时间复杂度

Soft VQ时间复杂度:

  • CSoftmax(1Tcosine(D,X))\mathbf {C} \leftarrow \operatorname{Softmax}\left(\frac{1}{T} \operatorname{cosine}(\mathbf {D},\mathbf {X})\right),所以时间复杂度为O(nrd)O(nrd)

    • cosine(D,X)\operatorname{cosine}(\mathbf {D},\mathbf {X})需要计算DX\mathbf D^{\top} \mathbf X,即r×d,d×nr×nr\times d,d\times n\to r\times n,所以时间复杂度为O(nrd)O(nrd)

    • Softmax\operatorname{Softmax}r×nr×nr\times n \to r\times n,所以时间复杂度为所以时间复杂度为O(nr)O(nr)

    • 总时间复杂度为O(nrd)O(nrd)

  • DXCdiag(C1n)1:\mathbf {D} \leftarrow \mathbf {X} \mathbf {C}^{\top} \operatorname{diag}\left(\mathbf {C} \mathbf{1}_{n}\right)^{-1}:

    • diag(C1n)1:r×n,n×1r×1r×r\operatorname{diag}\left(\mathbf {C} \mathbf{1}_{n}\right)^{-1}:r\times n,n\times 1 \to r\times 1 \to r\times r ,时间复杂度为O(nr)O(nr)

    • XC:d×n,n×rd×r\mathbf {X} \mathbf {C}^{\top}:d\times n, n\times r \to d\times r,时间复杂度为O(nrd)O(nrd)

    • Cdiag(C1n)1:d×r,r×rr×r\mathbf {C}^{\top} \operatorname{diag}\left(\mathbf {C} \mathbf{1}_{n}\right)^{-1}: d\times r, r\times r \to r\times r,时间复杂度为O(dr2)O(dr^2)

  • 所以时间复杂度为O(nrd+dr2)O(nrd+ dr^2)

  • 循环KK次,时间复杂度为O(K(nrd+dr2))O(K(nrd+ dr^2))

  • X=DC:d×r,r×nd×n\overline {\mathbf X} =\mathbf D\mathbf C: d\times r , r\times n \to d\times n,时间复杂度为O(nrd)O(nrd)

  • 总时间复杂度为O((K+1)nrd+Kdr2)O((K+1)nrd + Kdr^2)

NMF with MU时间复杂度:

  • CijCij(DX)ij(DDC)ij\mathbf {C}_{i j} \leftarrow \mathbf {C}_{i j} \frac{\left(\mathbf {D}^{\top} \mathbf X\right)_{i j}}{\left(\mathbf {D}^{\top} \mathbf {D C}\right)_{i j}}

    • DX:r×d,d×nr×n\mathbf {D}^{\top} \mathbf X: r\times d, d\times n\to r\times n,时间复杂度为O(nrd)O(nrd)

    • DDC\mathbf {D}^{\top} \mathbf {D C}

      • 先计算DC\mathbf {DC},再计算DDC\mathbf {D}^{\top} \mathbf {D C}

        • DC:d×r,r×nd×n\mathbf {DC}:d\times r, r\times n \to d\times n,时间复杂度为O(nrd)O(nrd)

        • DDC:r×d,d×nr×n\mathbf {D}^{\top} \mathbf {D C}: r\times d, d\times n \to r\times n,时间复杂度为O(nrd)O(nrd)

      • 先计算DD\mathbf {D}^{\top}\mathbf D,再计算DDC\mathbf {D}^{\top} \mathbf {D C}

        • DD:r×d,d×rr×r\mathbf {D}^{\top}\mathbf D: r\times d, d\times r \to r\times r ,时间复杂度为O(dr2)O(dr^2)

        • 再计算DDC:r×r,r×nr×n\mathbf {D}^{\top} \mathbf {D C}:r\times r, r\times n\to r\times n ,时间复杂度为O(nrd)O(nrd)

      • 一般r<nr<n,所以选择第二种算法,时间复杂度为O(nrd+dr2)O(nrd+dr^2)

    • 两次element wise乘法/除法:r×nr×nr×nr\times n \to r\times n \to r\times n,时间复杂度为O(nr)O(nr)

    • 循环KK次,时间复杂度为O(K(nrd+dr2))O(K(nrd+dr^2))

    • X=DC:d×r,r×nd×n\overline {\mathbf X}=\mathbf {DC}: d\times r , r\times n \to d\times n,时间复杂度为O(nrd)O(nrd)

    • 总时间复杂度为O((K+1)nrd+Kdr2)O((K+1)nrd + Kdr^2)

注意r,Kr, K一般不会很大(远小于nn),所以该方法的时间复杂度关于序列长度大概能到线性。

训练以及loss

没有区别。

代码

实验以及适用场景

目前只适用于Encoder结构,不适用于Decoder结构;实验主要是基于CV的,能达到和Attention相当的结果。

细节

在进行循环的时候不计算梯度,只有最后一次操作计算梯度。

简评

优点:

  • 提供了一种新的理解Attention的视角;

  • 方法实现比较简洁;

缺点:

  • 没法直接应用到Decoder结构中,即无法训练lm;

总结

  • 值得复现;

Last updated