哈工大讯飞联合实验室(HFL)在前期发布了多个中文预训练系列模型,目前已成为最受欢迎的中文预训练资源之一。然而,众多预训练模型体积庞大,难以满足运行时要求,为技术落地提出了新的挑战。为此,我们很高兴地向大家推出基于PyTorch框架的知识蒸馏工具包TextBrewer,提供更加方便快捷的知识蒸馏框架。我们欢迎各位读者积极试用并反馈宝贵意见。
工具地址:http://textbrewer.hfl-rc.com
论文地址:https://arxiv.org/abs/2002.12620
特点
TextBrewer为NLP中的知识蒸馏任务设计,提供方便快捷的知识蒸馏框架,主要特点包括:
- 模型无关:适用于多种模型结构(主要面向Transfomer结构)
- 方便灵活:可自由组合多种蒸馏方法,支持增加自定义损失等模块
- 非侵入式:无需对教师与学生模型本身结构进行修改
- 适用面广:支持典型NLP任务,如文本分类、阅读理解、序列标注等
TextBrewer目前支持的主要知识蒸馏技术有:
- 软标签与硬标签混合训练
- 动态损失权重调整与蒸馏温度调整
- 多种蒸馏损失函数
- 任意构建中间层特征匹配方案
- 多教师知识蒸馏
- ...
工作流程
TextBrewer工具中的一个完整工作流程如下图所示。
△ 第一步:在开始蒸馏之前的准备工作
- 训练教师模型
- 定义并初始化学生模型(随机初始化或载入预训练权重)
- 构造蒸馏用数据集的DataLoader,训练学生模型用的Optimizer和Learning rate scheduler
△ 第二步 : 知识蒸馏
- 初始化Distiller,构造训练配置(TrainingConfig)和蒸馏配置(DistillationConfig)
- 定义adaptors和callback,分别用于适配模型输入输出和训练过程中的回调
- 调用Distiller的train方法开始蒸馏
用户应先进行第一步准备工作,得到训练好的教师模型。TextBrewer负责第二步的知识蒸馏工作。为了方便用户使用,TextBrewer也提供了BasicTrainer用于训练第一步的教师模型。
知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重,然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。
实验效果
为了验证工具的效果,我们在多个中英文自然语言处理任务上进行了实验并取得了与业界公开结果相当,甚至是超过相关公开工作的效果。我们对如下几种基于Transformer的模型结构进行了知识蒸馏。
在中文实验中,我们选取了4个经典数据集,其中包括XNLI(自然语言推断)、LCQMC(句对分类)、CMRC 2018(阅读理解)、DRCD(繁体阅读理解)。其中教师模型为哈工大讯飞联合实验室发布的在英文实验中,我们选取了3个经典数据集,其中包括MNLI(自然语言推断)、SQuAD(阅读理解)、CoNLL-2003 NER(命名实体识别)。我们也列出了一些公开论文中的知识蒸馏效果,方便大家进行比较。教师模型为谷歌发布的BERT-base-cased模型。
相关资源地址
- TextBrewer知识蒸馏工具
- http://textbrewer.hfl-rc.com
- 中文BERT、RoBERTa、RBT系列模型
- http://bert.hfl-rc.com
- 中文XLNet系列模型
- http://xlnet.hfl-rc.com