Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

模型A:幸亏有你,我才不得0分,模型B:俺也一样

现在大模型都学会借力了。


琳琅满目的乐高积木,通过一块又一块的叠加,可以创造出各种栩栩如生的人物、景观等,不同的乐高作品相互组合,又能为爱好者带来新的创意。

我们把思路打开一点,在大模型(LLM)爆发的当下,我们能不能像拼积木一样,把不同的模型搭建起来,而不会影响原来模型的功能,还能起到 1+1>2 的效果。

这样的想法,谷歌已经实现了。他们的研究为未来的语言模型发展提供了一个新的方向,特别是在资源节约和模型适应性方面。

如今的大语言模型(LLM)仿佛一个全能战士,能进行常识和事实推理、懂得世界知识、生成连贯的文本…… 在这些基础功能的底座上,研究者们又进行了一系列努力对这些模型进行微调,以实现特定于领域的功能,如代码生成、文案编辑以及解决数学问题等。

但这些特定于领域的模型开始出现一些棘手的问题,例如,有些模型在标准代码生成方面做得很好,但在一般逻辑推理方面并不精通,反之亦然。

我们不禁要问:是否可以将 anchor 模型(即具有基础功能的模型)与特定于领域的增强模型组合在一起,从而开启模型新功能?例如,我们能否将理解代码的增强模型与 anchor 模型的语言生成能力组合起来,以实现从代码 - 文本的生成能力?

在此之前,该问题典型的解决方案是在最初用于训练增强模型的数据上进行进一步的预训练或微调 anchor 模型。然而,很多时候这样的解决方案是不可行的,因为训练大模型的计算成本很高。此外,由于数据隐私等问题,处理来自多个来源的数据可能不可行。

为了解决上述训练成本和数据带来的挑战,谷歌提出并研究了进行模型组合的实际设置,这些设置包括:(i)研究者可以访问一个或多个增强模型和 anchor 模型,(ii)不允许修改任一模型的权重,并且(iii)只能访问少量数据,这些数据代表了给定模型的组合技能。

图片

论文地址:https://arxiv.org/pdf/2401.02412.pdf

该研究是这样实现的,他们提出了一种新颖的 CALM(组合到增强语言模型 Composition to Augment Language Models)框架来解决模型组合设置。CALM 不是增强和 anchor LM 的浅层组合,而是在增强和 anchor 模型的中间层表示上引入了少量的可训练参数

这种方法不仅资源高效,只需增加少量额外的参数和数据,就能扩展到新任务上,比完全重新训练模型要经济得多。而且比单独使用一个模型能够更准确地执行新的挑战性任务,同时还能保留各个模型的功能。CALM 对特定任务和低资源语言也提供了更好的支持。

这种通过组合方式扩展模型功能的创新得到了很多人的好评:

「这项研究以及类似的 MoE 研究真的很令人惊讶。像堆乐高积木一样把模型拼在一起就行了!」

图片

还有人表示:「我们离 AI 奇点又近了一步!」

图片

方法介绍

对于给定的 anchor 模型 m_B 和增强模型 m_A,CALM 旨在将这两种模型结合起来,组成 m_(A⊕B),使得新模型的能力成为两个独立模型能力的组合。

研究过程中,开发人员做了以下假设:i)他们可以访问模型的权重,向前、向后传播,并有权限访问 m_B 和 m_A 的中间表示,ii)不允许更改两个模型的权重,iii)研究者无法访问两个基本模型的训练数据、超参数、训练状态,iv)研究者能提供来自目标组合域的一些示例。

在上述假设下,该研究的目标是学习组合 图片以实现某些联合任务 C。其中 m_B 和 m_A 的权重被冻结,θ_C 是为学习组合而引入的附加可训练参数集,D_C 是指用于学习该组合的示例集。

可训练参数

该研究在 m_B 和 m_A 的选定层上进行操作。具体而言,他们在这些层上学习两组附加参数:(i)一组是简单的线性变换,f_proj(.),它将来自 m_A 的第 i 层表示映射到来自 m_B 的表示的维度,以及(ii)一组交叉 - 注意力层,f_cross (.,.),该层位于线性变换后的层表示和 m_B 的第 j 层表示之间。

