# Time-aware large kernel convolutions

论文地址：

* <https://arxiv.org/abs/2002.03184>
* <https://curve.carleton.ca/system/files/etd/66a5a0db-fc74-4de3-8ad2-742c860100f5/etd_pdf/e8bfee5862f9739964ee653864b59f0b/lioutas-sequencemodelingwithlinearcomplexity.pdf>

补充：这篇论文提供了cuda代码。

## 整体思路以及计算方式

整体思路是利用卷积的方式进行序列建模，看完之后感觉非常赞，这里详细理一下计算思路。

步骤一，simple conv1d，序列建模的方式。

以一维序列$$x\_0,\ldots, x\_{n-1}$$为例，利用部分和（kernel值为1的卷积）的形式得到输出：

$$
o\_i=\sum\_{j=\alpha\_i^l}^{\alpha\_i^r} x\_j
$$

其中$$\alpha\_{i}^l , \alpha\_i^r$$为$$o\_i$$对应的边界值，注意上式的计算复杂度太高，但是可以构造前缀和降低计算复杂度：

$$
\left{\begin{array}{l}\mathcal{S}\_0=0 \ \mathcal{S}*i=\mathcal{S}*{i-1}+x\_i, \quad 1 \leq i \leq n .\end{array}\right.
$$

那么：

$$
o\_i=\mathcal{S}*{a\_i^r}-\mathcal{S}*{a\_i^l-1}
$$

那么现在的问题就是如何计算$$\alpha\_i^r, \alpha\_i^r$$，这在步骤二中可以解决。

步骤二，确定上界和下界。

首先构造可学习的参数：

$$
\tilde{a}\_i^{{l, r}}=\sigma\left(f^{{l, r}}\left(x\_i\right)\right) \in\[0,1]
$$

然后利用下式计算边界：

$$
\begin{aligned} a\_i^l & =i-\tilde{a}*i^l \cdot l*{\max } \ a\_i^r & =i+\tilde{a}*i^r \cdot r*{\max }\end{aligned}
$$

其中$$l\_{\max}, r\_{\max}$$是超参。现在的一个问题是，$$a\_i^l, a\_i^r$$不一定是整数，但是我们只能计算整数下标的$$\mathcal S\_k, k\in \mathbb Z$$，这一点利用插值即可解决：

$$
\begin{aligned} \mathcal{S}*{a\_i^l-1} & =\gamma^l \cdot \mathcal{S}*{\left\lfloor a\_i^l\right\rfloor-1}+\left(1-\gamma^l\right) \cdot \mathcal{S}*{\left\lceil a\_i^l\right\rceil-1}\ \mathcal{S}*{a\_i^r} & =\left(1-\gamma^r\right) \cdot \mathcal{S}*{\left\lfloor a\_i^r\right\rfloor}+\gamma^r \cdot \mathcal{S}*{\left\lceil a\_i^r\right\rceil}\end{aligned}
$$

步骤三：归一化和鲁棒性。

这里作者为了让算法work，增加了归一化和dropout：

$$
\tilde{o}*i=o\_i \cdot\left(\frac{1}{l*{\max }+r\_{\max }+1}\right)
$$

步骤四：

之前讨论的的都是一维的情形，然后作者将其推广到$$d$$维度时候发现性能一般，这里感觉主要问题是映射$$\tilde{a}\_i^{{l, r}}=\sigma\left(f^{{l, r}}\left(x\_i\right)\right) \in\[0,1]$$不太稳定，为了缓解这点，作者将$$d$$维拆成$$h\times (d/ h)$$，每$$d/h$$个特征共享一个$$\alpha\_i^{{l,r}}$$，并且由这$$d/h$$共同确定。

## 时间复杂度

因为是查找表，所以时间复杂度是$$O(nd)$$。

## 代码

* <https://github.com/lioutasb/TaLKConvolutions>

## 简评

非常赞的一个思路，其作用机理优点类似于window attention，其中window范围由输入确定。
