Pay Attention to MLPs

论文地址:

整体思路以及计算方式

该论文主要讨论Attention的必要性,提出了gMLP block,整体思路如下:

  • 输入:XRn×d1\mathbf X \in \mathbb R^{n\times d_1}

  • X1=Norm(X)Rn×d1\mathbf X_1= \mathrm{Norm}(\mathbf X)\in \mathbb R^{n\times d_1}

  • X2=f(X1W1)Rn×d2\mathbf X_2=f(\mathbf X_1\mathbf W_1) \in \mathbb R^{n\times d_2}

  • X3=SGU(X2)W2Rn×d1\mathbf X_3= \mathrm{SGU}(\mathbf X_2)\mathbf W_2 \in \mathbb R^{n\times d_1}

SGU有如下几个版本:

  • SGU(Z)=f(Z)\mathrm{SGU}(\mathbf Z)=f(\mathbf Z)

  • SGU(Z)=Z+f(Z)\mathrm{SGU}(\mathbf Z)=Z+f(\mathbf Z)

  • SGU(Z)=Zf(Z)\mathrm{SGU}(\mathbf Z)=\mathbf Z\odot f(\mathbf Z)

  • SGU(Z)=Z1f(Z2),Z=Z1Z2\mathrm{SGU}(\mathbf Z)=\mathbf Z_1\odot f(\mathbf Z_2 ),\mathbf Z=\mathbf Z_{1} \| \mathbf Z_{2}

其中ff表示:

  • f(Z)=WNorm(Z)f(\mathbf Z)=\mathbf W\mathrm{Norm}(\mathbf Z)

最后一个版本效果最好。

时间复杂度

O(nd2+n2d)O(nd^2+n^2 d)

训练以及loss

不变。

代码

实验以及适用场景

进行了Encoder任务,在视觉任务上表现不错,在MLM上性能一般,不过在SGU\mathrm {SGU}中添加Attention可以大幅提升性能(aMLP版本)。

细节

SGU\mathrm {SGU}模块本质上是一个GLU\mathrm{GLU},而aMLP对应的SGU\mathrm {SGU}则和FLASH模型非常相似(可以发现本文的一作也是FLASH的作者之一)。

简评

一个非常不错的思路,可以看到Attention在许多任务中不是必须的;不过该模块缺乏Token之间的交互,所以应该替换FFN更合适,可以进行这方面的实验。

Last updated