Object-Centric Learning with Slot Attention

论文地址:

参考资料:

整体思路以及计算方式

对任务背景没有特别的了解,感觉是一种抽特征的方式,直接讨论计算方式,忽略Normlize相关部分:

  • XRN×d1\mathbf X\in \mathbb R^{N\times d_1}

  • SN(μ,diag(σ))RK×d2\mathbf {S} \sim \mathcal{N}(\mu, \operatorname{diag}(\sigma)) \in \mathbb{R}^{K \times d_2}(代表Slots\mathbf {Slots}

  • for t=0,,T1t=0,\ldots ,T-1:

    • Sprev=SRK×d2\mathbf {S}_{\mathrm{prev}}=\mathbf{S}\in \mathbb R^{K\times d_2}

    • Q=SWqRK×d,K=XWkRN×d,V=XWvRN×d\mathbf Q= \mathbf{S}\mathbf W_q \in \mathbb R^{K\times d},\mathbf K=\mathbf X\mathbf W_k\in \mathbb R^{N\times d},\mathbf V=\mathbf X\mathbf W_v \in \mathbb R^{N\times d}

    • A=Softmax(QK,dim=0)RK×N\mathrm{A}=\mathrm{Softmax}(\mathbf Q\mathbf K^{\top} , \mathrm{dim}=0)\in \mathbb R^{K\times N}

    • U=AVRK×d\mathbf{U}=\mathbf A\mathbf V\in \mathbb R^{K\times d}

    • S=GRU(Sprev,U)RK×d2\mathbf{S}= \mathrm{GRU}(\mathbf {S}_{\mathrm{prev}}, \mathbf U) \in \mathbb R^{K\times d_2}

时间复杂度

MHA\mathrm{MHA}的时间复杂度为O(KNd)O(KNd),总时间复杂度为O(TKNd)O(TKNd)

训练以及loss

没有变化。

代码

实验以及适用场景

作者进行的实验比较简单,这里不进行讨论。

细节

略过。

简评

个人理解是一种抽特征的方式,不知道能否适用于NLP任务。

Last updated