摘要:Thinking Machines Lab 在 X 平台高调宣布,将频繁发布博客文章、开源代码以及各种研究成果,既是「造福公众」,也是为了「优化我们自己的研究文化」。
一群 OpenAI 前高管,创立 Thinking Machines Lab 才半年时间,连个正经产品都没发布,估值已经狂飙 120 亿美元(折合人民币 850 亿元)。
而就在刚刚,他们终于憋出了第一个「大招」——发布了成立以来的首篇重磅研究论文。
Thinking Machines Lab 在 X 平台高调宣布,将频繁发布博客文章、开源代码以及各种研究成果,既是「造福公众」,也是为了「优化我们自己的研究文化」。
这篇作为全新博客系列《Connectionism》的开山之作,显然也是宏大计划的第一步。
有意思的是,OpenAI 当年创立时也信誓旦旦地承诺要开放研究,结果随着越做越大,变得越来越封闭。至于这群 OpenAI 前高管能否坚持初心,还真不好说。
今天,Thinking Machines Lab 正式推出我们的研究博客 Connectionism。我们的第一篇文章是《击败 LLM 推理中的非确定性》。我们相信,科学在分享中才能更好地发展。Connectionism 将涵盖与我们研究一样多样的主题:从内核数值计算到提示词工程。在这里,我们会公开而频繁地分享正在进行的工作,并与研究社区保持交流。Connectionism 这个名字向人工智能早期的一段历史致敬。在 20 世纪 80 年代,它是一个研究方向的名称,专注于神经网络及其与生物大脑的相似性。
Mira Murati 是 OpenAI 的前 CTO,同时目前也是 Thinking Machines Lab 的创始人兼 CEO。她在 7 月份就曾表示 Thinking Machines Lab 已经完成 20 亿美元的融资,且首款产品将在未来几个月内亮相,并且会对研究人员和正在开发定制模型的初创公司大有裨益。
那么这篇论文具体讲了啥?
你有没有发现,同样的问题问 ChatGPT 好几遍,每次得到的答案都不太一样?在 AI 圈子里,这种现象早就被视作习以为常——大家都认为现在的 AI 模型就是概率模型。
但 Thinking Machines Lab 显然不信这个邪。他们首发的这篇文章题为《克服 LLM 推理中的不确定性》,试图解决这个老大难。
论文认为 AI 模型产生随机性的罪魁祸首,其实是 GPU 内核——就是那些在英伟达芯片里跑的小程序——在推理过程中的拼接方式有问题。
文章的核心作者是 Thinking Machines Lab 的 AI 研究员 Horace He,他的解决思路很巧妙,如果能精确控制这一层的执行流程,就有可能让 AI 模型的输出变得更加稳定可靠。
此外,除了提供更靠谱的 AI 响应,让 AI 模型生成可重复的答案还有个更大的价值——能显著改善强化学习(RL)训练效果。
RL 的基本原理是通过奖励机制来强化 AI 模型的正确输出,但如果每次答案都有细微差别,训练数据就会变得很「嘈杂」。Horace He 认为,更一致的响应能让整个 RL 训练过程「丝滑」很多。
根据之前 The Information 的爆料,Thinking Machines Lab 已经向投资者透露,他们计划用 RL 技术为企业量身定制 AI 模型。
这篇论文发布后不久,引来百万网友的围观,也有不少调侃称,Thinking Machines Lab 才是真正的 Open AI,从目前的反响热度来看,Thinking Machines Lab 这第一炮算是打响了。
克服 LLM 推理中的不确定性
可重复性是科学进步的基石。然而,要让大型语言模型(LLM)产生可重复的结果,却异常困难。
比如,你可能会发现,多次向 ChatGPT 提出同一个问题,它给出的回答却不一样。这本身并不奇怪,因为语言模型生成结果的过程涉及「采样」:模型先将输出转化为一个概率分布,然后再按照概率选择一个词元(token)。
更令人意外的是,即便我们将温度参数调到 0(这意味着 LLM 总是选择概率最高的词元,也就是所谓的贪婪采样,从理论上讲是确定性的),LLM API 在实际运行中仍然不是确定性的。即便是在你自己的硬件上,用像 vLLM 或 SGLang 这样的开源推理库运行推理,采样结果依然不具备确定性。
那么,为什么 LLM 推理引擎不是确定性的呢?一个常见的假设是:浮点数运算的非结合性和并发执行共同作用,导致结果取决于哪个并发核心先完成运算。我们把这种解释称为 「并发 + 浮点数」假设。比如,一篇最新的 arXiv 预印本写道:
GPU 中的浮点运算具有非结合性,也就是说(a+b)+c \neq a+(b+c)这是由于有限精度和舍入误差造成的。这个特性会直接影响 Transformer 架构中的注意力分数和 logits 的计算,在多个线程并行操作时,执行顺序不同就可能导致不同的结果。
你也能在其他地方看到类似的「并发 + 浮点数」假设。比如这里有人说:「为了加快响应速度,端点会使用 GPU,而 GPU 进行的是并行(非确定性)计算,任何现代 GPU 的神经网络计算都会受到这个影响。」 还有人说:「由于 GPU 高度并行化,每次执行时加法或乘法的顺序可能不同,这会逐步累积成输出上的细微差异。」
不过,这一假设虽然不算错,但并没有揭示全部真相。举个例子:即便在 GPU 上,用同样的数据反复运行相同的矩阵乘法,每次得到的结果也完全一致(逐位相同)。我们确实在使用浮点数,GPU 也确实存在高度并发,但在这种测试里为什么没有出现非确定性呢?
Python:A = torch.randn(2048, 2048, device=『cuda』, dtype=torch.bfloat16)B = torch.randn(2048, 2048, device=『cuda』, dtype=torch.bfloat16)ref = torch.mm(A, B)for _ in range(1000):assert (torch.mm(A, B) - ref).abs.max.item == 0
要理解 LLM 推理中非确定性的真正原因,我们必须更深入地探究。
不幸的是,甚至连如何定义 LLM 推理的「确定性」都很困难。或许让人困惑的是,以下几种说法同时都成立:
1. 有些 GPU 内核确实是非确定性的。2. 然而,语言模型前向传播中使用的所有内核都是确定性的。3. 此外,像 vLLM 这样的 LLM 推理服务器的前向传播过程,也可以说是确定性的。4. 尽管如此,从使用推理服务器的用户角度来看,结果却是非确定性的。
在这篇文章里,我们将解释为什么「并发 + 浮点数」的假设没有击中要害,揭示 LLM 推理非确定性的真正元凶,并说明如何解决这个问题,从而在 LLM 推理中获得真正可重复的结果。
原罪:浮点数的非结合性
在谈论非确定性之前,先有必要解释为什么会出现数值差异。毕竟,我们通常认为机器学习模型是一类数学函数,遵循诸如交换律或结合律这样的结构性规则。那为什么机器学习库不能给我们一个「数学上正确」的结果呢?
问题出在浮点数的非结合性。也就是说,对于浮点数:
(a+b)+c \neq a+(b+c)
Python:(0.1 + 1e20) - 1e20>>> 00.1 + (1e20 - 1e20)>>> 0.1
具有讽刺意味的是,正是结合律的「失效」,才让浮点数真正有了用处。
浮点数之所以有用,是因为它们提供了一种「动态」的精度水平。为了便于说明,我们这里使用 十进制(而不是二进制),并假设浮点数的表示格式为:尾数 × 10^指数。我们设定尾数保留 3 位数字,指数 1 位数字。
例如,数值 3450 可以精确表示为:
3.45 × 10³。
而更小的数值 0.486 可以表示为:
4.86 × 10⁻¹。
通过这种方式,浮点数既能表示非常大的数值,也能表示非常小的数值。在科学计算中,我们会说浮点数能够保持固定数量的「有效数字」。
当我们相加的两个浮点数指数相同时,看起来就和整数加法差不多。比如:
123 (1.23 × 10²) + 456 (4.56 × 10²) = 579 (5.79 × 10²)。
但是,当我们相加的两个浮点数指数不同,比如 1230 和 23.4,情况就不一样了。精确结果应为 1253.4。然而,我们一次只能保留 3 位有效数字。于是浮点数加法会舍弃最后 2 位,得到:1.25 × 10³(即 1250)。
图 1:我们需要 3 位有效数字来表示 1230,也需要 3 位有效数字来表示 23.4。然而,将这两个数相加后,结果 1253.4 需要 5 位有效数字才能表示。于是我们的浮点数格式只能把最后的 34 舍去。换句话说,我们实际上在加法前就把原本的 23.4 近似成了 20.0。
但在这个过程中,我们已经丢失了信息。需要注意的是,每当我们相加的浮点数处于不同「尺度」(即不同指数)时,这种情况都会发生。而浮点数的不同指数加法在计算中几乎无处不在。事实上,如果我们能保证永远不需要处理不同指数的情况,那直接用整数就足够了!
换句话说,每当我们以不同的顺序对浮点数进行加法时,最终结果都有可能完全不同。举个极端的例子:对于某个数组,仅仅因为加法顺序不同,结果就可能出现 102 种不同的情况。
Python:import randomvals = [1e-10, 1e-5, 1e-2, 1]vals = vals + [-v for v in vals]results = random.seed(42)for _ in range(10000):random.shuffle(vals)results.append(sum(vals))results = sorted(set(results))print(f「There are {len(results)} unique results: {results}」)Output:There are 102 unique results: [-8.326672684688674e-17, -7.45931094670027e-17, ..., 8.326672684688674e-17]
虽然浮点数的非结合性是导致输出结果不一致的根本原因,但它并不能直接解释非确定性从何而来。它无法告诉我们:为什么浮点数会以不同顺序相加、这种情况何时发生、以及我们该如何避免。
答案其实藏在内核的实现方式里。
为什么内核不会总是按相同顺序相加?
正如前面提到的,一个常见的解释是 「并发 + 浮点数」假设。该假设认为,如果并发线程完成的顺序是非确定性的,而累加的顺序又依赖于线程完成的顺序(比如使用了 atomic add),那么累加的结果自然也会是非确定性的。
不过令人困惑的是,尽管这种情况确实会导致某些内核非确定性,但在 LLM 推理的非确定性 中,并发(以及 atomic add)其实完全没有参与其中!要解释真正的元凶,我们需要先理解:为什么现代 GPU 内核几乎不需要使用 atomic add。
什么时候需要用到 atomic add?
通常情况下,GPU 会在许多「核心」(即 SM,多处理器单元)上并发运行一个程序。由于这些核心之间没有天然的同步机制,如果核心之间需要通信,就会带来挑战。比如,当所有核心都要把结果累加到同一个元素时,就可以使用 atomic add(有时也叫 「fetch-and-add」)。
atomic add 是「非确定性」的——其累加结果的顺序完全取决于哪个核心先完成。
举个具体例子:假设你要用 100 个核心对一个 100 元素的向量做归约运算(例如 torch.sum)。虽然你可以并行加载这 100 个元素,但最终必须将它们归并为一个值。一种实现方法就是使用某种 atomic add 原语,硬件会保证所有加法都被处理,但不会保证它们的执行顺序。
图 2 :atomic add 可以确保每个核心的计算结果都会被包含在最终的和里。但它并不保证累加的顺序,顺序完全取决于哪个核心先完成运算,而这是非确定性的。于是,即便是同一个并行程序,在相同输入下多次运行,结果也可能不一样。
这通常就是人们所说的「非确定性」——你在完全相同的输入条件下运行同一个内核两次,却得到不同的结果。这类情况被称为 运行间非确定性(run-to-run nondeterminism),就像你两次运行相同依赖环境下的 Python 脚本,却得到不一样的结果。
不过,尽管并发的 atomic add 确实会让内核非确定性,但在绝大多数内核中,其实并 不需要 使用 atomic add。事实上,在典型的 LLM 前向传播过程中,几乎不存在任何一个 atomic add 操作。
这可能会让人惊讶,因为在并行化归约操作时,atomic add 本可以派上用场。但实际并不需要,原因主要有两点:
1. 批量维度的并行性足够大,通常在「批量」维度上有足够的并行性,因此不需要在归约维度上再做并行。比如,与其并行归约一个 100 维向量,不如并行归约 500 个向量。这样每个核心负责一个完整向量的归约,各个核心之间互不干扰。神经网络库已经发展出确定性的归约策略,
2、多数神经网络库采用了多种策略,在保证性能的同时也能实现确定性。比如,可以用 分块(或树状)归约,将 100 个元素的归约拆分成 5 个 20 元素的归约(实现 5 路并行)。最后再把 5 个结果合并:要么用一个独立的「收尾」归约(不并行,但操作的元素足够少,所以开销很低);要么用 信号量(semaphore),保证并发的线程块以确定性的顺序累加。
基于这两点,对于绝大多数神经网络运算,避免 atomic add 几乎不会带来性能损失。
当然,仍然有少数操作在避免 atomic add 时会付出显著性能代价。比如 PyTorch 中的 scatter_add(a[b] += c)。而在 LLM 中唯一常见的这种情况是 FlashAttention 的反向传播。
有趣的是,你知道吗?广泛使用的 Triton 版 FlashAttention 反向传播实现,算法上其实与 Tri Dao 的 FlashAttention-2 论文不同!标准的 Triton 实现会在反向传播中额外做一些重复计算,以避免 atomic add,但因此多消耗了 40% 的 FLOPs!
然而,LLM 的 前向传播 中并不存在需要 atomic add 的操作。因此,LLM 的前向传播实际上是 运行间确定性的(run-to-run deterministic)。
图 3:从推理服务器的角度来看,它是确定性的。只要用户请求完全相同,服务器就会始终给出相同的确定性输出。
维基百科写道:「确定性算法是一种算法,只要输入相同,就总会产生相同的输出。」在这里也是一样:当输入(即推理服务器接收到的请求)完全一致时,前向传播总会生成完全一致的输出。
然而,仅仅前向传播本身是「确定性的」,并不足以保证包含它的整个系统就是确定性的。举个例子:如果一个请求的输出依赖于其他并行用户的请求(比如 batch-norm),那么对每个单独请求而言,它根本无法预知并行请求的内容,于是从用户的角度看,整个 LLM 推理依然是 非确定性的!
事实证明,请求的输出确实依赖于并行用户的请求。但原因并不是批次之间的信息「泄露」,而是因为前向传播缺乏所谓的 「批量不变性(batch invariance)」,导致单个请求的输出会受到前向传播批量大小的影响。
批量不变性与「确定性」
为了说明批量不变性,我们把系统简化,只看矩阵乘法(matmul)。你可以假设大多数矩阵乘法实现都是 运行间确定性的(run-to-run deterministic)。这并非百分之百正确,但在常见实现中确实如此。不过,它们并不具备批量不变性。换句话说,当批量大小发生变化时,批量中每个元素的计算结果都可能不同。
从数学角度看,这是一个相当反常的现象。矩阵乘法本应在批量的每个元素上是「独立」的——其他元素的存在与否,或批量的大小,都不该影响某个特定元素的计算结果。
然而,实证观察表明,事实并非如此。
Python:import torchtorch.set_default_device(『cuda』)B = 2048D = 4096a = torch.linspace(-1000, 1000, B*D).reshape(B, D)b = torch.linspace(-1000, 1000, D*D).reshape(D, D)Doing a matrix vector multiplication by takingthe first element of the batchout1 = torch.mm(a[:1], b)Doing a matrix matrix multiplication and then takingthe first element of the batchout2 = torch.mm(a, b)[:1]print((out1 - out2).abs.max) # tensor(1669.2500, device=「cuda:0」)
需要注意,这里说的是 「运行间确定性(run-to-run deterministic)」。也就是说,如果你多次运行同一个脚本,它会始终返回相同的结果。
但它并不是 「硬件/软件版本不变性」——也就是说,不同的 GPU 或 PyTorch 版本可能会给出不同的数值,不过在同一硬件和软件版本下,它应当始终返回一致的结果。
然而,当一个 不具备批量不变性 的内核被用于更大的推理系统时,整个系统就可能变得非确定性。因为当你向推理端点发出请求时,从用户的角度看,服务器的负载情况本质上是「非确定的」。而服务器的负载会决定内核运行时的 批量大小,从而导致每个请求最终结果的变化!
图 4 :尽管从整体上看,推理服务器本身可以被称为「确定性的」,但对于单个用户而言,情况就不同了。对于某个用户来说,其他并发用户并不是系统的「输入」,而是系统的一种 非确定性特征。这就使得从每个用户的角度来看,LLM 推理是非确定性的。
换句话说,如果某个内核在某种属性下不是不变的(例如批量大小),而这个属性本身又是非确定性的(例如服务器的负载情况),那么组合起来的系统自然就是非确定性的。
也就是说,几乎所有 LLM 推理端点之所以非确定性,主要原因就是 服务器负载(进而批量大小)是非确定性变化的!而且,这种非确定性并非 GPU 独有——运行在 CPU 或 TPU 上的 LLM 推理端点,同样会受到这种影响。
因此,如果我们想避免推理服务器中的非确定性,就必须让内核具备批量不变性。要理解如何做到这一点,我们得先看看为什么内核一开始就缺乏批量不变性。
如何让内核具备批量不变性?
为了让 Transformer 的实现具备批量不变性,我们必须让 每一个内核 都具备批量不变性。幸运的是,我们可以假设所有 逐点操作(pointwise operations) 都是批量不变的。虽然在 PyTorch 中确实如此,但严格来说这并不是天生成立的。比如,一些 CPU 内核的实现会在数组的某些部分使用向量化指令,而在其他部分使用非向量化指令,而这些指令在数值结果上未必能保证逐位完全一致。
因此,我们真正需要关注的只有三类涉及 归约(reduction) 的操作:
RMSNorm矩阵乘法注意力机制
(与并行相关的归约超出了本文的讨论范围,但原理相同。有一个可能有用的信息是:在 CUDA 12.8+ 环境下,Blackwell 和 Hopper GPU 的 NVLink-Sharp 交换机内归约操作是确定性的。相关信息可以在 NCCL 的 GitHub issues 中找到。)
巧合的是,这三类操作按难度顺序正好从低到高排列。要让它们在保持合理性能的同时实现批量不变性,各自都需要额外的考虑。我们先来看看 RMSNorm。
批量不变的 RMSNorm
图 4:数据并行 RMSNorm。在并行化策略中,理想情况下我们希望尽量避免核心之间的通信。其中一种方法就是 把每个批量元素分配给一个核心,从而保证每次归约运算都完全在单个核心内部完成。这种方式被称为 「数据并行」策略,因为我们只是沿着一个不需要通信的维度进行并行化。在这个例子中,我们有 4 行数据和 4 个核心,正好把核心全部用满。
RMSNorm 的实现方式如下:
Python:x: [batch_size, hidden_dim]weight: [hidden_dim]def rms_norm(x, weight):return x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * weight
要实现 批量不变性,要求每个元素的归约顺序在任何批量大小下都必须是固定的。
需要注意的是,这并不意味着我们必须始终使用完全相同的归约策略。比如,当归约的元素数量发生变化时,即便归约策略随之改变,也依然可以保持批量不变性。Quack 的一篇博文给出了不少不错的例子,展示了各种归约策略的层级结构(如线程归约、warp 归约、block 归约、集群归约)。
因此,只有当 批量大小会影响归约策略 时,我们才会破坏批量不变性。
接下来看看 RMSNorm 的标准并行化策略。一般来说,并行算法的优势在于能够最大限度减少核心之间的通信。在这里,为了便于讨论,可以假设「核心」指的是 SM(Streaming Multiprocessors,流式多处理器)。更具体来说,这里重要的性质是:内核启动的线程块数要大于 SM 的数量。所以,一个可行的策略就是像上图那样,把每个批量元素分配给一个核心。
在这种情况下,增加批量大小并不会影响归约策略。如果批量大小为 200 时已经能为内核提供足够的并行性,那么批量大小为 2000 时也一定能够提供足够的并行性。
图 6:大批量情况下的数据并行 RMSNorm:将数据并行策略扩展到更大的批量其实很直接——与其让每个核心只处理一行,不如让每个核心顺序处理多行。这样可以保持批量不变性,因为每个批量元素的归约策略仍然是完全一致的。
另一方面,减小批量大小则可能带来挑战。由于我们是把每个批量元素分配给一个核心,当批量大小减少时,最终可能出现核心数量多于批量元素的情况,从而导致部分核心闲置。
在这种情况下,一个优秀的内核工程师通常会采用前一节提到的解决方案(如 atomic add 或 分块归约),以保持良好的并行度,从而维持性能。遗憾的是,这么做会改变归约策略,从而使该内核失去批量不变性。
图 7:分块归约 RMSNorm:当批量较小时,数据并行策略可能不再具备足够的并行度来充分利用所有核心。这时,更高效的办法是把一次归约任务 拆分(split) 给多个核心,从而让 GPU 得到充分利用。然而,这样做会失去批量不变性,因为每个元素的归约顺序不再一致。
最简单的解决办法是干脆忽略这些情况。这并非完全不可接受——批量很小时,内核本身执行得就很快,即便性能有所下降也未必会造成严重影响。
如果确实必须优化这种场景,可以考虑一种策略:无论批量大小如何,都始终使用一种具备足够并行度的归约方式。这种方式在批量很小时能保证合理的性能,而在批量很大时则会带来过度的并行化。不过,这种方法能让性能在不同批量规模下保持「尚可」,虽然无法达到峰值表现。
批量不变的矩阵乘法
图 8:数据并行矩阵乘法(Matmul):与 RMSNorm 类似,矩阵乘法的标准并行化策略也是一种 数据并行 策略,即将整个归约操作限制在单个核心内完成。最直观的做法是把输出张量划分为二维小块(tiles),并将每个小块分配给不同的核心。每个核心只计算属于该小块的点积运算,从而依旧在单个核心内完成全部归约。不过,与 RMSNorm 不同的是,矩阵乘法受算术强度和 tensor core 利用率等额外约束的影响。为了实现高效的 matmul 内核,我们必须把 二维小块 分配给核心,而不是单个输出元素。
本质上,可以把矩阵乘法看作 逐点运算 + 归约。因此,当我们通过切分输出张量为小块来并行化矩阵乘法时,得到的就是一种类似的 数据并行内核策略,依旧保证每个归约在单个核心内完成。
同样地,就像 RMSNorm 那样,批量维度(M 和 N)可能过小,迫使我们必须沿着归约维度(K)进行切分。而且,即便有两个「批量维度」,为了有效利用 tensor core,matmul 也要求每个核心承担更多的计算「工作量」。例如,对于一个形状为 [1024, K] × [K, 1024] 的矩阵乘法,如果使用标准的 [128, 128] tile 大小,数据并行策略最多只能把计算切分到 64 个核心上,远不足以让 GPU 饱和。
在矩阵乘法中,沿着归约维度 K 进行切分被称为 Split-K Matmul。然而,就和 RMSNorm 一样,这种策略会破坏批量不变性。
另一个有趣的矩阵乘法并行策略是 Stream-K。Stream-K 更特别,因为它的「不变性」比常规矩阵乘法还要弱。大多数 matmul 库虽然不是批量不变的,但至少是 批量位置不变的(batch-position-invariant),也就是说,即使改变批量中元素的位置,数值结果也不会受影响。但 Stream-K 连这一点也不满足!
它的核心思想是:通过对不同输出 tile 在 K 维上采用不同的切分方式,可以实现更均衡的负载。但一旦利用了这种方法,内核也就不再具备批量位置不变性了。
Split-K 矩阵乘法(Matmul):当批量维度非常小时,并行度可能不足,这时就需要用到 Split-K matmul。在这种策略下,我们把一次归约拆分到两个核心上分别计算,最后再把结果合并。这样一来,虽然每次归约被分散到多个核心,但依然能利用更多的硬件资源,例如这里就能调动 8 个核心。
矩阵乘法相比 RMSNorm 还有一个额外的复杂性——tensor core 指令。在归约中,我们可以一次只处理一行;但在高效的矩阵乘法内核中,我们必须一次处理整个 tile(小块)。
每条 tensor core 指令(例如 wgmma.mma_async.sync.aligned.m64n128k16)在内部可能都有不同的归约顺序。选择不同的 tensor core 指令的原因之一,就是 批量大小过小。比如,如果我们使用的 tensor core PTX 指令要求 tile 长度为 256,而批量大小只有 32,那么几乎所有计算资源都会被浪费!在批量大小为 1 的情况下,最快的内核通常根本不会使用 tensor core。
图 10 :填充后的 Tensor-Core 指令:当批量大小太小时,可能会出现连一个完整的二维 tile 都无法放进输出的情况。在这种情况下,最有效的做法是切换到更小的 tensor core 指令,或者干脆完全不用 tensor core!不过,这两种选择都会让内核失去批量不变性。
因此,保证矩阵乘法批量不变性的最简单方法就是:只编译一个内核配置,并对所有形状都使用它。虽然这会牺牲一部分性能,但在 LLM 推理中通常并不是灾难性的。尤其是,split-k 最常在 M 和 N 都很小时才会用到,而幸运的是,在我们的场景下,N(也就是模型维度)通常都很大!
图 11:即便保证了批量不变性,我们的性能相比 cuBLAS 也只损失了大约 20%。需要注意的是,这里用的还不是一个优化过的 Triton 内核(比如没有用到 TMA)。不过,一些性能上的表现可以很好地说明:批量不变性要求到底是在哪些地方带来了性能损失。首先,在非常小的批量下,我们的性能损失显著,这是因为使用的指令过大,同时并行度不足。其次,随着批量大小的增加,会出现一种类似「拼图」的性能波动模式,这是由量化效应(包括 tile 和 wave)导致的。而在通常情况下,这类问题是通过改变 tile 大小来缓解的。
批量不变的注意力机制
图 12:FlashAttention2 策略:在 FlashAttention2 中,我们沿着 Q 做并行,同时沿着 K/V 做归约。这样一来,整个归约过程都能保持在单个核心内完成,因此这也是一种 数据并行策略。
不过,在矩阵乘法实现了批量不变性之后,注意力机制又带来了两个额外的复杂点——这也很合理,因为它本身包含了两个矩阵乘法:
1. 与 RMSNorm 和 Matmul 只在 特征维度 上做归约不同,注意力机制需要同时在 特征维度和序列维度 上做归约。2. 注意力机制还必须应对各种推理优化方式对序列处理方式的影响(如分块预填 chunked prefill、前缀缓存 prefix caching 等)。
因此,要在 LLM 推理中实现真正的确定性,我们的数值计算必须在两方面保持不变:与 一次处理多少请求无关:与 推理引擎如何切分请求无关。
让我们先看 FlashAttention2 中引入的注意力机制标准并行策略。和 RMSNorm、Matmul 类似,默认的还是 数据并行 策略。由于归约发生在 key/value 张量 上,因此数据并行只能沿着 query 张量来并行。
举个例子:根据推理引擎的调度方式,一个序列可能会被分块处理(比如分块预填),也可能一次性处理完(如果没有分块)。为了保证 批量不变性,对于某个给定的 token,它的归约顺序不能依赖于这个序列中同时被处理的其他 token 数量。
如果你把 KV 缓存中的 K/V 值与当前正在处理的 token 的 K/V 值分开归约(比如 vLLM 的 Triton 注意力内核就是这样做的),那么就无法实现批量不变性。举个具体例子:当处理序列中的第 1000 个查询 token 时,无论 KV 缓存里是 0 个 token(预填阶段),还是 999 个 token(解码阶段),它的归约顺序都必须完全一致。
图 13:带 KV 缓存的 FlashAttention:之所以把 KV 缓存和当前 KV 值分开处理会破坏批量不变性,原因比较微妙,涉及到 「边界条件」。举个例子:假设 block 大小是 32,而 KV 缓存里目前有 80 个元素。接着我们又要计算 48 个未缓存的元素。在这种情况下: 计算缓存部分 P cache 需要 3 个 block(两个完整的 + 一个带掩码的)。计算未缓存部分 P 需要 2 个 block(一个完整的 + 一个带掩码的)。总共需要 5 个 block 来完成归约,而实际上总元素数是 128(也就是 4 个 block)。这就必然会改变归约顺序。比如,如果我们没有任何 KV 缓存,而是一次性处理 128 个元素,那么两种情况下的数值结果必须完全一致,才能保证注意力机制的 批量不变性。
为了解决这个问题,我们可以在进入注意力内核之前,先更新 KV 缓存和页表,从而确保 无论处理多少个 token,keys 和 values 的布局始终一致。
结合这一点,以及前面提到的其他措施(比如保持 tile 大小一致),我们就能实现一个 批量不变的注意力机制实现!
不过,这里仍然存在一个严重问题。与矩阵乘法不同,在 LLM 推理中,注意力机制的输入形状往往确实需要用到 分块归约内核,通常被称为 Split-KV 或 FlashDecoding。
原因是:如果我们不在归约维度上做并行,那么只能在 batch 维度、head 维度 和 query 长度维度 上并行。而在注意力的解码阶段,query 长度通常非常小,如果批量规模不大,就很难让 GPU 得到充分利用。
遗憾的是,这种情况不像 RMSNorm 和矩阵乘法那样可以轻易忽略。比如,如果 KV 缓存非常长,即便只处理一个请求,注意力内核的计算也可能需要非常长的时间。
i图 14:固定分块数的 Split-KV 策略(即 FlashDecode):当 query 长度变得非常小时(比如在解码阶段),内核可能几乎没有并行度。这种情况下,我们不得不再次沿着归约维度切分——这次是 KV 维度。典型的做法是:先确定需要多少并行度,然后把 KV 维度均匀切分。比如,如果 KV 长度是 1000,而我们需要 4 个切分,那么每个核心就处理 250 个元素。但遗憾的是,这依然会破坏 批量不变性,因为具体的归约策略依赖于某个请求里到底有多少个 query token 被处理。
此外,注意力机制中常见的分块归约策略,本身也对批量不变性构成挑战。比如,FlashInfer 的「平衡调度算法」会选择一种能让所有 GPU 核心都被填满的最大切分方式,这就使得归约策略不再是批量不变的。不同于 RMSNorm 或 Matmul,在注意力机制中,单纯固定分块数(#splits)并不足以保证批量不变性。
相反,要实现批量不变性,我们必须采用 固定分块大小(fixed split-size) 的策略。换句话说,我们不是固定分块数,而是固定每个分块的大小,这样分块数会随着输入变化而不同。通过这种方式,我们就能确保无论处理多少 token,归约顺序始终保持一致。
图 15:固定大小的 Split-KV 策略,这种策略与之前的策略唯一的区别是:切分的大小是固定的。例如,当 KV 长度为 1000 时,我们不会把它均分为 4 个长度为 250 的切块,而是切分为 3 个固定大小为 256 的切块,以及 1 个长度为 232 的切块。这样就能保持批量不变性,因为归约策略不再依赖于一次要处理多少个 query token!
实现
我们在 vLLM 的 FlexAttention 后端以及 torch.Library 的支持下,演示了如何实现确定性推理。借助 torch.Library,我们能够以一种不具侵入性的方式替换掉大部分相关的 PyTorch 运算算子。
你可以在 thinking-machines-lab/batch-invariant-ops 中找到这套「批量不变」内核库,以及基于 vLLM 的确定性模式运行示例。
实验
补全到底有多非确定性?
我们使用 Qwen/Qwen3-235B-A22B-Instruct-2507,在温度设为 0 的情况下采样 1000 次补全,提示词为 「Tell me about Richard Feynman」(非思考模式),每次生成 1000 个 token。
令人惊讶的是,我们得到了 80 种不同的补全结果,其中最常见的一种出现了 78 次。
观察这些补全的差异,可以发现它们在前 102 个 token 完全一致!首次出现差异是在第 103 个 token:所有补全都生成了相同的片段——「Feynman was born on May 11, 1918, in」。然而,其中 992 个补全继续生成了 「Queens, New York」,而 8 个补全生成了 「New York City」。
另一方面,当我们启用批量不变内核后,1000 次补全结果完全一致。这正是从数学角度上我们对采样器的预期,但如果没有批量不变内核,我们是无法得到确定性结果的。
性能
目前我们还没有投入大量精力去优化这些批量不变内核的性能。不过,我们可以做一些实验来验证它们的性能是否依然可用。
我们搭建了一个 API 服务器,使用单张 GPU 运行 Qwen-3-8B,并请求生成 1000 条序列,每条输出长度在 90 到 110 之间。
很多性能下降其实源于 vLLM 中 FlexAttention 集成尚未经过深入优化。尽管如此,整体性能依然是可接受的,并没有出现灾难性的问题。
真正的 On-Policy 强化学习(True On-Policy RL)
研究人员指出,训练与推理之间的数值差异,会在隐性上把本应是 on-policy 的强化学习 变成 off-policy RL。
显然,如果连两次完全相同的推理请求都不能保证逐位一致的结果,那么训练与推理之间更不可能逐位一致。而一旦我们能够实现 确定性推理,就能进一步修改训练流程,使采样与训练之间也能达到逐位一致,从而实现真正的 on-policy RL。
我们在 Bigmath 上进行了 RLVR 实验,RL 策略初始化自 Qwen 2.5-VL instruct 8B,最大 rollout 长度为 4096。
当训练时 没有使用 off-policy 校正(即重要性加权)时,奖励在训练中途就会崩溃。而加入 off-policy 校正项 后,训练可以平稳进行。但如果我们能保证采样器与训练器之间逐位一致,那么整个过程就是完全 on-policy(即 KL 散度为 0),同样也能顺利训练。
我们还绘制了采样器与训练器之间 logprobs 的 KL 散度曲线,三种运行方式的表现差异明显:
使用重要性加权时,KL 散度大约稳定在 0.001,偶尔出现波动。不使用重要性加权时,KL 散度最终会出现一次明显飙升,与此同时奖励也会崩溃。而在运行 「真正的 On-Policy RL」 时,KL 散度始终保持在 0,这表明训练策略与采样策略之间完全没有差异。
图 16:需要注意的是,在 未使用重要性加权 的实验中,大约在 第 318 步 出现了显著的损失(loss)飙升,同时伴随着 logprobs 的 KL 散度 剧烈上升。相比之下,无论是使用 off-policy 校正,还是运行 「真正的 On-Policy」,强化学习训练都能顺利进行。而图中蓝色的 「True On-Policy」 曲线并不是 bug —— 它只是始终平稳地贴合在 0 上。
结论
现代软件系统包含了大量抽象层。在机器学习中,当我们遇到非确定性和细微的数值差异时,人们往往会选择「掩盖」它们。毕竟,我们的系统本身就是「概率性的」,再多一点非确定性又能怎样呢?单元测试里把 atol/rtol 容忍度调高一点,好像也不是什么大问题?训练器和采样器之间的 logprobs 差异,大概也不算是真正的 bug,对吧?
但我们拒绝这种 消极妥协。只要多花一些精力,我们完全可以理解非确定性的根源,甚至把它们解决掉!
我们希望这篇博客能帮助社区对 推理系统中的非确定性问题 有一个扎实的理解,并能激励更多人去真正掌握自己系统的运行机制。
引用
请按以下格式引用本文:
He, Horace 和 Thinking Machines Lab, 《Defeating Nondeterminism in LLM Inference》,Thinking Machines Lab: Connectionism, 2025 年 9 月。
或者使用 BibTeX 格式引用:
@article{he2025nondeterminism,author = {Horace He and Thinking Machines Lab},title = {Defeating Nondeterminism in LLM Inference},journal = {Thinking Machines Lab: Connectionism},year = {2025},note = {https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/}, doi = {10.64434/tml.20250910}}
附上博客原文:
来源:AI观察室