Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

ICML 2024 | 揭示非线形Transformer在上下文学习中学习和泛化的机制

图片
AIxiv专栏是机器之心发布学术、技术内容的栏目。过去数年,机器之心AIxiv专栏接收报道了2000多篇内容,覆盖全球各大高校与企业的顶级实验室,有效促进了学术交流与传播。如果您有优秀的工作想要分享,欢迎投稿或者联系报道。投稿邮箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

本文作者李宏康,美国伦斯勒理工大学电气、计算机与系统工程系在读博士生,本科毕业于中国科学技术大学。研究方向包括深度学习理论,大语言模型理论,统计机器学习等等。目前已在 ICLR/ICML/Neurips 等 AI 顶会发表多篇论文。

上下文学习 (in-context learning, 简写为 ICL) 已经在很多 LLM 有关的应用中展现了强大的能力,但是对其理论的分析仍然比较有限。人们依然试图理解为什么基于 Transformer 架构的 LLM 可以展现出 ICL 的能力。

近期,一个来自美国伦斯勒理工大学和 IBM 研究院的团队从优化和泛化理论的角度分析了带有非线性注意力模块 (attention) 和多层感知机 (MLP) 的 Transformer 的 ICL 能力。他们特别从理论端证明了单层 Transformer 首先在 attention 层根据 query 选择一些上下文示例,然后在 MLP 层根据标签嵌入进行预测的 ICL 机制。该文章已收录在 ICML 2024。

图片

  • 论文题目:How Do Nonlinear Transformers Learn and Generalize in In-Context Learning?

  • 论文地址:https://arxiv.org/pdf/2402.15607

背景介绍

上下文学习 in context learning (ICL)

上下文学习 (ICL) 是一种新的学习范式,在大语言模型 (LLM) 中非常流行。它具体是指在测试查询 (testing query)图片前添加 N 个测试样本 testing examples (上下文),即测试输入图片和测试输出图片的组合,从而构成一个 testing prompt:图片,作为模型的输入以引导模型作出正确的推断。这种方式不同于经典的对预训练模型进行微调的方式,它不需要改变模型的权重,从而更加的高效。

ICL 理论工作的进展

近期的很多理论工作都是基于 [1] 所提出的研究框架,即人们可以直接使用 prompt 的格式来对 Transformer 进行训练 (这一步也可以理解为在模拟一种简化的 LLM 预训练模式),从而使得模型具有 ICL 能力。已有的理论工作聚焦于模型的表达能力 (expressive power) 的角度 [2]。他们发现,人们能够找到一个有着 “完美” 的参数的 Transformer 可以通过前向运算执行 ICL,甚至隐含地执行梯度下降等经典机器学习算法。但是这些工作无法回答为什么 Transformer 可以被训练成这样 “完美” 的,具有 ICL 能力的参数。因此,还有一些工作试图从 Transformer 的训练或泛化的角度理解 ICL 机制 [3,4]。不过,受制于分析 Transformer 结构的复杂性,这些工作目前止步于研究线性回归任务,而所考虑的模型通常会略去 Transformer 中的非线形部分。

本文从优化和泛化理论的角度分析了带有非线性 attention 和 MLP 的 Transformer 的 ICL 能力和机制:

  • 基于一个简化的分类模型,本文具体量化了数据的特征如何影响了一层单头 Transformer 的域内 (in-domain) 和域外 (out-of-domain, OOD) 的 ICL 泛化能力。

  • 本文进一步阐释了 ICL 是如何通过被训练的 Transformer 来实现了。

  • 基于被训练的 Transformer 的特点,本文还分析了在 ICL 推断的时候使用基于幅值的模型剪枝 (magnitude-based pruning) 的可行性。

理论部分

问题描述

本文考虑一个二分类问题,即将图片通过一个任务图片映射图片。为了解决这样的一个问题,本文构建了 prompt 来进行学习。这里的 prompt 被表示为:

图片

训练网络为一个单层单头 Transformer:

图片

预训练过程是求解一个对所有训练任务的经验风险最小化 (empirical risk minimization)。损失函数使用的是适合二分类问题的 Hinge loss,训练算法是随机梯度下降

本文定义了两种 ICL 泛化的情况。一个是 in-domain 的,即泛化的时候测试数据的分布和训练数据一样,注意这个情况里面测试任务不必和训练任务一样,即这里已经考虑了对未见任务 (unseen task) 的泛化。另一个是 out-of-domain 的,即测试、训练数据分布不一样。

本文还涉及了在 ICL 推断的时候进行 magnitude-based pruning 的分析,这里的剪枝方式是指对于训练得到的中的各个神经元,根据其幅值大小,进行从小到大的删除。

对数据和任务的构建

这一部分请参考原文的 Section 3.2,这里只做一个概述。本文的理论分析是基于最近比较火热的 feature learning 路线,即通常将数据假设为可分(通常是正交)的 pattern,从而推导出基于不同 pattern 的梯度变化。本文首先定义了一组 in-domain-relevant (IDR) pattern 用于决定 in-domain 任务的分类,和一组与任务无关的 in-domain-irrelevant (IDI) pattern,这些 pattern 之间互相正交。IDR pattern 有图片个,IDI pattern 有图片个。一个图片被表示为一个 IDR pattern 和一个 IDI pattern 的和。一个 in-domain 任务就被定义为基于某两个 IDR pattern 的分类问题

类似地,本文通过定义 out-of-domain-relevant (ODR) pattern 和 out-of-domain-irrelevant (ODI) pattern,可以刻画 OOD 泛化时候的数据和任务。

本文对 prompt 的表示可以用下图的例子来阐述,其中图片是 IDR pattern,图片是 IDI pattern。这里在做的任务是基于 x 中的图片做分类,如果是图片那么其标签为 + 1,对应于 +q,如果是图片那么其标签为 - 1,对应于 -q。α,α' 分别被定义为训练和测试 prompt 中跟 query 的 IDR/ODR pattern 一样的上下文示例。下图中的例子里面,图片

图片

理论结果

首先,对于 in-domain 的情况,本文先给了一个 condition 3.2 来规定训练任务需要满足的条件,即训练任务需要覆盖所有的 IDR pattern 和标签。然后 in-domain 的结果如下:

图片

这里表明:1,训练任务的数量只需要在全部任务中占比达到满足 condition 3.2 的小比例,我们就可以对 unseen task 实现很好的泛化;2,跟当前任务相关的 IDR pattern 在 prompt 中的比例越高,就可以以更少的训练数据,训练迭代次数,以及更短的 training/testing prompt 实现理想的泛化。

接下来是 out-of-domain 泛化的结果。

图片

这里说明,如果 ODR pattern 是 IDR pattern 的线性组合且系数和大于 1,那么此时 OOD ICL 泛化可以达到理想的效果。这个结果给出了在 ICL 的框架下,好的 OOD 泛化所需要的训练和测试数据之间的内在联系。该定理也通过 GPT-2 的实验得到了验证。如下图所示,当 (12) 中的系数和图片大于 1 的时候,OOD 分类可以达到理想的结果。与此同时,当图片,即 prompt 中和分类任务相关的 ODR/IDR pattern 比例越高的时候,所需要的 context 长度越小。

图片

然后,本文给出了带有 magnitude-based pruning 的 ICL 泛化结果。

图片

这个结果表明,首先,训练得到的图片中有一部分(常数比例)神经元的幅值很小,而剩下的相对比较大(公式 14)。当我们只枝剪小神经元的时候,对泛化结果基本没有影响,而当枝剪比例增加到要剪大神经元的时候,泛化误差会随之显著变大(公式 15,16)。以下实验验证了定理 3.7。下图 A 中浅蓝色的竖线表示训练得到的图片呈现出了公式 14 的结果。而对小神经元进行枝剪不会使泛化变差,这个结果符合理论。图 B 反映出当 prompt 中和任务相关的上下文越多的时候,我们可以允许更大的枝剪比例以达到相同的泛化性能。

图片

ICL 机制

通过对预训练过程的刻画,本文得到了单层单头非线性 Transformer 做 ICL 的内在机制,这一部分在原文的 Section 4。该过程可以用下图表示。

图片

简而言之,attention 层会选择和 query 的 ODR/IDR pattern 一样的上下文,赋予它们几乎全部 attention 权重,然后 MLP 层会重点根据 attention 层输出中的标签嵌入来作出最后的分类。

总结

本文讲解了在 ICL 当中,非线性 Transformer 的训练机制,以及对于新任务和分布偏移数据的泛化能力。理论结果对于设计 prompt 选择算法和 LLM 剪枝算法有一定实际意义。

参考文献

[1] Garg, et al., Neurips 2022. "What can transformers learn in-context? a case study of simple function classes."

[2] Von Oswald et al., ICML 2023. "Transformers learn in-context by gradient descent."

[3] Zhang et al., JMLR 2024. "Trained transformers learn linear models in-context."

[4] Huang et al., ICML 2024. "In-context convergence of transformers."

产业ICML 2024
相关数据
IBM机构

是美国一家跨国科技公司及咨询公司,总部位于纽约州阿蒙克市。IBM主要客户是政府和企业。IBM生产并销售计算机硬件及软件,并且为系统架构和网络托管提供咨询服务。截止2013年,IBM已在全球拥有12个研究实验室和大量的软件开发基地。IBM虽然是一家商业公司,但在材料、化学、物理等科学领域却也有很高的成就,利用这些学术研究为基础,发明很多产品。比较有名的IBM发明的产品包括硬盘、自动柜员机、通用产品代码、SQL、关系数据库管理系统、DRAM及沃森。

https://www.ibm.com/us-en/
相关技术
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

多层感知机技术

