Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

ICML 2024 | 梯度检查点太慢?不降速、省显存,LowMemoryBP大幅提升反向传播显存效率

图片
AIxiv专栏是机器之心发布学术、技术内容的栏目。过去数年,机器之心AIxiv专栏接收报道了2000多篇内容,覆盖全球各大高校与企业的顶级实验室,有效促进了学术交流与传播。如果您有优秀的工作想要分享,欢迎投稿或者联系报道。投稿邮箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com

本文论文一作是南开大学统计与数据科学学院研二硕士生杨雨辰,指导老师为南开大学统计与数据科学学院的徐君副教授。徐君老师团队的研究重点是计算机视觉、生成式 AI 和高效机器学习,并在顶级会议和期刊上发表了多篇论文,谷歌学术引用超过 4700 次。

自从大型 Transformer 模型逐渐成为各个领域的统一架构,微调就成为了将预训练大模型应用到下游任务的重要手段。然而,由于模型的尺寸日益增大,微调所需要的显存也逐渐增加,如何高效地降低微调显存就成了一个重要的问题。此前,微调 Transformer 模型时,为了节省显存开销,通常的做法是使用梯度检查点(gradient checkpointing,也叫作激活重算),以牺牲训练速度为代价降低反向传播(Backpropagation, BP)过程中的激活显存占用。

最近,由南开大学统计与数据科学学院徐君老师团队发表在 ICML 2024 上的论文《Reducing Fine-Tuning Memory Overhead by Approximate and Memory-Sharing Backpropagation》提出通过更改反向传播(BP)过程,在不增加计算量的情况下,显著减少峰值激活显存占用。

图片

  • 论文:Reducing Fine-Tuning Memory Overhead by Approximate and Memory-Sharing Backpropagation

  • 论文链接:https://arxiv.org/abs/2406.16282

  • 项目链接:https://github.com/yyyyychen/LowMemoryBP

文章提出了两种反向传播改进策略,分别是 Approximate Backpropagation(Approx-BP)和 Memory-Sharing Backpropagation(MS-BP)。Approx-BP 和 MS-BP 分别代表了两种提升反向传播中内存效率的方案,可以将其统称为 LowMemoryBP。无论是在理论还是实践意义上,文章都对更高效的反向传播训练提供了开创性的指导。

在理论显存分析中,LowMemoryBP 可以大幅降低来自激活函数和标准化层的激活显存占用,以 ViT 和 LLaMA 为例,可以对 ViT 微调降低 39.47% 的激活显存,可以对 LLaMA 微调降低 29.19% 的激活显存。

图片

在实际实验中,LowMemoryBP 可以有效地使包括 ViT, LLaMA, RoBERTa, BERT, Swin 在内的 Transformer 模型微调峰值显存占用降低 20%~30%,并且不会带来训练吞吐量和测试精度的损失。

Approx-BP

在传统反向传播训练中,激活函数梯度的反向回传是严格对应其导函数的,对于 Transformer 模型中常用的 GELU 和 SiLU 函数,这意味着需要将输入特征张量完整地存入激活显存中。而本文的作者提出了一套反向传播近似理论,即 Approx-BP 理论。在该理论的指导下,作者使用分段线性函数逼近激活函数,并用分段线性函数的导数(阶梯函数)替代 GELU/SiLU 梯度的反向回传。这个方法导出了两个非对称的内存高效激活函数:ReGELU2 和 ReSiLU2。这类激活函数由于使用 4 段阶梯函数进行反向回传,从而使得激活存储只需要使用 2bit 数据类型。

图片

图片

MS-BP

BP 网络每一层通常都会将输入张量存入激活显存以用作反向传播计算。作者指出如果可以将某一层的反向传播改写成依赖输出的形式,那么这一层和后一层就可以共享同一个激活张量,从而降低激活存储的冗余。

而文章指出 Transformer 模型中常用的 LayerNorm 和 RMSNorm,在将仿射参数合并到后一层的线性层之后,可以很好地符合 MS-BP 策略的要求。经过重新设计的 MS-LayerNorm 和 MS-RMSNorm 不再产生独立的激活显存。

图片

实验结果

