Generalization through Memorization: Nearest Neighbor Language Models

论文地址:

代码:

参考资料:

整体思路以及计算方式

利用Knn做模型集成,整体思路如下:

  • ct=(w1,wt1)c_t=\left(w_1, \ldots w_{t-1}\right)

  • 假设有一个训练好的语言模型ff

  • 我们构造如下KV数据库:

    (K,V)={(f(ci),wi)(ci,wi)D}(\mathcal{K}, \mathcal{V})=\left\{\left(f\left(c_i\right), w_i\right) \mid\left(c_i, w_i\right) \in \mathcal{D}\right\}
  • 然后利用下式构造概率分布:

    pkNN(yx)(ki,vi)N1y=viexp(d(ki,f(x)))p_{\mathrm{kNN}}(y \mid x) \propto \sum_{\left(k_i, v_i\right) \in \mathcal{N}} \mathbb{1}_{y=v_i} \exp \left(-d\left(k_i, f(x)\right)\right)
  • 最后模型的输出为:

    p(yx)=λpkNN(yx)+(1λ)pLM(yx)p(y \mid x)=\lambda p_{\mathrm{kNN}}(y \mid x)+(1-\lambda) p_{\mathrm{LM}}(y \mid x)

简评

  • 最后的效果是十分明显的,这里唯一的问题是,KV数据库和训练文本大小成正比,如果训练文本太大,则开销太大;

  • 另一方面Knn的作用似乎是记忆功能,所以基于Transformer的模型似乎记忆能力较弱?是否可以引入类似功能的模块提升性能;

  • 基于检索的LM是否有可行性?

Last updated