如图 1 所示,图中展示了具有不同功能的 m_A(黄色块):键值映射(左)、低资源语言(中)和代码(右)。模型 m_A 和 m_B 在合成过程中保持不变 。那些额外的参数是通过模型的层表示来学习的。最左边的图显示了在一组字符串 - 整数映射上训练的 m_A,例如 {x_1 : 10……,x_n:2}。m_B 是一个具有算术能力的大型 LM。CALM 组合这两个冻结模型来解决任一模型无法自行解决的键算术(arithmetic on keys)任务。值得注意的是,尽管使用仅涵盖 20% 键的算术示例进行训练,但 CALM 仍可扩展到整个键 - 值集。

图片

训练示例的构建

由于目标模型 m_(A⊕B)涉及两个模型 m_A 和 m_B 的组合,因此该研究还构建了一组训练示例 D_C 来描述模型的组合技能。

理想情况下,如果组合任务中包含任务 t_1 和 t_2,例如组合任务 (C) 是对一组键执行算术运算。增强模型 m_A 用来学习给定的键值对(标记为任务 t_1), anchor 模型 m_B 是可以很好地执行数字运算的通用模型(标记为任务 t_2)。

为了学习组合参数 θ_C,该研究定义 D_C 包含两个模型的组合技能。与 LoRA 等在训练期间需要整个知识源(此处为键值)的微调方法相比,本文发现仅对一小部分键进行训练组合就可以泛化到全部。

实验结果

键值算术

论文作者首先研究了这样一种情况:有一个小型的增强 LM(m_A),它已被训练成能够记忆从字符串到整数的键值(KV)映射;还有一个大型的 anchor LM(m_B),它能够对整数进行算术运算。作者希望使用 CALM 将它们组合在一起,从而实现解决包含这些键的算术表达式的新功能。

表 1 显示了 m_A、m_B 和 m_(A⊕B) 这三个模型在一些数据集中的表现。首先,可以看到增强模型 m_A 在 KV 替换(KV-Substitution)任务中取得了 98.1% 的成绩,这表明它能很好地记忆 D_KV。接下来,可以看到它在数字算术(Numeric-Arithmetic)任务中的表现很差(4.2%),这表明它不具备算术能力。因此,该模型无法求解包含 D_KV 的键的算术表达式。

图片

不出所料,anchor 模型 m_B 在 KV 替换和 KV 算术(KV-Arithmetic)任务中的准确率为 0%,因为它没有看到任何来自 D_KV 的数据。然而,它在数字算术任务中的表现却很好(73.7%),这表明它有能力对数字进行算术运算。

最后,可以看到组合模型 m_(A⊕B) 能够以很高的准确率解决所有任务,尤其是 KV 算术任务(84.3%),而这是两个底层模型都无法解决的。这表明组合模型能够利用增强模型和 anchor 模型的相关能力来解决复杂任务。

接下来,作者研究了能否将这样一个大型 anchor LM m_B 与经过低资源语言预训练的小型增强 LM m_A 结合在一起,以执行以这些低资源语言呈现的翻译和数学词语解题任务。

表 2 显示了模型在 FLORES-200 数据集上的表现。对于表中所示的 10 种低资源语言,可以看到基础模型 m_A 和 m_B 的表现都不如组合模型 m_(A⊕B)。作者发现,在全部 192 种语言中的 175 种语言上,组合模型 m (A⊕B) 的表现都优于 m_B(见图 2)。

图片

图片

表 3 显示了这些模型在 GSM8K 任务中低资源语言和高资源语言的小学数学单词问题上的表现。首先,可以观察到,由于数学推理能力有限,增强模型 m_A 在这项任务中表现不佳。另一方面,鉴于 anchor 模型 m_B 数学推理能力和高资源语言的迁移学习能力,它的表现要好得多。最后,作者发现在 25 种低资源语言中的 18 种和 10 种高资源语言中的 9 种上,m (A⊕B) 的表现都优于 m_A 和 m_B,这证明了模型组合的有效性。请参见表 6 以了解完整的评估结果。请注意,表 3 的最后一行显示,在 D_NTL 上微调后的 m_B 比预训练的 m_B 性能更差,这表明存在遗忘。使用 CALM 将特定领域的模型 m_A 与 m_B 组合在一起可以避免这种情况。

图片

图片

代码理解和生成

