Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

五倍吞吐量,性能全面包围Transformer:新架构Mamba引爆AI圈

屹立不倒的 Transformer 迎来了一个强劲竞争者。

在别的领域,如果你想形容一个东西非常重要,你可能将其形容为「撑起了某领域的半壁江山」。但在 AI 大模型领域,Transformer 架构不能这么形容,因为它几乎撑起了「整个江山」。

自 2017 年被提出以来,Transformer 已经成为 AI 大模型的主流架构,但随着模型规模的扩展和需要处理的序列不断变长,Transformer 的局限性也逐渐凸显。一个很明显的缺陷是:Transformer 模型中自注意力机制的计算量会随着上下文长度的增加呈平方级增长,比如上下文增加 32 倍时,计算量可能会增长 1000 倍,计算效率非常低。

为了克服这些缺陷,研究者们开发出了很多注意力机制的高效变体,但这往往以牺牲其有效性特为代价。到目前为止,这些变体都还没有被证明能在不同领域发挥有效作用。

最近,一项名为「Mamba」的研究似乎打破了这一局面。

图片

在这篇论文中,研究者提出了一种新的架构 ——「选择性状态空间模型( selective state space model)」。它在多个方面改进了先前的工作。

作者表示,「Mamba」在语言建模方面可以媲美甚至击败 Transformer。而且,它可以随上下文长度的增加实现线性扩展,其性能在实际数据中可提高到百万 token 长度序列,并实现 5 倍的推理吞吐量提升。

消息一出,人们纷纷点赞,有人表示已经迫不及待想要把它用在大模型上了。

图片

作为通用序列模型的骨干,Mamba 在语言、音频和基因组学等多种模态中都达到了 SOTA 性能。在语言建模方面,无论是预训练还是下游评估,他们的 Mamba-3B 模型都优于同等规模的 Transformer 模型,并能与两倍于其规模的 Transformer 模型相媲美。

图片

这篇论文的作者只有两位,一位是卡内基梅隆大学机器学习系助理教授 Albert Gu,另一位是 Together.AI 首席科学家、普林斯顿大学计算机科学助理教授(即将上任)Tri Dao。

图片左:Albert Gu;右:Tri Dao。

Albert Gu 表示,这项研究的一个重要创新是引入了一个名为「选择性 SSM」的架构,该架构是 Albert Gu 此前主导研发的 S4 架构(Structured State Spaces for Sequence Modeling ,用于序列建模的结构化状态空间)的一个简单泛化,可以有选择地决定关注还是忽略传入的输入。一个「小小的改变」—— 让某些参数成为输入的函数,结果却非常有效。

图片

值得一提的是,S4 是一个非常成功的架构。此前,它成功地对  Long Range Arena (LRA) 中的长程依赖进行了建模,并成为首个在 Path-X 上获得高于平均性能的模型。更具体地说,S4 是一类用于深度学习的序列模型,与 RNN、CNN 和经典的状态空间模型(State Space Model,SSM)广泛相关。SSM 是独立的序列转换,可被整合到端到端神经网络架构中( SSM 架构有时也称 SSNN,它与 SSM 层的关系就像 CNN 与线性卷积层的关系一样)。Mamba 论文也讨论了一些著名的 SSM 架构,比如 Linear attention、H3、Hyena、RetNet、RWKV,其中许多也将作为论文研究的基线。Mamba 的成功让 Albert Gu 对 SSM 的未来充满了信心。

图片

Tri Dao 则是 FlashAttentionFlash Attention v2Flash-Decoding的作者。FlashAttention 是一种对注意力计算进行重新排序并利用经典技术(平铺、重新计算)加快速度并将内存使用从序列长度的二次减少到线性的算法。Flash Attention v2、Flash-Decoding 都是建立在 Flash Attention 基础上的后续工作,把大模型的长文本推理效率不断推向极限。在 Mamba 之前,Tri Dao 和 Albert Gu 也有过合作。

图片

另外,这项研究的模型代码和预训练的检查点是开源的,参见以下链接:https://github.com/state-spaces/mamba.

图片

图片

论文链接:https://arxiv.org/ftp/arxiv/papers/2312/2312.00752.pdf

