Stable, Fast and Accurate: Kernelized Attention with Relative Positional Encoding

论文地址:

整体思路以及计算方式

将相对位置编码引入Linear Attention:

zi=ϕ(qi)j=1nebjiϕ(kj)vjϕ(qi)j=1nebjiϕ(kj)\mathbf z_{i}=\frac{\phi\left(\mathbf q_{i}\right)^{\top} \sum_{j=1}^{n} \mathrm{e}^{b_{j-i}} \phi\left(\mathbf k_{j}\right)\mathbf v_{j}} {\phi\left(\mathbf q_{i}\right)^{\top} \sum_{j=1}^{n} \mathrm{e}^{b_{j-i}} \phi\left(\mathbf k_j\right)}

引入符号:

vec(M)=mmsb+t=MstMRa×bmRab\begin{aligned} \mathrm{vec}(\mathbf M)& = \mathbf m \\ \mathbf m_{sb+t}&= \mathbf M_{st}\\ \mathbf M &\in \mathbb R^{a\times b} \\ \mathbf m&\in \mathbb R^{ab} \end{aligned}

引入记号:

D1=(vec(j=1nebj1ϕ(kj)vj)vec(j=1nebjnϕ(kj)vj))Rn×d2D2=(vec(j=1nebj1ϕ(kj))vec(j=1nebjnϕ(kj)))Rn×dC=(c0c1c2cn1c1c0c1cn2c(n1)c(n2)c(n3)c0)ci=ebiB1=(vec(j=1nϕ(kj)vj)vec(j=1nϕ(kj)vj))Rn×d2B2=(vec(j=1nϕ(kj))vec(j=1nϕ(kj)))Rn×d\begin{aligned} \mathbf {D}_{1}&=\left(\begin{array}{c} \operatorname{vec}\left(\sum_{j=1}^{n} \mathrm{e}^{b_{j-1}} \phi\left(\mathbf k_j\right) \mathbf v_j \right) \\ \vdots \\ \operatorname{vec}\left(\sum_{j=1}^{n} \mathrm{e}^{b_{j-n}} \phi\left(\mathbf k_j\right) \mathbf v_j \right) \end{array}\right)\in \mathbb R^{n\times d^2}\\ \mathbf { D}_{2}&=\left(\begin{array}{c} \operatorname{vec}\left(\sum_{j=1}^{n} \mathrm{e}^{b_{j-1}} \phi\left(\mathbf k_j\right)\right) \\ \vdots \\ \operatorname{vec}\left(\sum_{j=1}^{n} \mathrm{e}^{b_{j-n}} \phi\left(\mathbf k_j\right)\right) \end{array}\right) \in \mathbb R^{n\times d} \\ \mathbf C &=\left(\begin{array}{ccccc} c_{0} & c_{1} & c_{2} & \cdots & c_{n-1} \\ c_{-1} & c_{0} & c_{1} & \cdots & c_{n-2} \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ c_{-(n-1)} & c_{-(n-2)} & c_{-(n-3)} & \cdots & c_{0} \end{array}\right) \\ c_{i}&=\mathrm{e}^{b_{i}} \\ \mathbf {B}_{1}&=\left(\begin{array}{c} \operatorname{vec}\left(\sum_{j=1}^{n} \phi\left(\mathbf k_j\right) \mathbf v_j \right) \\ \vdots \\ \operatorname{vec}\left(\sum_{j=1}^{n} \phi\left(\mathbf k_j\right) \mathbf v_j \right) \end{array}\right)\in \mathbb R^{n\times d^2}\\ \mathbf {B}_{2}&=\left(\begin{array}{c} \operatorname{vec}\left(\sum_{j=1}^{n} \phi\left(\mathbf k_j\right)\right) \\ \vdots \\ \operatorname{vec}\left(\sum_{j=1}^{n} \phi\left(\mathbf k_j\right)\right) \end{array}\right) \in \mathbb R^{n\times d} \\ \end{aligned}

那么:

D1=CB1D2=CB2\begin{aligned} \mathbf D_1 & = \mathbf C \mathbf B_1 \\ \mathbf D_2 & = \mathbf C \mathbf B_2 \\ \end{aligned}

输出:

Oi=ϕ(qi)reshape(D1,i)ϕ(qi)D2,iRd\mathbf O_i = \frac{\phi(\mathbf q_i)^{\top} \mathrm {reshape}(\mathbf D_{1,i})} {\phi(\mathbf q_i)^{\top} \mathbf D_{2,i} } \in \mathbb R^{d}

时间复杂度

因为C\mathbf C为Toeplitz矩阵,所以时间复杂度为O(nd2logn)O(nd^2\log n),注意要计算B1\mathbf B_1,所以空间复杂度为O(nd2)O(nd^2)

训练以及loss

不变。

代码

暂无,但实现起来不困难。

实验以及适用场景

适用于所有linear场景,从性能上来说,提升挺明显。

细节

暂无。

简评

思路挺巧妙,不过时间空间复杂度可能较大,但是总体来说是值得复现的工作。

Last updated