Parallelizing Linear Recurrent Neural Nets Over Sequence Length

Parallelizing Linear Recurrent Neural Nets Over Sequence Length

论文地址:

备注:作者提供了CUDA版本的实现。

整体思路以及计算方式

RNN的思路,不过中间隐状态的计算不使用激活函数:

ht=αtht1+(1αt)xth_t=\alpha_t \odot h_{t-1}+\left(1-\alpha_t\right) \odot x_t

利用这种形式,可以利用并行算法在O(nlogn)O(n\log n)时间复杂度内得到结果。

代码

简评

非常有意思的论文,S4其实思路和这个类似,个人觉得这篇被严重低估。

Last updated