S2-MLPv2 Improved Spatial-Shift MLP Architecture for Vision

论文地址:

整体思路以及计算方式

对S2-MLP的改进,使用了多种spatial-shift,然后split attention进行特征融合:

这里主要介绍split attention的计算方式:

  • 输入:[X1,X2,,XK]Rn×K×c,XkRn×c\left[\mathbf{X}_{1}, \mathbf{X}_{2}, \cdots, \mathbf{X}_{K}\right]\in \mathbb R^{n\times K\times c} , \mathbf X_k\in \mathbb R^{n\times c}

  • 特征融合:a=k=1K1nXkRc\mathbf a= \sum_{k=1}^K 1_n \mathbf X_k \in \mathbb R^c

  • a1=σ(aW1)Rc1\mathbf a_1 = \sigma(\mathbf a \mathbf W_1) \in \mathbb R^{c_1}

  • a2=a1W2RKc\mathbf a_2 = \mathbf a_1 \mathbf W_2 \in \mathbb R^{K c}

  • a3=Softmax(reshape(a2),dim=1)R1×K×c\mathbf a_3 =\mathrm{Softmax}(\mathrm{reshape}(\mathbf a_2),\mathrm{dim}=1)\in \mathbb R^{1\times K\times c}

  • o1=[X1,X2,,XK]a3Rn×K×c\mathbf o_1=\left[\mathbf {X}_{1}, \mathbf {X}_{2}, \cdots, \mathbf {X}_{K}\right]\odot \mathbf a_3 \in \mathbb R^{n\times K \times c }

  • o2=Sum(o1,dim=1)Rn×c\mathbf o_2 = \mathrm{Sum}(\mathbf o_1, \mathrm{dim}=1)\in \mathbb R^{n\times c}

该模块主要融合了Xk\mathbf X_k的特征,不知道是否可以代替Attention的效果?

时间复杂度

split attention模块的时间复杂度为O(nKc+cc1+Kc1c)O(nKc + cc_1 + Kc_1c ),其余部分任然为线性复杂度。

训练以及loss

不变。

代码

简评

spatial-shift + split attention可以大幅提升性能,可以研究下,然后在nlp中使用。

Last updated