方法创新

论文第 3.1 节介绍了如何利用合成任务的直觉来启发选择机制,第 3.2 节解释了如何将这一机制纳入状态空间模型。由此产生的时变 SSM 不能使用卷积,导致了高效计算的技术难题。研究者采用了一种硬件感知算法,利用当前硬件的内存层次结构来克服这一难题(第 3.3 节)。第 3.4 节描述了一个简单的 SSM 架构,不需要注意力,甚至不需要 MLP 块。第 3.5 节讨论了选择机制的一些其他特性。

选择机制

研究者发现了此前模型的一个关键局限:以依赖输入的方式高效选择数据的能力(即关注或忽略特定输入)。

序列建模的一个基本方法是将上下文压缩到更小的状态,我们可以从这个角度来看待当下流行的序列模型。例如,注意力既高效又低效,因为它根本没有明确压缩上下文。这一点可以从自回归推理需要明确存储整个上下文(即 KV 缓存)这一事实中看出,这直接导致了 Transformer 缓慢的线性时间推理和二次时间训练。

递归模型的效率很高,因为它们的状态是有限的,这意味着恒定时间推理和线性时间训练。然而,它们的高效性受限于这种状态对上下文的压缩程度。

为了理解这一原理,下图展示了两个合成任务的运行示例:

图片

研究者设计了一种简单的选择机制,根据输入对 SSM 参数进行参数化。这样,模型就能过滤掉无关信息,并无限期地记住相关信息。

将选择机制纳入模型的一种方法是让影响序列交互的参数(如 RNN 的递归动力学或 CNN 的卷积核)与输入相关。算法 1 和 2 展示了本文使用的主要选择机制。其主要区别在于,该方法只需将几个参数 ∆,B,C 设置为输入函数,并在整个过程中改变张量形状。这些参数现在都有一个长度维度 L ,意味着模型已经从时间不变变为时间可变。

图片

硬件感知算法

上述变化对模型的计算提出了技术挑战。所有先前的 SSM 模型都必须是时间和输入不变的,这样才能提高计算效率。为此,研究者采用了一种硬件感知算法,通过扫描而不是卷积来计算模型,但不会将扩展状态具体化,以避免在 GPU 存储器层次结构的不同级别之间进行 IO 访问。由此产生的实现方法在理论上(与所有基于卷积的 SSM 的伪线性相比,在序列长度上呈线性缩放)和现有硬件上都比以前的方法更快(在 A100 GPU 上可快达 3 倍)。

图片

架构

研究者将先前的 SSM 架构设计与 Transformer 的 MLP 块合并为一个块,从而简化了深度序列模型架构,形成了一种包含选择性状态空间的简单、同质的架构设计(Mamba)。

与结构化 SSM 一样,选择性 SSM 也是一种独立的序列变换,可以灵活地融入神经网络。H3 架构是著名的同质化架构设计的基础,通常由线性注意力启发的块和 MLP(多层感知器)块交错组成。

研究者简化了这一架构,将这两个部分合二为一,均匀堆叠,如图 3。他们受到门控注意力单元(GAU)的启发,该单元也对注意力做了类似的处理。

图片

选择性 SSM 以及 Mamba 架构的扩展是完全递归模型,几个关键特性使其适合作为在序列上运行的通用基础模型的骨干:

  1. 高质量:选择性为语言和基因组学等密集模型带来了强大的性能。

  2. 快速训练和推理:在训练过程中,计算量和内存与序列长度成线性关系,而在推理过程中,由于不需要缓存以前的元素,自回归展开模型每一步只需要恒定的时间。

  3. 长上下文:质量和效率共同提高了实际数据的性能,序列长度可达 100 万。

实验评估

实证验证了 Mamba 作为通用序列基础模型骨干的潜力,无论是在预训练质量还是特定领域的任务性能方面,Mamba 都能在多种类型的模态和环境中发挥作用:

合成任务。在复制和感应头等重要的语言模型合成任务上,Mamba 不仅能轻松解决,而且能推断出无限长的解决方案(>100 万 token)。

图片

