论文地址:
https://arxiv.org/abs/2205.11588arrow-up-right
将Transformer中的FFN模块换成RNN,最终带来了提升,计算公式如下:
输入X∈Rn×d\mathbf X\in \mathbb R^{n\times d}X∈Rn×d
隐藏状态X1=XW1∈Rn×d1,X2=XW2∈Rn×d1\mathbf X_1= \mathbf X\mathbf W_1\in \mathbb R^{n\times d_1},\mathbf X_2= \mathbf X\mathbf W_2\in \mathbb R^{n\times d_1}X1=XW1∈Rn×d1,X2=XW2∈Rn×d1
计算C∈Rn×d1\mathbf C\in \mathbb R^{n\times d_1}C∈Rn×d1
c[0]=0\mathbf c[0]=0c[0]=0
c[i]=Swish(c[i−1]−x1[i])+x1[i]\mathbf c[i]=\mathrm{Swish}\left(\mathbf c[i-1]- \mathbf x_1[i]\right)+\mathbf x_1[i]c[i]=Swish(c[i−1]−x1[i])+x1[i]
H=((C+bc)⊙σ(X2+bσ))W3+b3∈Rn×d{\mathbf H}=\left(\left(\mathbf {C}+\mathbf {b}_{c}\right) \odot \sigma\left(\mathbf {X}_{2}+\mathbf {b}_{\sigma}\right)\right)\mathbf {W}_{3}+\mathbf {b}_{3} \in \mathbb R^{n\times d}H=((C+bc)⊙σ(X2+bσ))W3+b3∈Rn×d
改进:
由于循环太慢,另一种计算方式是对kkk个位置同时计算,k=1k=1k=1退化到前一种情形:
c[0:k]=0\mathbf c[0:k]=0c[0:k]=0
c[ik:(i+1)k]=Swish(c[(i−1)k:ik]−x1[ik:(i+1)k])+x1[ik:(i+1)k]\mathbf c[ik:(i+1)k]=\mathrm{Swish}\left(\mathbf c[(i-1)k:ik]- \mathbf x_1[ik:(i+1)k]\right)+\mathbf x_1[ik:(i+1)k]c[ik:(i+1)k]=Swish(c[(i−1)k:ik]−x1[ik:(i+1)k])+x1[ik:(i+1)k]
总时间为O(ndd1)O(n dd_1)O(ndd1),但是由于是RNN,实际上会慢很多,作者给出的数字是k=1k=1k=1时耗时为140%,k=2k=2k=2时耗时为120%。
不变。
暂无。
适用于所有场景,作者测试了BERT(Encoder)和GLUE任务,带来了一定的提升,注意这里是时间换性能,所以是否值得需要视场景而定;Decoder的结果作者没有测试,后续可以尝试一下。
思路很简单的一篇论文,但是可以带来如下思考:
Transformer中FFN的作用到底是啥,之前一直理解为特征融合模块,但是利用RNN这样的序列融合模块也能达到同样作用;
既然FFN和RNN起的作用相当,而RNN可以用Attention模块代替,那是否可以将FFN换成Attention?
Last updated 3 years ago