紫影基地

 找回密码
 立即注册
查看: 148|回复: 0

详解 Transformer-XL

[复制链接]
阅读字号:

2002

主题

2117

帖子

21万

积分

超级版主

Rank: 8Rank: 8

积分
210303
发表于 2024-3-21 12:16:39 | 显示全部楼层 |阅读模式


序列模型捕获数据长期依赖的能力在任何NLP任务中都是至关重要的,LSTM通过引进门机制将RNN的长期依赖的捕获能力提升到200个左右,Transformer的提出则进一步提升了获长期依赖的能力,但是Transformer的捕获长期依赖的能力是无限长的吗?如果有一个需要捕获几千个时间片的能力的模型才能完成的任务,Transformer能够胜任吗?答案从目前Transformer的设计来看,它还是做不到。

这篇文章介绍的Transformer-XL(extra long)则是为了进一步提升Transformer建模长期依赖的能力。它的核心算法包含两部分:片段递归机制(segment-level recurrence)和相对位置编码机制(relative positional encoding)。Transformer-XL带来的提升包括:1. 捕获长期依赖的能力;2. 解决了上下文碎片问题(context segmentation problem);3. 提升模型的预测速度和准确率。

1. Transformer回顾
关于Transformer的详细介绍可以参考论文或者我之前的文章《详解Transformer(Attention is all you need)》。Transformer-XL的提出当然是为了对传统的Transformer进行改进,在了解改进之前,我们得先看一下Transformer的缺点。

1.1 输入
NLP相关的任务都很难避免处理输入为变长数据的场景,这个问题的解决方案有两个,一是将数据输入到类似前馈神经网络这样的模型中得到长度固定的特征向量,这个方法往往因为计算资源的限制很难执行;另一个是通过数据切段或者padding的方式将数据填充到固定长度。Transfomer采取的便是第二个方案,这个值这里用
来表示,
的值在Transformer的论文中为512。

将数据分完段之后,接下来便是将分段的数据依次喂到网络中进行模型的训练,如图(1)所示。

动图封面
图1:Transformer的训练流程


这种分段式的提供数据的方式的一个很大的问题是数据并不会在段与段之间流通,因此模型能够捕获的长期依赖的上限便是段的长度。另外这种将数据分段,而不考虑段与段之间的关系无疑是非常粗暴的,对于模型的能力无疑是要打折的。这个问题便是我们所说的上下文碎片问题。

1.2 Self-Attention
这里以单头Transformer为例进行说明,对于一个长度为
的输入序列,
,通过Transformer得到的序列为

的计算方式为
中各元素的加权和:


权值
是通过一个
运算得到:


的计算则是通过Q,K两个矩阵得到:



其中
是三个权值矩阵。

1.3 测试
Transformer是一个自回归模型(auto-regressive),也就是说在测试时模型依次取时间片为
的分段,然后将整个片段提供给模型后预测一个结果,如图2所示。在下个时间片时再将这个分段向右移一个单位,这个新的片段也将通过整个网络的计算后得到一个值。Transformer的这个特性导致其预测阶段的计算量是非常大的,这也限制了其应用领域。

动图封面
图2:Transformer的测试过程

1.4 绝对位置编码
Tansformer的位置编码是以段为单位的,它使用的是无参数的sinusoid decoding matrix,表示为
,第
个元素
表示的是在这个分段中第
个元素的相对位置,
表示的是能编码的最大长度。然后这个位置编码会通过单位加的形式和词嵌入(word Embedding)合并成一个矩阵,表示为:



其中
表示第
个碎片
的词嵌入,
表示转换方程。从(1)式中我们可以看出,对于第
和第
个片段来说,它们的时间位置编码是完全相同的,我们完全没法确认它属于哪个片段或者它在分段之前的输入数据中的相对位置。

在Transformer中,self-attention可以表示为:


考虑到词嵌入,
的完整表达式为:


我们使用乘法分配律将其展开,展开式会在后面使用:


Transformer的问题是无论对于第几个片段,它们的位置编码
都是一样的,也就是说Transformer的位置编码是相对于片段的绝对位置编码(absulate position encoding),与当前内容在原始句子中的相对位置是没有关系的。

