论文地址:
https://arxiv.org/abs/2103.03206arrow-up-right
参考资料:
https://zhuanlan.zhihu.com/p/360773327arrow-up-right
https://blog.csdn.net/weixin_39707121/article/details/117258115arrow-up-right
整体思路是利用CrossAttention来降维。
具体计算方式如下:
给定输入x∈Rn×d,y∈Rm×d\mathbf x\in \mathbb R^{n\times d}, \mathbf y\in \mathbb R^{m\times d}x∈Rn×d,y∈Rm×d;
x\mathbf xx对应Latent array,y\mathbf yy对应Byte array,这里假设n≪mn\ll mn≪m;
一个例子是x\mathbf xx为图像的patch表示,y\mathbf yy为像素级表示;
x=MHA1(x,y)∈Rn×d\mathbf x= \mathrm{MHA}_1(\mathbf x, \mathbf y)\in \mathbb R^{n\times d}x=MHA1(x,y)∈Rn×d
x=MHA2(x,x)∈Rn×d\mathbf x= \mathrm{MHA}_2(\mathbf x,\mathbf x)\in \mathbb R^{n\times d}x=MHA2(x,x)∈Rn×d
备注:这里省略了FFN以及NORM操作。
MHA1\mathrm{MHA}_1MHA1的时间复杂度为O(mnd)O(mnd)O(mnd),MHA2\mathrm{MHA}_2MHA2的时间复杂度为O(n2d)O(n^2d)O(n2d),总时间复杂度为O(mnd+n2d)O(mnd+n^2d)O(mnd+n2d),论文里假设n≪mn\ll mn≪m,所以总复杂度为O(mnd)O(mnd)O(mnd)。
不变。
https://github.com/lucidrains/perceiver-pytorcharrow-up-right
感觉还是主要适用于Encoder场景,像LM,NMT这样的任务似乎没法直接应用;论文做了除NLP以外的实验,效果还行。
暂无,需要复现的时候体会。
优点:
把CrossAttention理解为降维是一个很好的点;
总结:
值得复现,可以尝试应用于Roberta模型中;
LM, NMT场景是否能使用需要思考;
Last updated 3 years ago