Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

天下苦英伟达久矣!PyTorch官方免CUDA加速推理,Triton时代要来?

近日,PyTorch 官方分享了如何实现无 CUDA 计算,对各个内核进行了微基准测试比较,并讨论了未来如何进一步改进 Triton 内核以缩小与 CUDA 的差距。

在做大语言模型(LLM)的训练、微调和推理时,使用英伟达的 GPU 和 CUDA 是常见的做法。在更大的机器学习编程与计算范畴,同样严重依赖 CUDA,使用它加速的机器学习模型可以实现更大的性能提升。

虽然 CUDA 在加速计算领域占据主导地位,并成为英伟达重要的护城河之一。但其他一些工作的出现正在向 CUDA 发起挑战,比如 OpenAI 推出的 Triton,它在可用性、内存开销、AI 编译器堆栈构建等方面具有一定的优势,并持续得到发展。

近日,PyTorch 官宣要做「无英伟达 CUDA 参与的大模型推理」。在谈到为什么要 100% 使用 Triton 进行探索时,PyTorch 表示:「Triton 提供了一条途径,使大模型 能够在不同类型的 GPU 上运行,包括英伟达、AMD英特尔和其他基于 GPU 的加速器。

此外 Triton 还在 Python 中为 GPU 编程提供了更高的抽象层,使得使用 PyTorch 能够比使用供应商特定的 API 更快地编写高性能内核。」

图片

在 PyTorch 博客中讨论了使用流行的 LLM 模型(例如 Meta 的 Llama3-8B 和 IBM 的 Granite-8B Code)实现 FP16 推理的方法,其中计算是 100% 使用 OpenAI 的 Triton 语言执行的。

对于使用基于 Triton 内核的模型生成单个 token 的时间,PyTorch 能够实现在英伟达 H100 GPU 上 Llama 和 Granite 的 CUDA 内核主导工作流程的 0.76-0.78 倍性能,以及在英伟达 A100 GPU 上的 0.62-0.82 倍。

图片

图 1. 在英伟达 H100 和 A100 上,Llama3-8B 和 Granite-8B 的 Triton 和 CUDA 变体的推理吞吐量比较。设置:批大小 = 2,输入序列长度 = 512,输出序列长度 = 256

也许告别英伟达的时候真要来了。

图片

Transformer 块的组成

PyTorch 团队首先对基于 Transformer 的模型中发生的计算进行细分。下图显示了典型 Transformer 块的「内核(kernel)」。

图片

                                     图 2

Llama3 架构的核心操作总结如下:

  • 均方根归一化(RMSNorm)

  • 矩阵乘法:Fused QKV

  • RoPE

  • 注意力

  • 矩阵乘法:输出投影

  • RMSNorm

  • 矩阵乘法:Fused Gate + Up Projection

  • 激活函数:SiLU

  • 点乘(Element Wise Multiplication)

  • 矩阵乘法:Down Projection

这些操作中的每一个都是通过在 GPU 上执行一个(或多个)内核来计算的。虽然每个内核的细节在不同的 Transformer 模型中可能有所不同,但核心操作保持不变。例如,IBM 的 Granite 8B Code 模型在 MLP 层中使用偏置,与 Llama3 不同。此类更改确实需要对内核进行修改。典型的模型是这些 Transformer 块的堆叠,这些 Transformer 块通过嵌入层连接在一起。

模型推理

典型的模型架构代码与 PyTorch 启动的 python model.py 文件共享。在默认的 PyTorch Eager Execution 模式下,这些内核都是使用 CUDA 执行的。为了实现 100% Triton 进行端到端 Llama3-8B 和 Granite-8B 推理,需要编写和集成手写 Triton 内核以及利用 torch.compile(生成 Triton 操作)。首先,PyTorch 用编译器生成的 Triton 内核替换较小的操作,其次,PyTorch 用手写的 Triton 内核替换更昂贵和复杂的计算(例如矩阵乘法和闪存注意力)。

Torch.compile 自动为 RMSNorm、RoPE、SiLU 和点乘生成 Triton 内核。使用 Nsight Systems 等工具,可以观察到这些生成的内核,它们在矩阵乘法和注意力之间表现为微小的深绿色内核。

图片

                              图 3. 使用 torch.compile 跟踪 Llama3-8B,显示用于矩阵乘法和闪存注意力的 CUDA 内核。