作者对计算机视觉自然语言处理领域的若干个代表模型进行了微调实验。其中,在 ViT,LLaMA 和 RoBERTa 的微调实验中,文章提出的方法分别将峰值显存占用降低了 27%,29% 和 21%,并且没有带来训练效果和训练速度的损失。注意到,作为对比的 Mesa(一个 8-bit Activation Compressed Training 方法)使训练速度降低了约 20%,而文章提出的 LowMemoryBP 方法则完全保持了训练速度。

图片

图片

图片

结论及意义

文章提出的两种 BP 改进策略,Approx-BP 和 MS-BP,均在保持训练效果和训练速度的同时,实现了激活显存的显著节省。这意味着从 BP 原理上进行优化是非常有前景的显存节省方案。此外,文章提出的 Approx-BP 理论突破了传统神经网络的优化框架,为使用非配对导数提供了理论可行性。其导出的 ReGELU2 和 ReSiLU2 展现了这一做法的重要实践价值。

欢迎大家阅读论文或者代码去了解算法的详细细节,LowMemoryBP 项目的 github 仓库上已经开源相关的模块。

工程LowMemoryBPICML 2024
相关数据
激活函数技术

在 计算网络中, 一个节点的激活函数定义了该节点在给定的输入或输入的集合下的输出。标准的计算机芯片电路可以看作是根据输入得到"开"(1)或"关"(0)输出的数字网络激活函数。这与神经网络中的线性感知机的行为类似。 一种函数(例如 ReLU 或 S 型函数),用于对上一层的所有输入求加权和,然后生成一个输出值(通常为非线性值),并将其传递给下一层。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

数据科学技术

数据科学,又称资料科学,是一门利用数据学习知识的学科,其目标是通过从数据中提取出有价值的部分来生产数据产品。它结合了诸多领域中的理论和技术,包括应用数学、统计、模式识别、机器学习、数据可视化、数据仓库以及高性能计算。数据科学通过运用各种相关的数据来帮助非专业人士理解问题。

导数技术

导数(Derivative)是微积分中的重要基础概念。当函数y=f(x)的自变量x在一点x_0上产生一个增量Δx时,函数输出值的增量Δy与自变量增量Δx的比值在Δx趋于0时的极限a如果存在,a即为在x0处的导数,记作f'(x_0) 或 df(x_0)/dx。

张量技术

张量是一个可用来表示在一些矢量、标量和其他张量之间的线性关系的多线性函数,这些线性关系的基本例子有内积、外积、线性映射以及笛卡儿积。其坐标在 维空间内,有 个分量的一种量,其中每个分量都是坐标的函数,而在坐标变换时,这些分量也依照某些规则作线性变换。称为该张量的秩或阶(与矩阵的秩和阶均无关系)。 在数学里,张量是一种几何实体,或者说广义上的“数量”。张量概念包括标量、矢量和线性算子。张量可以用坐标系统来表达,记作标量的数组,但它是定义为“不依赖于参照系的选择的”。张量在物理和工程学中很重要。例如在扩散张量成像中,表达器官对于水的在各个方向的微分透性的张量可以用来产生大脑的扫描图。工程上最重要的例子可能就是应力张量和应变张量了,它们都是二阶张量,对于一般线性材料他们之间的关系由一个四阶弹性张量来决定。

计算机视觉技术

计算机视觉(CV)是指机器感知环境的能力。这一技术类别中的经典任务有图像形成、图像处理、图像提取和图像的三维推理。目标识别和面部识别也是很重要的研究领域。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

自然语言处理技术

自然语言处理(英语:natural language processing,缩写作 NLP)是人工智能和语言学领域的分支学科。此领域探讨如何处理及运用自然语言;自然语言认知则是指让电脑“懂”人类的语言。自然语言生成系统把计算机数据转化为自然语言。自然语言理解系统把自然语言转化为计算机程序更易于处理的形式。

机器之心机构

机器之心,成立于2014年,是国内最具影响力、最专业、唯一用于国际品牌的人工智能信息服务与产业服务平台。目前机器之心已经建立起涵盖媒体、数据、活动、研究及咨询、线下物理空间于一体的业务体系,为各类人工智能从业者提供综合信息服务和产业服务。

https://www.jiqizhixin.com/
推荐文章
暂无评论
暂无评论~