Compressive Transformers for Long-Range Sequence Modelling
论文地址:
整体思路以及计算方式
计算方式:
构造记忆m∈Rnm×d和压缩记忆cm∈Rncm×d;
对于输入x∈Rn×d,将记忆和压缩记忆拼接为整体记忆mem=concat(cm,m,x)∈R(nm+ncm+n)×d,得到输出MHA(x,mem)∈Rn×d。
记忆的更新方式为:
记忆:
拼接,选择最近的nm个记忆:m=concat(m,x)[−nm:]∈Rnm×d
压缩记忆:
对m[:n]的序列维度降维c倍得到cmtmp∈R⌊cn⌋×d
拼接,选择最近的ncm个记忆:cm=concat(cm,cmtmp)[−ncm:]∈Rncm×d
时间复杂度
依然是标准Attention的计算方式,所以时间复杂度为O(ns(nm+ncm+n)d)。
训练以及loss
训练方式一致,loss部分增加了如下部分:
MHA(x,old_mem)−MHA(x,mem)2 其中old_mem/mem表示m,cm更新前/后拼接得到的整体记忆,应该是确保训练稳定。
代码
实验以及适用场景
单向双向模型均适用;论文里只测试了lm(单向模型),效果有所提升。
细节
记忆和压缩记忆都不在计算图内,即不使用梯度方式更新。
简评
优点:
不足:
引入的记忆机制增增加了不少显存,时间复杂度也增加了;
总结: