Perceiver General Perception with Iterative Attention

论文地址:

参考资料:

整体思路以及计算方式

整体思路是利用CrossAttention来降维。

具体计算方式如下:

  • 给定输入xRn×d,yRm×d\mathbf x\in \mathbb R^{n\times d}, \mathbf y\in \mathbb R^{m\times d}

    • x\mathbf x对应Latent array,y\mathbf y对应Byte array,这里假设nmn\ll m

    • 一个例子是x\mathbf x为图像的patch表示,y\mathbf y为像素级表示;

  • x=MHA1(x,y)Rn×d\mathbf x= \mathrm{MHA}_1(\mathbf x, \mathbf y)\in \mathbb R^{n\times d}

  • x=MHA2(x,x)Rn×d\mathbf x= \mathrm{MHA}_2(\mathbf x,\mathbf x)\in \mathbb R^{n\times d}

备注:这里省略了FFN以及NORM操作。

时间复杂度

MHA1\mathrm{MHA}_1的时间复杂度为O(mnd)O(mnd)MHA2\mathrm{MHA}_2的时间复杂度为O(n2d)O(n^2d),总时间复杂度为O(mnd+n2d)O(mnd+n^2d),论文里假设nmn\ll m,所以总复杂度为O(mnd)O(mnd)

训练以及loss

不变。

代码

https://github.com/lucidrains/perceiver-pytorch

实验以及适用场景

感觉还是主要适用于Encoder场景,像LM,NMT这样的任务似乎没法直接应用;论文做了除NLP以外的实验,效果还行。

细节

暂无,需要复现的时候体会。

简评

优点:

  • 把CrossAttention理解为降维是一个很好的点;

总结:

  • 值得复现,可以尝试应用于Roberta模型中;

  • LM, NMT场景是否能使用需要思考;

Last updated