Last updated 2 years ago
论文地址:
参考资料:
利用网络流的思路计算Attention。
输入:
Q∈Rn×d,K∈Rm×d,V∈Rm×dQ\in \mathbb R^{n\times d}, K\in \mathbb R^{m\times d}, V\in \mathbb R^{m\times d}Q∈Rn×d,K∈Rm×d,V∈Rm×d
Q=ϕ(Q)∈Rn×dQ=\phi(Q)\in \mathbb R^{n\times d}Q=ϕ(Q)∈Rn×d
K=ϕ(K)∈Rm×dK=\phi(K)\in \mathbb R^{m\times d}K=ϕ(K)∈Rm×d
Calculate incoming and outgoing flow
Qsum=Sum(Q,d=0)∈RdQ_{sum}=\mathrm{Sum}(Q,d=0) \in \mathbb R^{d}Qsum=Sum(Q,d=0)∈Rd
Ksum=Sum(K,d=0)∈RdK_{sum}=\mathrm{Sum}(K,d=0) \in \mathbb R^{d}Ksum=Sum(K,d=0)∈Rd
dQ=1/(QKsum⊤)∈Rnd_Q = 1/(Q K_{sum}^{\top})\in \mathbb R^{n}dQ=1/(QKsum⊤)∈Rn
dK=1/(KQsum⊤)∈Rmd_K = 1/(K Q_{sum}^{\top})\in \mathbb R^{m}dK=1/(KQsum⊤)∈Rm
conservation refine for source and sink
tQ=Sum(K⊙dK,d=0)∈Rdt_Q= \mathrm{Sum}(K\odot d_K, d=0)\in \mathbb R^{d}tQ=Sum(K⊙dK,d=0)∈Rd
tK=Sum(Q⊙dQ,d=0)∈Rdt_K= \mathrm{Sum}(Q\odot d_Q, d=0)\in \mathbb R^{d}tK=Sum(Q⊙dQ,d=0)∈Rd
sink=Q⊙tQ∈Rn×dsink= Q \odot t_Q \in \mathbb R^{n\times d}sink=Q⊙tQ∈Rn×d
source=K⊙tK∈Rm×dsource = K \odot t_K \in \mathbb R^{m\times d}source=K⊙tK∈Rm×d
Competition & Allocation
α=Sigmoid(sink)∈Rn×d\alpha = \mathrm{Sigmoid}(sink) \in \mathbb R^{n\times d}α=Sigmoid(sink)∈Rn×d
β=Softmax(source)∈Rm×d\beta= \mathrm{Softmax}(source) \in \mathbb R^{m\times d}β=Softmax(source)∈Rm×d
dot product
Q1=Q⊙α∈Rn×dQ_1 = Q\odot \alpha \in \mathbb R^{n\times d}Q1=Q⊙α∈Rn×d
K1=Q⊙β∈Rm×dK_1 = Q\odot \beta \in \mathbb R^{m\times d}K1=Q⊙β∈Rm×d
O=α⊙(Q1(K1⊤V))∈Rn×dO=\alpha \odot (Q_1(K_1^{\top} V)) \in \mathbb R^{n\times d}O=α⊙(Q1(K1⊤V))∈Rn×d
理论上依然是O((n+m)d2)O((n+m)d^2)O((n+m)d2),但是实际上应该不会太快。
不变。
测试了各种常见,总体来说性能都有提升。
暂无。
从理论和实验来说都还不错,是一篇不错的工作,但是计算的方式有点生硬,感觉并没有抓住问题的核心。