Last updated 2 years ago
论文地址:
思路非常简单,降低infercence时间复杂度:
输入:X∈Rn×dX\in \mathbb R^{n\times d}X∈Rn×d
W1=WQWK⊤∈Rd×d,W2=WvWo∈Rd×dW_1= W_QW_K^{\top} \in \mathbb R^{d\times d}, W_2= W_v W_o \in \mathbb R^{d\times d}W1=WQWK⊤∈Rd×d,W2=WvWo∈Rd×d
S1=XW1X⊤∈Rn×n(=QK⊤)S_1 = XW_1 X^{\top} \in \mathbb R^{n\times n}(=QK^{\top} )S1=XW1X⊤∈Rn×n(=QK⊤)
O1=Softmax(S1)XW2∈Rn×dO_1=\mathrm{Softmax(S_1)}XW_2\in \mathbb R^{n\times d}O1=Softmax(S1)XW2∈Rn×d
不考虑。
主要是用于inference,可以提升不少速度。
暂无。
非常好的思路,感觉可以尝试在training上。