对于上面的跟踪,PyTorch 团队注意到,在 Llama3-8B 样式模型中,占 E2E 延迟 80% 的两个主要操作是矩阵乘法和注意力内核,并且两者仍然是 CUDA 内核。因此,为了弥补剩余的差距,PyTorch 团队用手写的 Triton 内核替换了 matmul 和注意力内核。

Triton SplitK GEMM 内核

对于线性层中的矩阵乘法,PyTorch 团队编写了一个自定义 FP16 Triton GEMM(通用矩阵 - 矩阵乘法)内核,该内核利用了 SplitK 工作分解。

GEMM 内核调优

为了实现最佳性能,PyTorch 团队使用穷举搜索方法来调整 SplitK GEMM 内核。Granite-8B 和 Llama3-8B 具有如下形状的线性层:

图片

图 4. Granite-8B 和 Llama3-8B 线性层权重矩阵形状。

每个线性层都有不同的权重矩阵形状。因此,为了获得最佳性能,必须针对每个形状轮廓调整 Triton 内核。在对每个线性层进行调整后,PyTorch 能够在 Llama3-8B 和 Granite-8B 上实现相对于未调整的 Triton 内核 1.20 倍的 E2E 加速。

Flash Attention 内核

PyTorch 团队使用不同的配置,对现有 Triton flash attention 内核进行了评估,包括

  • AMD Flash

  • OpenAI Flash

  • Dao AI Lab Flash

  • XFormers Flash

  • PyTorch FlexAttention

PyTorch 团队分别在 eager 模式和编译模式下评估了每个内核的文本生成质量。下图 5 为不同 Flash Attention 内核的比较。

图片

上图总结了 PyTorch 观察到的开箱即用情况,并预计内核 2 到 5 可以在修改后满足上述标准。不过这也表明,拥有一个可用于基准测试的内核通常只是将它用作端到端生产内核的开始。

PyTorch 团队选择在后续测试中使用 AMD flash attention 内核,它通过 torch.compile 进行编译,并在 eager 和编译模式下产生清晰的输出。

为了满足 torch.compile 与 AMD flash attention 内核的兼容性,PyTorch 团队必须将它定义为 torch 自定义算子。并且封装更复杂的 flash attention 内核遵循以下两个步骤:

一是将函数封装为一个 PyTorch 自定义算子。

图片

二是向该算子添加一个 FakeTensor 内核,并在给定 flash 输入张量的形状(q、k 和 v)时,计算 flash 内核的输出形状。

图片

在将 Triton flash 内核定义为一个自定义 op 后,PyTorch 团队可以成功地对它进行编译以实现端到端运行。

图片

                              图 6:在交换 Triton matmul 和 Triton flash attention 内核后,使用 torch.compile 的 Llama3-8B 轨迹。

从图中可以看到,在集成 SplitK 矩阵乘法内核后,torch op 封装 flash attention 内核,然后运行 torch.compile,即可实现使用 100% Triton 计算内核的前向传递。

端到端基准测试

PyTorch 团队分别对运行 Granite-8B 和 Llama3-8B 模型的英伟达 H100 和 A100(单 GPU)进行了端到端测试,使用了两种不同的配置来执行基准测试。

其中 Triton 内核配置使用了:

  • Triton SplitK GEMM

  • AMD Triton Flash Attention

CUDA 内核配置使用了

  • cuBLAS GEMM

  • cuDNN Flash Attention - Scaled Dot-Product Attention (SDPA)

在典型推理设置下,两种 eager 和 torch 编译模式的吞吐量和 inter-token 延迟如下图所示。
图片

                      图 7:H100 和 A100 上 Granite-8B 和 Llama3-8B 单 token 生成延迟(批大小 = 2,输入序列长度 = 512,输出序列长度 = 256)。

总的来说,在 H100 上,Triton 模型最高可以达到 CUDA 模型性能的 78%;在 A100 上可以达到 82%。这些性能差距是由 matmul 和 flash attention 的内核延迟造成的。

基准测试

下图 8 为 Triton 和 CUDA 内核延迟比较(英伟达 H100 上运行 Llama3-8B)。输入为一个任意 prompt(批大小 = 1,prompt 序列长度 = 44),以解码延迟时间。

最后结果显示,Triton matmul 内核比 CUDA 慢了 1.2 至 1.4 倍,而 AMD Triton Flash Attention 比 CUDA SDPA 慢了 1.6 倍。

以上结果凸显了需要进一步提升 GEMM 和 Flash Attention 等核心原语内核的性能。最近的一些工作(如 FlashAttention-3、FlexAttention) 已经提出了更好地利用底层硬件和 Triton 的方法,PyTorch 希望在它们的基础上实现更大加速。为了阐明这一点,PyTorch 团队将 FlexAttention 与 SDPA、AMD’s Triton Flash 内核进行了比较。

