Resurrecting Recurrent Neural Networks for Long Sequences

论文地址:

整体思路以及计算方式

这篇论文主要解决了之前RNN无法在长序列上并行训练,或者说性能一般的问题。

动机:SSM和RNN那么像,为啥SSM work,RNN不work呢?

第一个问题:传统RNN带有非线性激活:

xk=σ(Axk1+Buk),yk=Cxk+Dukx_k=\sigma\left(A x_{k-1}+B u_k\right), \quad y_k=C x_k+D u_k

所以无法递推得到类似SSM的结果,那解决这点很简单,直接把激活拿掉即可:

xk=Axk1+Bukx_k=A x_{k-1}+B u_k

假设x1=0x_{-1}=0,那么:

xk=j=0k1AjBukjx_k=\sum_{j=0}^{k-1} A^j B u_{k-j}

这样第一个问题就解决,但是如果这样直接训练,效果还是很一般,因为是AjA^j的模长可能太大或者太小,作者使用如下方式解决,首先假设AA为对角阵Λ\Lambda,然后用如下方式初始化:

Λ=diag(exp(ν+iθ))\Lambda=\operatorname{diag}(\exp (-\nu+i \theta))

这样做的好处是,保证了矩阵的特征值<1<1,不会出现模长爆炸的情况,v,θv, \theta的初始化可以参考论文。

最后具体的实现还有一个残差部分,这里罗列一下:

xk=diag(λ)xk1+γBukλj=exp(exp(νjlog)+iexp(θjlog))γj(1λj2)1/2x_k=\operatorname{diag}(\lambda) x_{k-1}+\gamma \odot B u_k\\ \lambda_j=\exp \left(-\exp \left(\nu_j^{\log }\right)+i \exp \left(\theta_j^{\log }\right)\right)\\ \gamma_j \leftarrow\left(1-\left|\lambda_j\right|^2\right)^{1 / 2}

简评

初始化部分可以看看,其他部分,理解SSM的人应该不难自己得到结论。

Last updated