Memory Transformer

论文地址:

整体思路以及计算方式

思路比较简洁,对输入部分增加mm个mem token,记为xmemRm×d\mathbf x^{mem}\in \mathbb R^{m\times d},原始输入记为xseqRn×d\mathbf x^{seq}\in \mathbb R^{n\times d},合并后的输入记为xmem+seq=[xmem;xseq]R(n+m)×d\mathbf x^{mem+seq}=[\mathbf x^{mem};\mathbf x^{seq}]\in \mathbb R^{(n+m)\times d}

论文一共介绍了三个模型,分别为:

  • Mem Transformer:

    xmem+seq=MHA(xmem+seq,xmem+seq)\mathbf x^{mem+seq}= \mathrm{MHA}(\mathbf x^{mem+seq},\mathbf x^{mem+seq})
  • MemCtrl Transformer:

    xmem=MHA(xmem,xmem+seq)xseq=MHA(xmem,xmem+seq)\begin{aligned} \mathbf x^{mem}&= \mathrm{MHA}(\mathbf x^{mem},\mathbf x^{mem+seq}) \\ \mathbf x^{seq}&= \mathrm{MHA}(\mathbf x^{mem},\mathbf x^{mem+seq}) \end{aligned}
  • MemBottleNeck Transformer:

    xmem=MHA(xmem,xmem+seq)xseq=MHA(xmem,xmem)\begin{aligned} \mathbf x^{mem}&= \mathrm{MHA}(\mathbf x^{mem},\mathbf x^{mem+seq}) \\ \mathbf x^{seq}&= \mathrm{MHA}(\mathbf x^{mem},\mathbf x^{mem}) \end{aligned}

时间复杂度

依然是标准Attention的计算方式,所以时间复杂度为O((n+m)2d)O((n+m)^2 d)

训练以及loss

不变。

代码

实验以及适用场景

Encoder和Decoder均适用;实验比较全,Encoder, Decoder以及Encoder-Decoder结构均测试过,总体效果积极。

细节

论文中mm取的比较小,所以增加的时间并不多。

简评

优点:

  • 非常简洁清晰的方法;

  • 适用于单向和双向模型;

总结

  • 值得复现;

Last updated