RWKV

项目地址:

Training

RWKV是一个训练时并行,推理时串行的算法,非常有意思,一共有4个版本,这里逐一介绍,假设输入为XRn×d\mathbf X\in \mathbb R^{n\times d},激活函数为ffnn为序列长度,dd为特征维度。

Time mixing

Time mixing可以理解为强制bi-gram,每个token包含自己和前一个token的部分信息:

x = torch.cat([self.time_shift(x)[:,:T,:C//2], x[:,:T,C//2:]], dim=2)

Feature mixing

首先所有版本的feature mixing几乎是一样的:

  1. Time mix得到X1Rn×d\mathbf X_1\in \mathbb R^{n\times d}

  2. K=X1WkRn×e,V=X1WvRn×e,R=X1WrRn×d\mathbf K =\mathbf X_1 \mathbf W_k\in \mathbb R^{n\times e},\mathbf V =\mathbf X_1 \mathbf W_v\in \mathbb R^{n\times e}, \mathbf R =\mathbf X_1 \mathbf W_r \in \mathbb R^{n\times d}

  3. WKV=(f(K)V)WwRn×d\mathbf {WKV}= (f(\mathbf K) \odot \mathbf V)\mathbf W_w \in \mathbb R^{n\times d}

  4. RWKV=Sigmoid(R)WKVRn×d\mathbf {RWKV} =\mathrm{Sigmoid}(\mathbf R)\odot \mathbf{WKV} \in \mathbb R^{n\times d}

V1

  1. Token mix得到X1Rn×d\mathbf X_1\in \mathbb R^{n\times d}

  2. K=exp(X1Wk)Rn×e,V=X1WvRn×e,R=X1WrRn×e\mathbf K =\exp(\mathbf X_1 \mathbf W_k)\in \mathbb R^{n\times e},\mathbf V =\mathbf X_1 \mathbf W_v\in \mathbb R^{n\times e}, \mathbf R =\mathbf X_1 \mathbf W_r \in \mathbb R^{n\times e}

  3. K1=cumsum(K,d=0)Rn×e\mathbf K_{1}=\mathrm{cumsum}(\mathbf K, d=0)\in \mathbb R^{n\times e}

  4. WwRn×n,Ww[i,j]=αiλijbj\mathbf W_w \in \mathbb R^{n\times n}, \mathbf W_w [i,j]=\alpha_i \lambda^{i-j}b_j

  5. KV=KVRn×e\mathbf {KV}= \mathbf K \odot \mathbf V \in \mathbb R^{n\times e}

  6. WKV=WwKVRn×e\mathbf {WKV}=\mathbf W_w \mathbf {KV} \in \mathbb R^{n\times e}

  7. RWKV=Sigmoid(R)WKVRn×d/K1Rn×e\mathbf {RWKV} =\mathrm{Sigmoid}(\mathbf R)\odot \mathbf{WKV} \in \mathbb R^{n\times d} / \mathbf K_1 \in \mathbb R^{n\times e}

  8. O=RWKV×WoRn×d\mathbf O= \mathbf {RWKV} \times \mathbf W_o \in \mathbb R^{n\times d}

  9. O=O×γRn×d,γRn\mathbf O= \mathbf O \times \gamma \in \mathbb R^{n\times d}, \gamma \in \mathbb R^{n}

V2

步骤4修改为:

Ww[i,j,k]Rn×n×e,Ww[i,j,k]={0,ij<0ck,ij=0λkij,ij1\begin{aligned} \mathbf W_w [i,j, k]& \in \mathbb R^{n\times n \times e}, \\ \mathbf W_w [i,j, k]&=\begin{cases} 0 , i -j < 0\\ c_k, i -j =0\\ \lambda_k^{i-j}, i -j \ge 1 \end{cases} \end{aligned}

步骤3修改为:

WK[:,k]=Ww[:,:,k]KRn×1K1=cumsum(WK,d=0)Rn×e\begin{aligned} \mathbf {WK}[:, k] & = \mathbf W_w[:, :, k] \mathbf K \in \mathbb R^{n\times 1}\\ \mathbf K_{1}&=\mathrm{cumsum}(\mathbf {WK}, d=0)\in \mathbb R^{n\times e} \end{aligned}

步骤6修改为:

WKV[:,k]=Ww[:,:,k]KV[:,k]Rn×1\mathbf {WKV}[:, k]=\mathbf W_w[:,:, k] \mathbf {KV}[:, k] \in \mathbb R^{n\times 1}

删除步骤9。

V3

步骤2修改为:

  • Xp=λpX+(1λp)X1Rn×d,p=k,v,r\mathbf X_{p}=\lambda_p \mathbf X +(1-\lambda_p) \mathbf X_1\in \mathbb R^{n\times d}, p=k, v, r

  • K=exp(XkWk)Rn×e,V=XvWvRn×e,R=XrWrRn×e\mathbf K =\exp(\mathbf X_k \mathbf W_k)\in \mathbb R^{n\times e},\mathbf V =\mathbf X_v \mathbf W_v\in \mathbb R^{n\times e}, \mathbf R =\mathbf X_r \mathbf W_r \in \mathbb R^{n\times e}

V4

利用a/b=(am)/bma/b= (am)/bm,保证KK这一项的数值大小,防止数值问题。

Inference

V2, V3, V4版本RWKV\mathbf {RWKV}可以递归计算,记:

at+1=λkat+KV[t+1,:]R1×ebt+1=at+(1ck)KV[t+1,:]R1×ect+1=λkct+K[t+1,:]R1×edt+1=ct+(1ck)K[t+1,:]R1×eRWKVt+1=(R[t+1,:]/R[t,:])(at+1/bt+1)R1×e\begin{aligned} \mathbf a_{t+1} &= \lambda_k\mathbf a_{t} + \mathbf {KV}[t + 1,:]\in \mathbb R^{1\times e} \\ \mathbf b_{t+1} & =\mathbf a_t + (1-c_k) \mathbf {KV}[t + 1,:]\in \mathbb R^{1\times e} \\ \mathbf c_{t+1} &= \lambda_k\mathbf c_{t} + \mathbf {K}[t+1, :] \in \mathbb R^{1\times e} \\ \mathbf d_{t+1} &= \mathbf c_{t} + (1-c_k )\mathbf {K}[t+1, :] \in \mathbb R^{1\times e} \\ \mathbf {RWKV}_{t+1} &= (\mathbf R[t+1, :] /\mathbf R[t, :]) \odot (\mathbf a_{t+1} / \mathbf b_{t+1}) \in \mathbb R^{1\times e} \end{aligned}

Last updated