Simple Recurrence Improves Masked Language Models

论文地址:

整体思路以及计算方式

将Transformer中的FFN模块换成RNN,最终带来了提升,计算公式如下:

  • 输入XRn×d\mathbf X\in \mathbb R^{n\times d}

  • 隐藏状态X1=XW1Rn×d1,X2=XW2Rn×d1\mathbf X_1= \mathbf X\mathbf W_1\in \mathbb R^{n\times d_1},\mathbf X_2= \mathbf X\mathbf W_2\in \mathbb R^{n\times d_1}

  • 计算CRn×d1\mathbf C\in \mathbb R^{n\times d_1}

    • c[0]=0\mathbf c[0]=0

    • c[i]=Swish(c[i1]x1[i])+x1[i]\mathbf c[i]=\mathrm{Swish}\left(\mathbf c[i-1]- \mathbf x_1[i]\right)+\mathbf x_1[i]

  • H=((C+bc)σ(X2+bσ))W3+b3Rn×d{\mathbf H}=\left(\left(\mathbf {C}+\mathbf {b}_{c}\right) \odot \sigma\left(\mathbf {X}_{2}+\mathbf {b}_{\sigma}\right)\right)\mathbf {W}_{3}+\mathbf {b}_{3} \in \mathbb R^{n\times d}

改进:

由于循环太慢,另一种计算方式是对kk个位置同时计算,k=1k=1退化到前一种情形:

  • c[0:k]=0\mathbf c[0:k]=0

  • c[ik:(i+1)k]=Swish(c[(i1)k:ik]x1[ik:(i+1)k])+x1[ik:(i+1)k]\mathbf c[ik:(i+1)k]=\mathrm{Swish}\left(\mathbf c[(i-1)k:ik]- \mathbf x_1[ik:(i+1)k]\right)+\mathbf x_1[ik:(i+1)k]

时间复杂度

总时间为O(ndd1)O(n dd_1),但是由于是RNN,实际上会慢很多,作者给出的数字是k=1k=1时耗时为140%,k=2k=2时耗时为120%。

训练以及loss

不变。

代码

暂无。

实验以及适用场景

适用于所有场景,作者测试了BERT(Encoder)和GLUE任务,带来了一定的提升,注意这里是时间换性能,所以是否值得需要视场景而定;Decoder的结果作者没有测试,后续可以尝试一下。

细节

暂无。

简评

思路很简单的一篇论文,但是可以带来如下思考:

  • Transformer中FFN的作用到底是啥,之前一直理解为特征融合模块,但是利用RNN这样的序列融合模块也能达到同样作用;

  • 既然FFN和RNN起的作用相当,而RNN可以用Attention模块代替,那是否可以将FFN换成Attention?

Last updated