Hyena Hierarchy: Towards Larger Convolutional Language Models

论文地址:

整体思路以及计算方式

整体还是Toeplitz的思路,这里理一下如何计算。

首先是基本操作ff

f(X1,X2,T)=X1(TX2)f(\mathbf X_1, \mathbf X_2, \mathbf T)= \mathbf X_1 \odot (\mathbf T \mathbf X_2)

其中T\mathbf T是Toeplitz matrix。

然后是Hyena的计算方式:

  • 假设Hyena的循环次数为mm,输入为XRn×d\mathbf X \in \mathbb R^{n\times d}

  • 每次循环使用基本操作ff

  • 第一步得到初始输入和每次循环中使用的X1\mathbf X_1,即ff的输入1;

    • U=XWRn×(m+1)×d\mathbf U = \mathbf X \mathbf W \in \mathbb R^{n\times (m+1)\times d}

    • V=Conv1d(U)XWRn×(m+1)×d\mathbf V=\mathrm{Conv1d}(\mathbf U) \in \mathbf X \mathbf W \in \mathbb R^{n\times (m+1)\times d}

      • kernel size为3,这一步的作用是进行local token mixing;

    • V\mathbf V按照第二个维度拆分为m+1m+1个向量:

      • V1,,Vm,X0Rn×d\mathbf V^1,\ldots, \mathbf V^m, \mathbf X^0 \in \mathbb R^{n\times d}

  • 第二步利用一个网络计算每一步的Toeplitz matrix:

    • T1,,TmRn×n\mathbf T^1,\ldots, \mathbf T^m\in \mathbb R^{n\times n}

    • 使用了类似Tnn中的Rpe加上指数衰减;

  • for i in 1,,m1,\ldots, m:

    • Xi=f(Vi,Xi1,Ti)\mathbf X^{i}= f(\mathbf V^i, \mathbf X^{i-1}, \mathbf T^i)

  • return Xm\mathbf X^{m}

时间复杂度

O(nmdlogn+nd2)O(nmd\log n + nd^2)

代码

简评

非常有意思的工作:

  • local token mixing看起来比较关键;

  • 如何去掉那个循环,是一个值得研究的问题;

Last updated