XCiT: Cross-Covariance Image Transformers

论文地址:

整体思路以及计算方式

输入:

  • XRn×d1X\in \mathbb R^{n\times d_1}

  • Q,K,V=XWQ,XWK,XWVRn×d2Q,K,V=XW_Q, XW_K, XW_V\in \mathbb R^{n\times d_2}

  • Q=Norm(Q),K=Norm(K)Q= \mathrm{Norm}(Q), K=\mathrm{Norm}(K)

  • O=VSoftmax(KQ)WoRn×d1O=V\mathrm{Softmax}(K^{\top} Q) W_o\in \mathbb R^{n\times d_1}(分组计算)

时间复杂度

假设有hh个分组,那么时间复杂度为O(n(d/h)2×h)=O(nd2/h)O(n(d/h)^2\times h)=O(nd^2/h)

训练以及loss

不变。

代码

实验以及适用场景

适用于Encoder,作者进行了视觉任务,效果都不错。

细节

作者在Attention和FFN之间增加了一个模块,带来了不少提升,但是不加这个模块性能一般;另一方面,计算内积的同时增加了分组操作,这部分需要看源码。

简评

这篇思路过于简单,不知道该模块单独使用是否起作用。

Last updated