2. 相对位置编码
最先介绍相对位置编码的是论文《self-attention with relative positional representation》(后面简称RPR)。对比RNN系列的模型,Transformer的一个缺点是没有从网络结构上对位置信息进行处理,而只是把位置编码加入到了输入层。RPR的动机就是解决Transformer的这个天然缺陷,它的做法是把相对位置编码加入到了self-attention的内部。

例如图3的例子,输入序列为“I think therefore I am”,对RNN来说两个‘I’接收到的信息是不同的,第一个'I'接收的隐层状态是初始化的值,第二个'I'接收的隐层状态是经过'I think therefore'编码之后的。


图3:RNN结构具有编码相对位置的能力

而对于Transormer来说,在没有位置编码的情况下,尽管两个‘I’在句子中的位置不同,但是两个‘I’的输入信息是完全一致的。正如我在分析Transformer的文章中所说的,只在输入中加入位置信息是显然不够的,Transformer也应该在其结构中加入序列信息。这样做的好处是当我们在计算权值或者特征值的时候,额外添加了位置信息,无疑将有助于这两个变量的计算。


图4:Transformer不具有编码相对位置的结构特征

RPR提出的模型的原理是在计算第
个元素与第
个元素之间的attention的值和权值的时候加入

之间的距离编码,因为加入的是

之间的相对位置关系,因此叫做相对位置编码。

例如对一个长度为5的序列,它共有9个相对位置编码信息(当前位置编码,当前位置的前4个,当前位置的后四个),如下表所示:

Index        解释        Value
0        位置i与位置i-4之间的距离        -4
1        位置i与位置i-3之间的距离        -3
2        位置i与位置i-2之间的距离        -2
3        位置i与位置i-1之间的距离        -1
4        位置i与位置i之间的距离        0
5        位置i与位置i+1之间的距离        1
6        位置i与位置i+2之间的距离        2
7        位置i与位置i+3之间的距离        3
8        位置i与位置i+4之间的距离        4
通过加入上面的相对位置编码信息,我们再对比一下“I think therefore I am”中两个‘I’的输入有什么不同,如图5所示。(a)是第一个‘I’的相对位置编码信息,(b)是第二个‘I’的相对位置编码信息。RPR并没有根据输入序列的长度来确定需要考虑的相对位置之间的距离,而是用了一个固定的常数
,也就是说我们需要学习的相对位置编码的序列长度为
。对于
的取值,论文中给出了不同值得对比实验结果,结论是当
时,得到的效果非常接近。


图5:RPR的相对位置编码
刚刚我们已经说过,RPR需要为学习两个相对位置向量,一个是计算第
各词特特征
,另一个用于计算第
个词到第
个词之间的权值系数
,不同于投影矩阵,这两个嵌入在注意力头之间是共享的。

对比(1)式和(3)式,RPR在self-attention中添加了两个可学习的变量

。其中
的计算方式改为:



式的变化和
基本相同:



这里用加法的原因是因为这样设计计算效率更高。关于计算效率的分析,自行翻阅参考文献中的[3]和[5]。


的计算方式相同,即在
的范围之内计算相对距离,超出范围的用
或者
进行截断:





3. Transformer-XL介绍
Transformer-XL的提出旨在解决上面所列出的Transformer的三个问题,为了解决上下文碎片和推理速度慢的问题,作者推出了片段递归机制,为了解决长期依赖,作者对绝对位置编码进行了改进,并推出了相对位置编码机制。下面分别详细介绍两个优化点。

3.1 片段递归
和Transformer一样,Transformer-XL在训练的时候也是以固定长度的片段的形式进行输入的,不同的是Transformer-XL的上一个片段的状态会被缓存下来然后在计算当前段的时候再重复使用上个时间片的隐层状态。因为上个片段的特征在当前片段进行了重复使用,这也就赋予了Transformer-XL建模更长期的依赖的能力。

那么Transformer-XL是如何重用上个片段的隐层状态呢,我们通过数学的形式具体说明。长度为
的连续两个片段表示为


