Multi-Head Attention Collaborate Instead of Concatenate

Multi-Head Attention: Collaborate Instead of Concatenate

论文地址:

整体思路以及计算方式

Q,K\mathbf Q,\mathbf K降维,然后通过对角阵增加模型表达力,最后达到相当的效果。

计算方式:

  • 压缩率pphh为头数;

  • 输入:XRn×d,YRm×d,miR,i=1,,h\mathbf X\in \mathbb R^{n\times d},\mathbf Y\in \mathbb R^{m\times d},m_i\in \mathbb R,i=1,\ldots, h,记d1=pdd_1=\lfloor pd\rfloor

  • for i=1,,hi=1,\ldots, h

    • 计算Q=XWQ(i)Rn×(d1/h),K=XWK(i)Rm×(d1/h),V=XWV(i)Rm×(d/h)\mathbf Q= \mathbf X\mathbf W_Q^{(i)} \in \mathbb R^{n\times (d_1 /h)},\mathbf K= \mathbf X\mathbf W_K^{(i)} \in \mathbb R^{m\times (d_1 /h)}, \mathbf V =\mathbf X\mathbf W_V^{(i)}\in \mathbb R^{m\times (d /h)}

    • H(i)=MHA(Qdiag(mi),K,V)\mathbf H^{(i)}=\mathrm{MHA}(\mathbf Q\mathrm{diag}(m_i), \mathbf K, \mathbf V)

  • 输出:Concat[H(i)]\mathrm{Concat}[\mathbf H^{(i)}]

说明:

  • 尽管原文中不同头算Q,K\mathbf Q,\mathbf KWQ,WK\mathbf W_Q,\mathbf W_K是共享的,但实际实现的时候并不是;

时间复杂度

对于每个头,时间复杂度为:

O(nmd1/h+mnd/h)=O(nmpd/h+mnd/h)O(nmd_1/h + mnd/h)=O(nmpd/h + mnd/h)

所以hh个头的时间复杂度为:

O(mn(p+1)d)O(mn(p+1)d )

训练以及loss

略过。

代码

实验以及适用场景

因为只改了Head部分,所以适用于所有场景;作者进行了大量实验,效果均不错。

细节

降维比例为30%的时候也能达到相当效果。

简评

总结:

  • 很简洁的思路,通过降维减少参数量,然后再通过少量参数恢复性能;

  • 非常简洁,值得复现;

Last updated