Compressive Transformers for Long-Range Sequence Modelling

论文地址:

整体思路以及计算方式

计算方式:

  • 构造记忆mRnm×d\mathbf m\in \mathbb R^{n_m\times d}和压缩记忆cmRncm×d\mathbf {cm}\in \mathbb R^{n_{cm}\times d}

  • 对于输入xRn×d\mathbf x\in \mathbb R^{n\times d},将记忆和压缩记忆拼接为整体记忆mem=concat(cm,m,x)R(nm+ncm+n)×d\mathbf{mem}=\mathrm{concat}(\mathbf {cm},\mathbf m,\mathbf x) \in \mathbb R^{(n_m+ n_{cm}+n)\times d},得到输出MHA(x,mem)Rn×d\mathrm{MHA}(\mathbf x, \mathbf{mem}) \in \mathbb R^{n\times d}

记忆的更新方式为:

  • 记忆:

    • 拼接,选择最近的nmn_m个记忆:m=concat(m,x)[nm:]Rnm×d\mathbf m =\mathrm{concat}(\mathbf m,\mathbf x)[-n_m:] \in \mathbb R^{n_m \times d}

  • 压缩记忆:

    • m[:n]\mathbf m[:n]的序列维度降维cc倍得到cmtmpRnc×d\mathbf {cm}_{tmp}\in \mathbb R^{\left\lfloor\frac{n}{c}\right\rfloor \times d}

    • 拼接,选择最近的ncmn_{cm}个记忆:cm=concat(cm,cmtmp)[ncm:]Rncm×d\mathbf {cm}=\mathrm{concat}(\mathbf {cm}, \mathbf {cm}_{tmp})[-n_{cm}:]\in \mathbb R^{n_{cm}\times d}

时间复杂度

依然是标准Attention的计算方式,所以时间复杂度为O(ns(nm+ncm+n)d)O(n_s(n_m+ n_{cm} +n) d)

训练以及loss

训练方式一致,loss部分增加了如下部分:

MHA(x,old_mem)MHA(x,mem)2\left\| \mathrm{MHA}(\mathbf x, \mathbf{old\_{mem}}) - \mathrm{MHA}(\mathbf x, \mathbf {mem}) \right\|_2

其中old_mem/mem\mathbf {old\_mem/mem}表示m,cm\mathbf {m},\mathbf c_m更新前/后拼接得到的整体记忆,应该是确保训练稳定。

代码

实验以及适用场景

单向双向模型均适用;论文里只测试了lm(单向模型),效果有所提升。

细节

记忆和压缩记忆都不在计算图内,即不使用梯度方式更新。

简评

优点:

  • 适用于单向和双向模型;

  • 引入了记忆机制,提升了性能;

不足:

  • 引入的记忆机制增增加了不少显存,时间复杂度也增加了;

  • 压缩记忆的动机不够清晰;

总结:

  • 是一种时间和空间换性能的方法,不会进行复现;

Last updated