背景
在2015年这样的年份,深度学习在自然语言、图像处理分别发展着。在自然语言领域RNN (LSTM,GRU)是最主流的模型结构,用于进行长时序建模。即使如此,这样的结构依然 对长时序不友好,比如某些翻译任务最开始出现的单词在另外一种语言可能在句子最后。 这样最早期的特征信息可能已经遗忘。
注意力机制的说明,图片源于 olah2016attention1
如上面的例子,我们在翻译的时候会重点关注正在翻译的单词。同样的,我们可以在构建 网络的时候引入注意力机制,即只关注输入信息的一部分子集。具体注意哪一部分还是 给予不同部分不同的权重,后续常用的是 Soft-attention,这样方便权重计算和训练2。 具体值的计算一般是基于内容的,以下图举例说明,关注网络RNN生成一个 query,然后 与输入A中的每一项进行点积计算相似度,从而生成注意力分布。
Self-Attention 将输入序列中的【不同位置】信息进行关联,以生成新的表示,通过不同的 MLP处理输入部分生成不同的query/key/value,然后进行attention特征提取;另外一种是 Cross-Attention,其中 query 和 key/value 来自不同的源。
以上是本文3提出的网络结构,模型分为 Encoder 和 Decoder 两部分:
- Encoder: Maps an input sequence of repr $(x_1,...,x_n)$ to a sequence of continuous representations $z=(z_1,...,z_n)$,
- Decoder: Given $z$, generates an output sequence $(y_1,...,y_m)$ of symbols one element at a time。
Encoder使用 Self-attention 堆叠模块替换了常规的 RNN 网络进行特征编码。 优势是能【同时处理】序列中的所有字段,而非RNN依赖旧时的记忆,实现高效处理长序列。 Encoder部分是序列的所有输入一同处理,Decoder 部分仍然是一个个元素依次处理; 实现上,Decoder 的训练通过对目标序列未处理部分进行掩码覆盖也能达到并行处理的效果。
下面介绍网络的三个主要组成部分:Multi-Head Attention, Feed Forward Network 和 position encoding。
先介绍 Mutli-Head Attention 的组成部分为 Scaled Dot Product Attention:
$$ \text{Attention}(Q, K, V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$
其中 $Q \in \mathbb{R}^{m \times d_k}$,$m$ 为 token 长度,$d_k$ 为每个 token 的向量表示维度, $K \in \mathbb{R}^{n \times d_k}$, $V \in \mathbb{R}^{n \times d_v}$,$\text{softmax}$ 之前 除以 $\sqrt{d_k}$ 可以缓解梯度反传时的消散问题。
假设 Attention 输入的纬度为 $d_{model}$,网络不一定要在此纬度上计算, 单一的 Attention 模块表达能力不足,所以使用多个 Attention 并联的方式,提升模型的表达能力, 同时调帧内部每个 head 的纬度(一般是降低),类似与卷积层输入、输出的不同通道的概念: $$ \begin{align} \text{MultiAttention}(Q, K, V) & = \text{Concat}(head_1,...,head_n)W^0 \\ \text{where} \; head_i & = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \end{align} $$ 其中 $W_i^Q \in \mathbb{R}^{d_{model}\times d_k}, W_i^K \in \mathbb{R}^{d_{model}\times d_k}, W_i^V \in \mathbb{R}^{d_{model}\times d_v}, W^O \in \mathbb{R}^{hd_v \times d_{model}}$。 经过 Concat 之后的特征(长度为$h d_v$)再与${W^O}$相乘最终与原输入特征长度相同, 如果输入的 Q, K, V 是同一个向量,Multi-Head Attention 可以类比为 Activation 激活函数。
第二部分是 Position-wise Feed Forward Model 来提升模型的非线性表达能力:
$$ FFN(x)=\max{(0, xW_1+b_1)}W_2+b_2 $$
Least but not last 是 Position Encoding 即位置编码模块。由于在 Attention 模块中并没有区分 不同位置的数据,但是时序输入有先后位置关系,需要将位置信息编码到特征表示中:
$$ \begin{align} PE_{(pos,2i)} & = sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos,2i+1)} & = cos(pos/10000^{2i/d_{model}}) \end{align} $$
详细的解释可以参考Position Encoding。
理解
在 Neural Turing Machines4 中引入 Memory
的概念,使得模型的输出不仅与输入有关,
而且与模型内部的记忆力有关。记忆力用于编码 query 和 key 之间的关系,使得特征输出
更具有相关性。
总结
这篇文章可能在提出的时候都不曾想到,attention 模型结构会成为后续大语言模型(Large Language Model, LLM)发展的重要起源,也成为了自然语言与视觉任务融合的一个桥梁。