论文地址:
https://arxiv.org/abs/1911.05507arrow-up-right
计算方式:
构造记忆m∈Rnm×d\mathbf m\in \mathbb R^{n_m\times d}m∈Rnm×d和压缩记忆cm∈Rncm×d\mathbf {cm}\in \mathbb R^{n_{cm}\times d}cm∈Rncm×d;
对于输入x∈Rn×d\mathbf x\in \mathbb R^{n\times d}x∈Rn×d,将记忆和压缩记忆拼接为整体记忆mem=concat(cm,m,x)∈R(nm+ncm+n)×d\mathbf{mem}=\mathrm{concat}(\mathbf {cm},\mathbf m,\mathbf x) \in \mathbb R^{(n_m+ n_{cm}+n)\times d}mem=concat(cm,m,x)∈R(nm+ncm+n)×d,得到输出MHA(x,mem)∈Rn×d\mathrm{MHA}(\mathbf x, \mathbf{mem}) \in \mathbb R^{n\times d}MHA(x,mem)∈Rn×d。
记忆的更新方式为:
记忆:
拼接,选择最近的nmn_mnm个记忆:m=concat(m,x)[−nm:]∈Rnm×d\mathbf m =\mathrm{concat}(\mathbf m,\mathbf x)[-n_m:] \in \mathbb R^{n_m \times d}m=concat(m,x)[−nm:]∈Rnm×d
压缩记忆:
对m[:n]\mathbf m[:n]m[:n]的序列维度降维ccc倍得到cmtmp∈R⌊nc⌋×d\mathbf {cm}_{tmp}\in \mathbb R^{\left\lfloor\frac{n}{c}\right\rfloor \times d}cmtmp∈R⌊cn⌋×d
拼接,选择最近的ncmn_{cm}ncm个记忆:cm=concat(cm,cmtmp)[−ncm:]∈Rncm×d\mathbf {cm}=\mathrm{concat}(\mathbf {cm}, \mathbf {cm}_{tmp})[-n_{cm}:]\in \mathbb R^{n_{cm}\times d}cm=concat(cm,cmtmp)[−ncm:]∈Rncm×d
依然是标准Attention的计算方式,所以时间复杂度为O(ns(nm+ncm+n)d)O(n_s(n_m+ n_{cm} +n) d)O(ns(nm+ncm+n)d)。
训练方式一致,loss部分增加了如下部分:
其中old_mem/mem\mathbf {old\_mem/mem}old_mem/mem表示m,cm\mathbf {m},\mathbf c_mm,cm更新前/后拼接得到的整体记忆,应该是确保训练稳定。
https://github.com/lucidrains/compressive-transformer-pytorcharrow-up-right
单向双向模型均适用;论文里只测试了lm(单向模型),效果有所提升。
记忆和压缩记忆都不在计算图内,即不使用梯度方式更新。
优点:
适用于单向和双向模型;
引入了记忆机制,提升了性能;
不足:
引入的记忆机制增增加了不少显存,时间复杂度也增加了;
压缩记忆的动机不够清晰;
总结:
是一种时间和空间换性能的方法,不会进行复现;
Last updated 3 years ago