Last updated 2 years ago
论文地址:
https://arxiv.org/abs/1907.01470
思路非常简单,给Transformer模块增加和输入无关的记忆模块,并且删除FFN:
计算流程:
输入:Q,K,V∈Rn×d\mathbf Q,\mathbf K, \mathbf V\in \mathbb R^{n\times d}Q,K,V∈Rn×d;
内部参数:Mk,Mv∈RN×d\mathbf M_k, \mathbf M_{v} \in \mathbb R^{N\times d}Mk,Mv∈RN×d;
NNN为预设的最大值;
K1=[K,Mk[:n]]∈R(n+N)×d\mathbf K_1 = [\mathbf K, \mathbf M_{k}[:n]]\in \mathbf R^{(n+N)\times d}K1=[K,Mk[:n]]∈R(n+N)×d
V1=[V,Mv[:n]]∈R(n+N)×d\mathbf V_1 = [\mathbf V, \mathbf M_{v}[:n]]\in \mathbf R^{(n+N)\times d}V1=[V,Mv[:n]]∈R(n+N)×d
输出:MHA(Q,K1,V1)∈Rn×d\mathrm{MHA}(\mathbf Q, \mathbf K_1, \mathbf V_1)\in \mathbb R^{n\times d}MHA(Q,K1,V1)∈Rn×d
O(n(n+N)d)O(n(n+N)d)O(n(n+N)d)。
不变。
暂无,但是实现起来不难。
论文测试了lm,有一定的性能提升。
暂无。
不错的想法,有助于大家理解FFN的作用,后续可以考虑复现。