感知机(Perceptron)一般只有一个输入层与一个输出层,导致了学习能力有限而只能解决线性可分问题。多层感知机(Multilayer Perceptron)是一类前馈(人工)神经网络及感知机的延伸,它至少由三层功能神经元(functional neuron)组成(输入层,隐层,输出层),每层神经元与下一层神经元全互连,神经元之间不存在同层连接或跨层连接,其中隐层或隐含层(hidden layer)介于输入层与输出层之间的,主要通过非线性的函数复合对信号进行逐步加工,特征提取以及表示学习。多层感知机的强大学习能力在于,虽然训练数据没有指明每层的功能,但网络的层数、每层的神经元的个数、神经元的激活函数均为可调且由模型选择预先决定,学习算法只需通过模型训练决定网络参数(连接权重与阈值),即可最好地实现对于目标函数的近似,故也被称为函数的泛逼近器(universal function approximator)。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

剪枝技术

剪枝顾名思义,就是删去一些不重要的节点,来减小计算或搜索的复杂度。剪枝在很多算法中都有很好的应用,如:决策树,神经网络,搜索算法,数据库的设计等。在决策树和神经网络中,剪枝可以有效缓解过拟合问题并减小计算复杂度;在搜索算法中,可以减小搜索范围,提高搜索效率。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

线性回归技术

在现实世界中,存在着大量这样的情况:两个变量例如X和Y有一些依赖关系。由X可以部分地决定Y的值,但这种决定往往不很确切。常常用来说明这种依赖关系的最简单、直观的例子是体重与身高,用Y表示他的体重。众所周知,一般说来,当X大时,Y也倾向于大,但由X不能严格地决定Y。又如,城市生活用电量Y与气温X有很大的关系。在夏天气温很高或冬天气温很低时,由于室内空调、冰箱等家用电器的使用,可能用电就高,相反,在春秋季节气温不高也不低,用电量就可能少。但我们不能由气温X准确地决定用电量Y。类似的例子还很多,变量之间的这种关系称为“相关关系”,回归模型就是研究相关关系的一个有力工具。

梯度下降技术

梯度下降是用于查找函数最小值的一阶迭代优化算法。 要使用梯度下降找到函数的局部最小值,可以采用与当前点的函数梯度(或近似梯度)的负值成比例的步骤。 如果采取的步骤与梯度的正值成比例,则接近该函数的局部最大值,被称为梯度上升。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

随机梯度下降技术

梯度下降(Gradient Descent)是遵循成本函数的梯度来最小化一个函数的过程。这个过程涉及到对成本形式以及其衍生形式的认知,使得我们可以从已知的给定点朝既定方向移动。比如向下朝最小值移动。 在机器学习中,我们可以利用随机梯度下降的方法来最小化训练模型中的误差,即每次迭代时完成一次评估和更新。 这种优化算法的工作原理是模型每看到一个训练实例,就对其作出预测,并重复迭代该过程到一定的次数。这个流程可以用于找出能导致训练数据最小误差的模型的系数。

分类问题技术

分类问题是数据挖掘处理的一个重要组成部分,在机器学习领域,分类问题通常被认为属于监督式学习(supervised learning),也就是说,分类问题的目标是根据已知样本的某些特征,判断一个新的样本属于哪种已知的样本类。根据类别的数量还可以进一步将分类问题划分为二元分类(binary classification)和多元分类(multiclass classification)。

神经元技术

(人工)神经元是一个类比于生物神经元的数学计算模型,是神经网络的基本组成单元。 对于生物神经网络,每个神经元与其他神经元相连,当它“兴奋”时会向相连的神经元发送化学物质,从而改变这些神经元的电位;神经元的“兴奋”由其电位决定,当它的电位超过一个“阈值”(threshold)便会被激活,亦即“兴奋”。 目前最常见的神经元模型是基于1943年 Warren McCulloch 和 Walter Pitts提出的“M-P 神经元模型”。 在这个模型中,神经元通过带权重的连接接处理来自n个其他神经元的输入信号,其总输入值将与神经元的阈值进行比较,最后通过“激活函数”(activation function)产生神经元的输出。

查询技术

一般来说,查询是询问的一种形式。它在不同的学科里涵义有所不同。在信息检索领域,查询指的是数据库和信息系统对信息检索的精确要求

GPT-2技术

GPT-2是OpenAI于2019年2月发布的基于 transformer 的大型语言模型,包含 15 亿参数、在一个 800 万网页数据集上训练而成。据介绍,该模型是对 GPT 模型的直接扩展,在超出 10 倍的数据量上进行训练,参数量也多出了 10 倍。在性能方面,该模型能够生产连贯的文本段落,在许多语言建模基准上取得了 SOTA 表现。而且该模型在没有任务特定训练的情况下,能够做到初步的阅读理解、机器翻译、问答和自动摘要。

机器之心机构

机器之心,成立于2014年,是国内最具影响力、最专业、唯一用于国际品牌的人工智能信息服务与产业服务平台。目前机器之心已经建立起涵盖媒体、数据、活动、研究及咨询、线下物理空间于一体的业务体系,为各类人工智能从业者提供综合信息服务和产业服务。

https://www.jiqizhixin.com/
语言模型技术

统计式的语言模型是借由一个几率分布,而指派几率给字词所组成的字串。语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。

量化技术

深度学习中的量化是指,用低位宽数字的神经网络近似使用了浮点数的神经网络的过程。

推荐文章
暂无评论
暂无评论~