BERT 等大模型性能强大,但很难部署到算力、内存有限的设备中。为此,来自华中科技大学、华为诺亚方舟实验室的研究者提出了 TinyBERT,这是一种为基于 transformer 的模型专门设计的知识蒸馏方法,模型大小还不到 BERT 的 1/7,但速度是 BERT 的 9 倍还要多,而且性能没有出现明显下降。目前,该论文已经提交机器学习顶会 ICLR 2020。
论文:https://arxiv.org/abs/1909.10351
在自然语言处理(NLP)领域,BERT 等预训练语言模型极大地提升了诸多 NLP 任务的性能。但是,这类预训练语言模型通常计算开销大,内存占用也大,因此很难在一些资源紧张的设备上有效执行。
为了在加快推理速度和降低模型大小的同时保持准确率,来自华中科技大学和华为诺亚方舟实验室的研究者提出了一种新颖的 transformer 蒸馏法,这是为基于 transformer 的模型专门设计的知识蒸馏(knowledge distillation,KD)方法。通过这种新的 KD 方法,大型 teacherBERT 模型中编码的大量知识可以很好地迁移到小型 student TinyBERT 模型中。
此外,研究者还提出了一种专门用于 TinyBERT 的两段式学习框架,从而分别在预训练和针对特定任务的学习阶段执行 transformer 蒸馏。这一框架确保 TinyBERT 可以获取 teacherBERT 的通用知识和针对特定任务的知识。除了提出新的 transformer 蒸馏法之外,研究者还提出了一种专门用于 TinyBERT 的两段式学习框架,从而分别在预训练和针对特定任务的具体学习阶段执行 transformer 蒸馏。这一框架确保 TinyBERT 可以获取 teacherBERT 的通用和针对特定任务的知识。
实证研究结果表明,TinyBERT 是有效的,在 GLUE 基准上实现了与 BERT 相当(下降 3 个百分点)的效果,并且模型大小仅为 BERT 的 13.3%(BERT 是 TinyBERT 的 7.5 倍),推理速度是 BERT 的 9.4 倍。此外,TinyBERT 还显著优于当前的 SOTA 基准方法(BERT-PKD),但参数仅为为后者的 28%,推理时间仅为后者的 31%左右。
研究者提出的 Transformer 蒸馏是专门为 Transformer 网络设计的知识蒸馏方法,下图 1 为本文提出的 Transformer 蒸馏方法概览图:
在这篇论文中,student 和 teacher 网络都是通过 Transformer 层构建的。为了表述清楚,研究者在详解 TinyBERT 之前阐述了以下问题。
假定 student 模型有 M 个 Transformer 层,teacher 模型有 N 个 Transformer 层,从 teacher 模型中选择 M 个 Transformer 层用于 Transformer 层蒸馏。n=g(m) 是 student 层到 teacher 层的映射函数,这意味着 student 模型的第 m 层从 teacher 模型的第 n 层开始学习信息。嵌入层蒸馏和预测层蒸馏也考虑进来,将嵌入层的指数设为 0,预测层的指数设为 M+1,并且对应的层映射分别定义为 0 = g(0) 和 N + 1 = g(M + 1)。下文实验部分将探讨不同的映射函数对性能的影响。在形式上,通过最小化以下目标函数,student 模型可以获取 teacher 模型的知识:其中 L_layer 是给定模型层(如 Transformer 层或嵌入层)的损失函数,λ_m 是表征第 m 层蒸馏重要度的超参数。
研究者提出的 Transformer 层蒸馏包含基于注意力的蒸馏和基于隐状态的蒸馏,具体可参考上图 1(b)。基于注意力的蒸馏是为了鼓励语言知识从 teacherBERT 迁移到 student TinyBERT 模型中。具体而言,student 网络学习如何拟合 teacher 网络中多头注意力的矩阵,目标函数定义如下:其中,h 是注意力头数。A_i ∈ R^( l×l) 是与 teacher 或 student 的第 i 个注意力头对应的注意力矩阵。
此外,(非归一化)注意力矩阵 A_i 用作拟合目标,而不是其 softmax 输出的 softmax(A_i),因为实验表明前者的设置呈现更快的收敛速度和更佳的性能。除了基于注意力的蒸馏之外,研究者还对 Transformer 层输出的知识进行蒸馏处理(具体可参考上图 1(b)),目标函数定义如下:其中 H^S∈R^l×d'和 H^T∈R^l×d 分别表示 student 和 teacher 网络的隐状态,由方程式 4 计算得到。标量值 d 和 d'分别表示 teacher 和 student 模型的隐状态,并且 d'通常小于 d,以获得更小的 student 网络。
研究者还执行了嵌入层的蒸馏,与基于隐状态的蒸馏类似,定义如下:其中矩阵 E^S 和 H^T 分别表示 student 和 teacher 网络的嵌入。在论文中,这两种矩阵的形状与隐状态矩阵相同。矩阵 W_e 表示线性变化,它起到与 W_h 类似的作用。
除了模拟中间层的行为之外,研究者还利用知识蒸馏来拟合 teacher 模型的预测结果。具体而言,他们对 student 网络 logits 和 teacher 网络 logits 之间的 soft 交叉熵损失进行惩罚:其中 z^S 和 z^T 分别表示 student 和 teacher 模型预测的 logits 向量,log_softmax() 表示 log 似然,t 表示温度值。实验表明,t=1 时运行良好。
通过以上几个蒸馏目标函数(即方程式 7、8、9 和 10),可以整合 teacher 和 student 网络之间对应层的蒸馏损失:在实验中,研究者首先执行的是中间层蒸馏(M ≥ m ≥ 0),其次是预测层蒸馏(m = M + 1)。
BERT 的应用通常包含两个学习阶段:预训练和微调。BERT 在预训练阶段学到的大量知识非常重要,并且迁移的时候也应该包含在内。因此,研究者提出了一个两段式学习框架,包含通用蒸馏和特定于任务的蒸馏,如下图 2 所示:
通用蒸馏可以帮助 student TinyBERT 学习到 teacher BERT 中嵌入的丰富知识,对于提升 TinyBERT 的泛化能力至关重要。特定于任务的蒸馏赋予 student 模型特定于任务的知识。这种两段式蒸馏可以缩小 teacher 和 student 模型之间的差距。
在通用蒸馏中,研究者使用原始 BERT 作为 teacher 模型,而且不对其进行微调,利用大规模文本语料库作为学习数据。通过在通用领域文本上执行 Transformer 蒸馏,他们获取了一个通用 TinyBERT,可以针对下游任务进行微调。然而,由于隐藏/嵌入层大小及层数显著降低,通用 TinyBERT 的表现不如 BERT。
研究者提出通过针对特定任务的蒸馏来获得有竞争力的微调 TinyBERT 模型。而在蒸馏过程中,他们在针对特定任务的增强数据集上(如图 2 所示)重新执行了提出的 Transformer 蒸馏。具体而言,微调的 BERT 用作 teacher 模型,并提出以数据增强方法来扩展针对特定任务的训练集。
此外,上述两个学习阶段是相辅相成的:通用蒸馏为针对特定任务的蒸馏提供良好的初始化,而针对特定任务的蒸馏通过专注于学习针对特定任务的知识来进一步提升 TinyBERT 的效果。
为了验证 TinyBERT 的效果,研究者在多个任务上将其与其他模型进行了比较。
研究者在 GLUE 基准上评估了 TinyBERT 的性能,结果如下表 2 所示。模型大小和推理时间的效率见下表 3。
表 3:基线模型和 TinyBERT 的模型大小和推理时间。层数量不包含嵌入和预测层。
实验结果表明:1)TinyBERT 在所有 GlUE 任务中的表现都优于 BERTSMALL,平均性能提升了 6.3%,表明本文提出的 KD 学习框架可以有效地提升小模型在下游任务中的性能;2)TinyBERT 显著超越了 KD SOTA 基线(即 BERT-PKD 和 DistillBERT),比 BERT-PKD 高出 3.9%(见图 2),但参数只有基线模型的 28%,推理时间只有基线模型的 31% 左右(见图 3);3)与 teacher BERTBASE 相比,TinyBERT 的大小仅为前者的 13.3%,但速度却是前者的 9.4 倍,而且性能损失不大。4)TinyBERT 与 Distilled BiLSTM_SOFT 模型效率相当,在 BiLSTM 基线公开的所有任务中均显示出了明显的性能提升。5)对于具有挑战性的 CoLA 数据集,所有的蒸馏小模型与 teacher 模型的性能差距都比较大。TinyBERT 与基线模型相比实现了显著的性能提升,如果利用更深、更宽的模型来捕获更复杂的语言信息,它的性能还能进一步提升。
为了测试模型大小对性能的影响,研究者在几个典型的 GLUE 任务中测试了不同大小 TinyBERT 模型的性能。结果如下表 4 所示:表 4:提升宽度、深度之后的 TinyBERT 变体与基线的性能比较结果。
本文提出的两段式 TinyBERT 学习框架包含三个关键步骤:TD(特定于任务的蒸馏)、GD(通用蒸馏)和 DA(数据蒸馏)。每个学习步骤的影响如下表 5 所示:
研究者还探索了不同目标对 TinyBERT 学习的影响,结果如下表 6 所示:表 6:不同蒸馏目标对 TinyBERT 学习的影响。
研究者探究了不同映射函数 n = g(m) 对于 TinyBERT 学习的影响,比较结果见下表 7: