Momentum Transformer: Closing the Performance Gap Between Self-attention and Its Linearization

论文地址:

整体思路以及计算方式

在Linear Attention的RNN版本中添加动量项来提升性能,从底层原理上来看,实际上是给Linear Attention添加了指数衰减的相对位置编码。

Linear Attention的RNN形式:

si=si1+ϕ(ki)vizi=zi1+ϕ(ki)v^i=ϕ(qi)siϕ(qi)zi\begin{aligned} \boldsymbol{s}_i & =\boldsymbol{s}_{i-1}+\phi\left(\boldsymbol{k}_i\right) \boldsymbol{v}_i^{\top} \\ \boldsymbol{z}_i & =\boldsymbol{z}_{i-1}+\phi\left(\boldsymbol{k}_i\right) \\ \hat{\boldsymbol{v}}_i & =\frac{\phi\left(\boldsymbol{q}_i\right)^{\top} \boldsymbol{s}_i}{\phi\left(\boldsymbol{q}_i\right)^{\top} \boldsymbol{z}_i} \end{aligned}

动量形式:

mi=βmi1ϕ(ki)visi=si1γmizi=zi1+ϕ(ki)v^i=ϕ(qi)siϕ(qi)zi\begin{aligned} \boldsymbol{m}_i & =\beta \boldsymbol{m}_{i-1}-\phi\left(\boldsymbol{k}_i\right) \boldsymbol{v}_i^{\top} \\ \boldsymbol{s}_i & =\boldsymbol{s}_{i-1}-\gamma \boldsymbol{m}_i \\ \boldsymbol{z}_i & =\boldsymbol{z}_{i-1}+\phi\left(\boldsymbol{k}_i\right) \\ \hat{\boldsymbol{v}}_i & =\frac{\phi\left(\boldsymbol{q}_i\right)^{\top} \boldsymbol{s}_i}{\phi\left(\boldsymbol{q}_i\right)^{\top} \boldsymbol{z}_i} \end{aligned}

并行形式:

v^i=γϕ(qi)j=1i(1βij+11βϕ(kj)vj)ϕ(qi)zi\hat{\boldsymbol{v}}_i=\frac{\gamma \phi\left(\boldsymbol{q}_i\right)^{\top} \sum_{j=1}^i\left(\frac{1-\beta^{i-j+1}}{1-\beta} \phi\left(\boldsymbol{k}_j\right) \boldsymbol{v}_j^{\top}\right)}{\phi\left(\boldsymbol{q}_i\right)^{\top} \boldsymbol{z}_i}

可以看到最后的形式多了一个指数衰减的相对位置编码,这是最后性能提升的根本原因。

代码

简评

最后的形式是很简单的,我也进行过相关实验,确实有一定的提升。

Last updated