论文地址:
https://arxiv.org/abs/2302.10866arrow-up-right
整体还是Toeplitz的思路,这里理一下如何计算。
首先是基本操作fff:
其中T\mathbf TT是Toeplitz matrix。
然后是Hyena的计算方式:
假设Hyena的循环次数为mmm,输入为X∈Rn×d\mathbf X \in \mathbb R^{n\times d}X∈Rn×d;
每次循环使用基本操作fff;
第一步得到初始输入和每次循环中使用的X1\mathbf X_1X1,即fff的输入1;
U=XW∈Rn×(m+1)×d\mathbf U = \mathbf X \mathbf W \in \mathbb R^{n\times (m+1)\times d} U=XW∈Rn×(m+1)×d;
V=Conv1d(U)∈XW∈Rn×(m+1)×d\mathbf V=\mathrm{Conv1d}(\mathbf U) \in \mathbf X \mathbf W \in \mathbb R^{n\times (m+1)\times d}V=Conv1d(U)∈XW∈Rn×(m+1)×d;
kernel size为3,这一步的作用是进行local token mixing;
将V\mathbf VV按照第二个维度拆分为m+1m+1m+1个向量:
V1,…,Vm,X0∈Rn×d\mathbf V^1,\ldots, \mathbf V^m, \mathbf X^0 \in \mathbb R^{n\times d}V1,…,Vm,X0∈Rn×d;
第二步利用一个网络计算每一步的Toeplitz matrix:
T1,…,Tm∈Rn×n\mathbf T^1,\ldots, \mathbf T^m\in \mathbb R^{n\times n}T1,…,Tm∈Rn×n;
使用了类似Tnn中的Rpe加上指数衰减;
for i in 1,…,m1,\ldots, m1,…,m:
Xi=f(Vi,Xi−1,Ti)\mathbf X^{i}= f(\mathbf V^i, \mathbf X^{i-1}, \mathbf T^i)Xi=f(Vi,Xi−1,Ti)
return Xm\mathbf X^{m}Xm;
O(nmdlogn+nd2)O(nmd\log n + nd^2)O(nmdlogn+nd2)。
https://github.com/HazyResearch/safariarrow-up-right
非常有意思的工作:
local token mixing看起来比较关键;
如何去掉那个循环,是一个值得研究的问题;
Last updated 2 years ago