PyTorch 团队 正努力验证 FlexAttention 的端到端性能。目前,FlexAttention 的初始微基准测试结果表明,在查询向量较小的情况下,有望实现更长的上下文以及解码问题形状。

图片

                             图 9:英伟达 H100 SXM5 80GB 上 FlexAttention 内核基准测试(批大小 = 1,最大头数 = 32,头维数 = 128)。

未来工作

未来,PyTorch 团队计划探索进一步优化 matmuls 的方法,以便更好地利用硬件,并为基于 Triton 的方法实现更大的加速。

对于 flash attention,PyTorch 团队计划探索 FlexAttention 和 FlashAttention-3 等内核中使用到的技术,以帮助进一步缩小 Triton 与 CUDA 之间的差距。同时还将探索端到端 FP8 LLM 推理。

原文链接:https://pytorch.org/blog/cuda-free-inference-for-llms/

工程PyTorch
相关数据
英特尔机构

英特尔(NASDAQ: INTC)是全球半导体行业的引领者,以计算和通信技术奠定全球创新基石,塑造以数据为中心的未来。我们通过精尖制造的专长,帮助保护、驱动和连接数十亿设备以及智能互联世界的基础设施 —— 从云、网络到边缘设备以及它们之间的一切,并帮助解决世界上最艰巨的问题和挑战。

http://www.intel.cn/
相关技术
IBM机构

是美国一家跨国科技公司及咨询公司,总部位于纽约州阿蒙克市。IBM主要客户是政府和企业。IBM生产并销售计算机硬件及软件,并且为系统架构和网络托管提供咨询服务。截止2013年,IBM已在全球拥有12个研究实验室和大量的软件开发基地。IBM虽然是一家商业公司,但在材料、化学、物理等科学领域却也有很高的成就,利用这些学术研究为基础,发明很多产品。比较有名的IBM发明的产品包括硬盘、自动柜员机、通用产品代码、SQL、关系数据库管理系统、DRAM及沃森。

https://www.ibm.com/us-en/
相关技术
激活函数技术

在 计算网络中, 一个节点的激活函数定义了该节点在给定的输入或输入的集合下的输出。标准的计算机芯片电路可以看作是根据输入得到"开"(1)或"关"(0)输出的数字网络激活函数。这与神经网络中的线性感知机的行为类似。 一种函数(例如 ReLU 或 S 型函数),用于对上一层的所有输入求加权和,然后生成一个输出值(通常为非线性值),并将其传递给下一层。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

张量技术

张量是一个可用来表示在一些矢量、标量和其他张量之间的线性关系的多线性函数,这些线性关系的基本例子有内积、外积、线性映射以及笛卡儿积。其坐标在 维空间内,有 个分量的一种量,其中每个分量都是坐标的函数,而在坐标变换时,这些分量也依照某些规则作线性变换。称为该张量的秩或阶(与矩阵的秩和阶均无关系)。 在数学里,张量是一种几何实体,或者说广义上的“数量”。张量概念包括标量、矢量和线性算子。张量可以用坐标系统来表达,记作标量的数组,但它是定义为“不依赖于参照系的选择的”。张量在物理和工程学中很重要。例如在扩散张量成像中,表达器官对于水的在各个方向的微分透性的张量可以用来产生大脑的扫描图。工程上最重要的例子可能就是应力张量和应变张量了,它们都是二阶张量,对于一般线性材料他们之间的关系由一个四阶弹性张量来决定。

查询技术

一般来说,查询是询问的一种形式。它在不同的学科里涵义有所不同。在信息检索领域,查询指的是数据库和信息系统对信息检索的精确要求

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

文本生成技术

文本生成是生成文本的任务,其目的是使人类书写文本难以区分。

AMD机构

超威半导体(中国)有限公司专门为计算机、通信和消费电子行业设计和制造各种创新的微处理器(CPU、GPU、主板芯片组、电视卡芯片等),以及提供闪存和低功率处理器解决方案,公司成立于1969年。AMD致力为技术用户——从企业、政府机构到个人消费者——提供基于标准的、以客户为中心的解决方案。

https://www.amd.com/zh-hans
语言模型技术

统计式的语言模型是借由一个几率分布,而指派几率给字词所组成的字串。语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。

推荐文章
暂无评论
暂无评论~