代码理解和生成需要两类不同的能力:(a)代码语法和语义知识;(b)代码所操纵的世界的知识。虽然 LLM 拥有丰富的世界知识,但由于其预训练语料库中的代码数据表示有偏差,它们往往缺乏代码语法方面的具体知识。相反,专门用代码数据训练的小模型可以很好地理解代码语法,但它们可能缺乏广泛的世界知识和推理能力。CALM 可以实现这两方面的最佳效果。

表 4 展示了单个模型 m_A 和 m_B、组合模型 m (A⊕B) 以及经过微调的 anchor 基线 图片 的性能比较。首先,在 HumanEval 数据集上进行的评估表明,由于 m_A 在 D_Code 上进行了额外的训练,它对代码语法的理解能力更强。而由于 m_B 的规模更大,而且进行了通用预训练,它在一般语言理解方面表现出色,因此在 T2C 和 C2T 任务中表现更好。

图片

当使用 CALM 来组成这两个模型时,作者通过显著的性能改进观察到了能力的清晰迁移和组合:与 m_B 相比,组合模型在 CC 和 T2C 任务上的绝对性能分别提高了 6.1% 和 3.6%。作者观察到,由于灾难性遗忘,在 D_Code 上微调 m_B 会导致 C2T 性能显著下降。在所有语言中,CALM 保持了性能,并略微优于 m_B。作者还研究了 C2T 任务的定性示例,并观察到了有趣的共同模式,详情见附录 B。

消融研究

m_A 的影响 

作者首先研究了 m_A 的影响,即在组成过程中用 vanilla 和随机变体替换 m_A。表 5 显示了在 NTL 和代码任务中,当专门的 m_A 被 vanilla PaLM2-XXS 检查点或未经训练的模型版本(即随机模型)替换时,性能的变化情况。作者发现,在所有任务中,这些变体的性能都大幅下降。在 FLORES-200 XX-En 任务中,使用 vanilla 和随机模型时,语言的组合性能分别下降到 115 和 43。与 m_B 相比,vanilla 模型的性能略有提高,这表明非专门化模型(与 m_B 的训练机制不同)可能具有正交能力,从而增强了模型的性能。这一发现验证了 CALM 的性能提升是利用 m_A 而不是增加 Θ_C 参数的结果。

图片

迭代解码的影响

作者还研究了一个变体,即将 m_A 用作编码器,也就是说,在给定时间步解码的输出 token 不会添加到 m_A 的输入中。在这种情况下,只使用 m_A 的前缀表示。这种设置与过去针对图像和文本模型的工作不太一样,后者将编码器和解码器模型组合使用。作者观察到,在采用之前的设置时,各种任务的性能都有明显下降。

与 LoRA 的比较 

最后,作者通过训练 LoRA 层来评估一种参数高效微调方法,以适应 m_B。在所有实验中,他们都设置了 LoRA rank,使添加的参数数量等于 CALM 引入的参数数量。作者还在与 CALM 相同的数据(即 D_C)上训练 LoRA。他们发现这两种方法在所有任务和指标上的性能差异都很大。

请参阅原始论文以获取更多详细信息。

参考链接:https://twitter.com/GPTDAOCN/status/1743240332136030542
工程组合到增强语言模型框架
相关数据
权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

逻辑推理技术

逻辑推理中有三种方式:演绎推理、归纳推理和溯因推理。它包括给定前提、结论和规则

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

语料库技术

语料库一词在语言学上意指大量的文本,通常经过整理,具有既定格式与标记;事实上,语料库英文 "text corpus" 的涵意即为"body of text"。

逻辑技术

人工智能领域用逻辑来理解智能推理问题;它可以提供用于分析编程语言的技术,也可用作分析、表征知识或编程的工具。目前人们常用的逻辑分支有命题逻辑(Propositional Logic )以及一阶逻辑(FOL)等谓词逻辑。

迁移学习技术

迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

语言模型技术

统计式的语言模型是借由一个几率分布,而指派几率给字词所组成的字串。语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。

算术技术

算术(英语:arithmetic)是数学最古老且最简单的一个分支,几乎被每个人使用着,从日常生活上简单的算数到高深的科学及工商业计算都会用到。一般而言,算术这一词指的是记录数字某些运算基本性质的数学分支。

推荐文章
暂无评论
暂无评论~