Translational Equivariance in Kernelizable Attention

论文地址:

整体思路以及计算方式

论文讨论的是在线性Attention中添加相对位置编码信息,整体思路如下。

首先回顾Attention计算方式,其中ei\mathbf e_i是word embedding,ui\mathbf u_i是position embedding:

A~i,j=(ei+ui)Wq,(ej+uj)Wk=(ei+ui)WqWk(ej+uj)=eiWqWkej+eiWqWkuj+uiWqWkej+uiWqWkuj\begin{aligned} \tilde{\mathbf{A}}_{i, j} &=\left\langle\left(\mathbf e_{i}+\mathbf u_{i}\right) \mathbf {W}_{q},\left(\mathbf e_{j}+\mathbf u_{j}\right) \mathbf {W}_{k}\right\rangle=\left(\mathbf e_{i}+\mathbf u_{i}\right) \mathbf {W}_{q} \mathbf {W}_{k}^{\top}\left(\mathbf e_{j}+\mathbf u_{j}\right)^{\top} \\ &=\mathbf e_{i} \mathbf {W}_{q} \mathbf {W}_{k}^{\top}\mathbf e_{j}^{\top}+\mathbf e_{i} \mathbf {W}_{q} \mathbf {W}_{k}^{\top} \mathbf u_{j}^{\top}+\mathbf u_{i} \mathbf {W}_{q} \mathbf {W}_{k}^{\top} \mathbf e_{j}^{\top}+\mathbf u_{i} \mathbf {W}_{q} \mathbf {W}_{k}^{\top} \mathbf u_{j}^{\top} \end{aligned}

论文的思路是对该计算方式进行重构,并且仍然能保持线性Attention的性质。

方案1:

A~i,j=eiWqWkej+uiWqWkuj\begin{aligned} \tilde{\mathbf{A}}_{i, j} &=\mathbf e_{i} \mathbf {W}_{q} \mathbf {W}_{k}^{\top}\mathbf e_{j}^{\top} +\mathbf u_{i} \mathbf {W}_{q} \mathbf {W}_{k}^{\top}\mathbf u_{j}^{\top} \end{aligned}

其中:

Wq=blockdiag([α1β1β1α1],,[αmβmβmαm])Wk=I2m\begin{aligned} \mathbf {W}_{q}^{*} &=\operatorname{blockdiag}\left(\left[\begin{array}{cc} \alpha_{1} & \beta_{1} \\ -\beta_{1} & \alpha_{1} \end{array}\right], \ldots,\left[\begin{array}{cc} \alpha_{m} & \beta_{m} \\ -\beta_{m} & \alpha_{m} \end{array}\right]\right)\\ \mathbf {W}_{k}^{*}&=\mathbb{I}_{2 m} \end{aligned}

该方案的特点是,如果位置编码的形式为:

ux=ϕ(x)=[sin(ω1x),cos(ω1x),,sin(ωmx),cos(ωmx)]u_x =\phi(x)=\left[\sin \left(\omega_{1} x\right), \cos \left(\omega_{1} x\right), \ldots, \sin \left(\omega_{m} x\right), \cos \left(\omega_{m} x\right)\right]

那么满足如下性质:

uiuWq,ujuWk=uiWq,ujWk\left\langle \mathbf u_{i-u} \mathbf {W}_{q},\mathbf u_{j-u} \mathbf {W}_{k}\right\rangle =\left\langle \mathbf u_{i} \mathbf {W}_{q},\mathbf u_{j} \mathbf {W}_{k}\right\rangle

方案2:

A~ij=eiWq(ejWk+aij)=eiWqWkej+eiWqaij\tilde{\mathbf{A}}_{i j}=\mathbf e_{i} \mathbf{W}_{q}\left(\mathbf e_{j} \mathbf{W}_{k}+{a}_{i j}\right)^{\top}=\mathbf e_{i} \mathbf{W}_{q} \mathbf{W}_{k}^{\top} \mathbf e_{j}^{\top}+\mathbf e_{i} \mathbf{W}_{q} \mathbf {a}_{i j}^{\top}

其中

aij=wclip(ji,k)clip(x,k)=max(k,min(k,x))\begin{aligned} \mathbf {a}_{i j} &=\mathbf{w}_{\operatorname{clip}(j-i, k)} \\ \operatorname{clip}(x, k) &=\max (-k, \min (k, x)) \end{aligned}

clip函数的含义是将输入截断至[k,k][-k, k]之间。

如果使用普通的实现方式,那么时间复杂度为O(L2d)O(L^2 d),但是注意到:

j=1Lqi,aijvj=qi,wkj=1Lvj+mqi,aim\sum_{j=1}^{L}\left\langle\mathbf{q}_{i}^{\prime}, \mathbf{a}_{i j}^{\prime}\right\rangle \mathbf{v}_{j}=\left\langle\mathbf{q}_{i}^{\prime}, \mathbf{w}_{k}^{\prime}\right\rangle \sum_{j=1}^{L} \mathbf{v}_{j}+\sum_{m}\left\langle\mathbf{q}_{i}^{\prime}, \mathbf{a}_{i m}^{\prime} \right\rangle

可以将时间复杂度降低为O(Ld2)O(Ld^2),其中dd为embedding维度。

时间复杂度

和线性Attention一致,依然为O(Ld2)O(Ld^2)

训练以及loss

不变。

代码

实验以及适用场景

适用于所有场景,实验只测试了图像任务,但是效果一般。

细节

暂无。

简评

总体感觉新意尚可,但效果一般。

Last updated