Transformer with Fourier Integral Attentions

论文地址:

整体思路以及计算方式

利用非参数回归的方式对Attention进行改进,整体思路分为两步:

非参数回归:

  • vj=f(kj)+εj{v}_{j}=f\left({k}_{j}\right)+\varepsilon_{j}

  • E[vk]=RDvp(vk)dv=vp(v,k)p(k)dv{\mathbb E}[{v} \mid {k}]=\int_{{R}^{D}} {v} \cdot p({v} \mid {k}) d {v}=\int \frac{{v} \cdot p({v}, {k})}{p({k})} d {v}

  • 利用Kernel法估计概率密度(φ\varphi为高斯核函数):

    p^σ(v,k)=1Nj=1Nφσ(vvj)φσ(kkj),p^σ(k)=1Nj=1Nφσ(kkj)\hat{p}_{\sigma}({v}, {k})=\frac{1}{N} \sum_{j=1}^{N} \varphi_{\sigma}\left({v}-{v}_{j}\right) \varphi_{\sigma}\left({k}-{k}_{j}\right), \quad \hat{p}_{\sigma}({k})=\frac{1}{N} \sum_{j=1}^{N} \varphi_{\sigma}\left({k}-{k}_{j}\right)
  • 带入:f^σ(k)=E[vk]=j=1Nvjφσ(kkj)j=1Nφσ(kkj)\widehat{f}_{\sigma}({k})={\mathbb E}[{v} \mid {k}]= \frac{\sum_{j=1}^{N} v_{j} \varphi_{\sigma}\left({k}-{k}_{j}\right)}{\sum_{j=1}^{N} \varphi_{\sigma}\left({k}-{k}_{j}\right)}

  • kk换成qq得到:

    f^σ(qi)=jNvjexp(qikj2/2σ2)jNexp(qikj2/2σ2)=jNvjexp[(qi2+kj2)/2σ2]exp(qikj/σ2)jNexp[(qi2+kj2)/2σ2]exp(qikj/σ2)\begin{aligned} \widehat{f}_{\sigma}\left({q}_{i}\right) &=\frac{\sum_{j}^{N} {v}_{j} \exp \left(-\left\|{q}_{i}-{k}_{j}\right\|^{2} / 2 \sigma^{2}\right)}{\sum_{j}^{N} \exp \left(-\left\|{q}_{i}-{k}_{j}\right\|^{2} / 2 \sigma^{2}\right)} \\ &=\frac{\sum_{j}^{N} {v}_{j} \exp \left[-\left(\left\|{q}_{i}\right\|^{2}+\left\|{k}_{j}\right\|^{2}\right) / 2 \sigma^{2}\right] \exp \left({q}_{i} {k}_{j}^{\top} / \sigma^{2}\right)}{\sum_{j}^{N} \exp \left[-\left(\left\|{q}_{i}\right\|^{2}+\left\|{k}_{j^{\prime}}\right\|^{2}\right) / 2 \sigma^{2}\right] \exp \left({q}_{i} {k}_{j}^{\top} / \sigma^{2}\right)} \end{aligned}

    如果假设qi=kj\|q_i\| = \|k_j\|,那么上式退化为Attention,由此作者说该方法是Attention的推广;

计算:

  • 作者利用傅里叶定理求解非参数回归问题,思路为利用傅里叶积分定理计算φσ(kkj)\varphi_{\sigma}\left({k}-{k}_{j}\right)

  • 直接给出计算公式:

    h^i:=fN,R(qi)=i=1Nvij=1Dϕ(sin(R(qijkij))R(qijkij))i=1Nj=1Dϕ(sin(R(qijkij))R(qijkij))\hat{{h}}_{i}:=f_{N, R}\left({q}_{i}\right)=\frac{\sum_{i=1}^{N} {v}_{i} \prod_{j=1}^{D} \phi\left(\frac{\sin \left(R\left(q_{i j}-k_{i j}\right)\right)}{R\left(q_{i j}-k_{i j}\right)}\right)}{\sum_{i=1}^{N} \prod_{j=1}^{D} \phi\left(\frac{\sin \left(R\left(q_{i j}-k_{i j}\right)\right)}{R\left(q_{i j}-k_{i j}\right)}\right)}
  • 这里ϕ\phi是一个函数,论文里有介绍。

时间复杂度

依然为O(n2d)O(n^2d),所以理论复杂度没有改进,根据计算的形式,推测速度会慢。

训练以及loss

不变。

代码

实验以及适用场景

适用于Encoder, Decoder,结果有所提升。

细节

暂无。

简评

不错的一个思路,让人眼前一亮。

Last updated