Resurrecting Recurrent Neural Networks for Long Sequences
论文地址:
整体思路以及计算方式
这篇论文主要解决了之前RNN无法在长序列上并行训练,或者说性能一般的问题。
动机:SSM和RNN那么像,为啥SSM work,RNN不work呢?
第一个问题:传统RNN带有非线性激活:
xk=σ(Axk−1+Buk),yk=Cxk+Duk 所以无法递推得到类似SSM的结果,那解决这点很简单,直接把激活拿掉即可:
xk=Axk−1+Buk 假设x−1=0,那么:
xk=j=0∑k−1AjBuk−j 这样第一个问题就解决,但是如果这样直接训练,效果还是很一般,因为是Aj的模长可能太大或者太小,作者使用如下方式解决,首先假设A为对角阵Λ,然后用如下方式初始化:
Λ=diag(exp(−ν+iθ)) 这样做的好处是,保证了矩阵的特征值<1,不会出现模长爆炸的情况,v,θ的初始化可以参考论文。
最后具体的实现还有一个残差部分,这里罗列一下:
xk=diag(λ)xk−1+γ⊙Bukλj=exp(−exp(νjlog)+iexp(θjlog))γj←(1−∣λj∣2)1/2 简评
初始化部分可以看看,其他部分,理解SSM的人应该不难自己得到结论。