的隐层节点的状态表示为
,其中
是隐层节点的维度。
的隐层节点的状态
的计算过程为:




其中
表示stop-gradient,表示这一部分并不参与BP的计算,
表示两个隐层节点在长度维度进行拼接,
是模型需要学习的参数。注意

使用的是扩展了上个片段的隐层状态的
。这一部分如图6所示:

动图封面
图6:Transformer-XL的训练过程
片段递归的另一个好处是带来的推理速度的提升,对比Transformer的自回归架构每次只能前进一个时间片,Transfomer-XL的推理过程(图7)通过直接复用上一个片段的表示而不是从头计算,讲推理过程提升到以片段为单位进行推理,这种简化带来的速度提升是成百上千倍的。

动图封面
图7:Transformer-XL的推理过程
Transformer-XL是一个典型的用空间换时间的方案,因为这个方法需要对上个片段的隐层节点的状态进行缓存,无疑将增大模型的显存占用量,但依照目前硬件的发展速度来看,对一个速度和准确率都大幅提升的模型,显存是不会成为它的瓶颈的。而且只要显存足够大,其实我们也可以复用更多的之前片段的隐层状态。

从这个角度看,Transformer-XL是一个和残差网络思想非常接近的一个模型,它相当于在两个片段之间添加了一条short-cut。而复用更多片段的结构则是一个DenseNet思想的模型。

3.2 Transformer-XL的相对位置编码
Transformer-XL的相对位置编码参考了RPR中把相对位置编码加入到self-attention中的思想,Transfomer-XL在(7)式的基础上做了若干变化,得到了下面的计算方法:


第一个变化出现在了(a),(b),(c),(d)中:
被拆分成立

,也就是说输入序列和位置编码不再共享权值。
第二个变化是(b)和(d)中将绝对位置编码
换成了相对位置编码
,其中
是Transformer中采用的不需要学习的sinsoid编码矩阵,原因正如第二借所介绍的,相对位置比绝对位置更为重要。
第三个变化是(c),(d)中引入了两个新的可学习的参数

来替换Transformer中的query向量
。表明对于所有的query位置对应的query(位置)向量是相同的。 即无论query位置如何,对不同词的注意偏差都保持一致。
改进之后(16)中的四个部分也有了各自的含义:

(a) 没有考虑位置编码的原始分数,只是基于内容的寻址;
(b) 相对于当前内容的位置偏差;
(c) 从内容层面衡量键的重要性,表示全局的内容偏置;
(d) 从相对位置层面衡量键的重要性,表示全局的位置偏置。
式(16)使用乘法分配律得到的表达式为:


4. 总结
Transformer由于自回归的特性,每个时间片的预测都需要从头开始,这样的推理速度限制了它在很多场景的应用。Transformer-XL提出的递归机制,使得推理过程以段为单位,段的长度越长,无疑提速越明显,从实验结果来看,Transformer-XL提速了300-1800倍,为Transformer-XL的使用提供了基础支撑。同时递归机制增加了Transformer-XL可建模的长期依赖的长度(
),这对提升模型的泛化能力也是很有帮助的。

仿照RPR,Transformer-XL提出了自己的相对位置编码算法,此编码方法对比Transformer和RPR都有了性能的提升,而且从理论角度也有了可解释性。谷歌推出的XLNet也是使用了Transformer-XL为基础,我们会在之后的文章中进行分析。

参考文献
[1] Radford A, Wu J, Child R, et al. Language models are unsupervised multitask learners[J]. OpenAI blog, 2019, 1(8): 9.

[2] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need [C]//Advances in Neural Information Processing Systems. 2017: 5998-6008.

[3] Shaw P, Uszkoreit J, Vaswani A. Self-attention with relative position representations[J]. arXiv preprint arXiv:1803.02155, 2018.

[4] https://ai.googleblog.com/2019/0 ... g-potential-of.html

[5] https://medium.com/@_init_/how-s ... -works-28173b8c245a

[6] https://zhuanlan.zhihu.com/p/48

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|紫影基地

GMT+8, 2025-1-12 12:29 , Processed in 0.084609 second(s), 18 queries .

Powered by Discuz! X3.4

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表