音频和基因组学。在音频波形和 DNA 序列建模方面,Mamba 在预训练质量和下游指标方面都优于 SaShiMi、Hyena、Transformer 等先前的 SOTA 模型(例如,在具有挑战性的语音生成数据集上将 FID 降低了一半以上)。在这两种情况下,它的性能随着上下文长度的增加而提高,最高可达百万长度的序列。

图片

语言建模。Mamba 是首个线性时间序列模型,在预训练复杂度和下游评估方面都真正达到了 Transformer 质量的性能。通过多达 1B 参数的缩放规律,研究者发现 Mamba 的性能超过了大量基线模型,包括 LLaMa 这种非常强大的现代 Transformer 训练配方。

图片

与类似规模的 Transformer 相比,Mamba 具有 5 倍的生成吞吐量,而且 Mamba-3B 的质量与两倍于其规模的 Transformer 相当(例如,与 Pythia-3B 相比,常识推理的平均值高出 4 分,甚至超过 Pythia-7B)。

图片

扩展阅读:

想把半本《红楼梦》搬进 ChatGPT 输入框?先把这个问题解决掉

产业MambaTransformer
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

感知技术

知觉或感知是外界刺激作用于感官时,脑对外界的整体的看法和理解,为我们对外界的感官信息进行组织和解释。在认知科学中,也可看作一组程序,包括获取信息、理解信息、筛选信息、组织信息。与感觉不同,知觉反映的是由对象的各样属性及关系构成的整体。

自注意力技术

自注意力(Self-attention),有时也称为内部注意力,它是一种涉及单序列不同位置的注意力机制,并能计算序列的表征。自注意力在多种任务中都有非常成功的应用,例如阅读理解、摘要概括、文字蕴含和语句表征等。自注意力这种在序列内部执行 Attention 的方法可以视为搜索序列内部的隐藏关系,这种内部关系对于翻译以及序列任务的性能非常重要。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

注意力机制技术

我们可以粗略地把神经注意机制类比成一个可以专注于输入内容的某一子集(或特征)的神经网络. 注意力机制最早是由 DeepMind 为图像分类提出的,这让「神经网络在执行预测任务时可以更多关注输入中的相关部分,更少关注不相关的部分」。当解码器生成一个用于构成目标句子的词时,源句子中仅有少部分是相关的;因此,可以应用一个基于内容的注意力机制来根据源句子动态地生成一个(加权的)语境向量(context vector), 然后网络会根据这个语境向量而不是某个固定长度的向量来预测词。

张量技术

张量是一个可用来表示在一些矢量、标量和其他张量之间的线性关系的多线性函数,这些线性关系的基本例子有内积、外积、线性映射以及笛卡儿积。其坐标在 维空间内,有 个分量的一种量,其中每个分量都是坐标的函数,而在坐标变换时,这些分量也依照某些规则作线性变换。称为该张量的秩或阶(与矩阵的秩和阶均无关系)。 在数学里,张量是一种几何实体,或者说广义上的“数量”。张量概念包括标量、矢量和线性算子。张量可以用坐标系统来表达,记作标量的数组,但它是定义为“不依赖于参照系的选择的”。张量在物理和工程学中很重要。例如在扩散张量成像中,表达器官对于水的在各个方向的微分透性的张量可以用来产生大脑的扫描图。工程上最重要的例子可能就是应力张量和应变张量了,它们都是二阶张量,对于一般线性材料他们之间的关系由一个四阶弹性张量来决定。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

语言模型技术

语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。由于字词与句子都是任意组合的长度,因此在训练过的语言模型中会出现未曾出现的字串(资料稀疏的问题),也使得在语料库中估算字串的机率变得很困难,这也是要使用近似的平滑n元语法(N-gram)模型之原因。

常识推理技术

常识推理是人工智能(AI)的一个分支,它关注模拟人类每天遇到的普通情境的类型和本质的假设。这些假设包括对人和物体的物理特性,目的,意图和行为的判断,以及他们的行为和相互作用的可能结果。展示常识推理的设备将能够预测结果并得出类似于人类民间心理学(人类对人们的行为和意图进行推理的天生能力)和天真物理学(人类对物理世界的自然理解)的结论。

推荐文章
暂无评论
暂无评论~