When Attention Meets Fast Recurrence: Training Language Models with Reduced Compute

论文地址:

整体思路以及计算方式

将SRU中的全连接层替换成MHA,增加模型表达性,首先回顾SRU:

U=(WWW)X\mathbf{U}^{\top}=\left(\begin{array}{l} \mathbf{W} \\ \mathbf{W}^{\prime} \\ \mathbf{W}^{\prime \prime} \end{array}\right) \mathbf{X}^{\top}

其中URL×3×d,XRL×d.\mathbf{U} \in \mathbb{R}^{L \times 3 \times d}, \mathbf{X} \in \mathbb{R}^{L \times d}.

然后利用递推式计算:

f[t]=σ(U[t,0]+vc[t1]+b)r[t]=σ(U[t,1]+vc[t1]+b)c[t]=f[t]c[t1]+(1f[t])U[t,2]h[t]=r[t]c[t]+(1r[t])x[t].\begin{aligned} \mathbf{f}[t] & =\sigma(\mathbf{U}[t, 0]+\mathbf{v} \odot \mathbf{c}[t-1]+\mathbf{b}) \\ \mathbf{r}[t] & =\sigma\left(\mathbf{U}[t, 1]+\mathbf{v}^{\prime} \odot \mathbf{c}[t-1]+\mathbf{b}^{\prime}\right) \\ \mathbf{c}[t] & =\mathbf{f}[t] \odot \mathbf{c}[t-1]+(1-\mathbf{f}[t]) \odot \mathbf{U}[t, 2] \\ \mathbf{h}[t] & =\mathbf{r}[t] \odot \mathbf{c}[t]+(1-\mathbf{r}[t]) \odot \mathbf{x}[t] . \end{aligned}

这里的改进是,将U\mathbf U部分替换为MHA:

U=WoLayernorm(Q+αA)A=softmax(QKd)VQ=WqXK=WkQV=WvQ\begin{aligned} \mathbf{U}^{\top}&=\mathbf{W}^o\mathrm{Layernorm}(\mathbf{Q}+\alpha \cdot \mathbf{A}) \\ \mathbf{A}^{\top}&=\operatorname{softmax}\left(\frac{\mathbf{Q}^{\top} \mathbf{K}}{\sqrt{d^{\prime}}}\right) \mathbf{V}^{\top} \\ \mathbf{Q}& =\mathbf{W}^q \mathbf{X}^{\top} \\ \mathbf{K}& =\mathbf{W}^k \mathbf{Q} \\ \mathbf{V}& =\mathbf{W}^v \mathbf{Q} \end{aligned}

代码

细节

作者测试了每几层增加Attention的效果,最后的结论是,只要增加一层Attention,就能比纯SRU的效果好很多。

简评

不错的工作,最后的结论也有启发意义。

Last updated