Prhub

#43599 [Bugfix][Kernel] TRTLLM NVFP4 MoE chunking

原始 PR 作者 amitz-nv 合并时间 2026-05-28 08:36 文件变更 4 提交数 7 评论 9 代码增减 +74 / -21

执行摘要

修复 TRTLLM NVFP4 MoE 内核大批量 token 下的 CUDA grid 溢出

TRTLLM 融合 MoE NVFP4 内核在大量 token(>=305k)下存在两个崩溃问题:

1) CUDA grid.Y 维度可能超过 64k 限制;
2) 在 Nemotron 3 Super(num_experts=512, top_k=22)上 token 数 >=305k 时发生非法内存访问(IMA)。常见场景如 EP 与 DP=8 且 max_num_batched_tokens=50k 时会达到 400k token,导致服务崩溃。此前内核不支持任何分块机制。

建议阅读 trtllm_nvfp4_moe.py 中的 chunking 实现,特别是 _calc_max_supported_tokens 的公式推导,它展示了如何根据 CUDA grid 限制逆向计算安全 token 数。此外,设计上选择仅在 TRTLLM NVFP4 内核启用 chunking 并在其他实现中移除未使用的 supports_chunking,体现了清晰的职责分离。此 PR 的测试方法也值得参考:通过对比极大数据配置下的运行和精度来验证修复。

讨论亮点

公式正确性讨论:gemini-code-assist 最初指出 _calc_max_supported_tokens 公式可能错误,并给出了替代推导。作者 amitz-nv 通过引用 flashinfer 源码和数学不等式反驳,最终 AI 承认公式正确,确认其满足 floor 运算的约束条件。

  • 注释澄清:用户 tomeras91 认为 300k 的注释令人困惑,作者解释 305k 是观察到的 IMA 阈值,因此取 300k 作为安全裕度,注释无误。

实现拆解

  1. 计算安全分块大小:在 trtllm_nvfp4_moe.pyTrtLlmNvFp4ExpertsModular 类中,新增 _get_chunk_size 方法。其内部闭包 _calc_max_supported_tokens 套用了 flashinfer 的 TRTLLM MoE runner 中 getMaxNumCtasInBatchDim 函数关于 CUDA grid.Y 维度的约束公式,反解出不超过 65535 限制的最大 token 数;再与硬编码 300k(因 IMA 错误观察阈值为 305k)取最小值作为最终 chunk size。

  2. 精简内核调用接口:将原有的 apply 方法重命名为 _invoke_kernel,移除 expert_mapa2_scaleworkspace13workspace2expert_tokens_metaapply_router_weight_on_input 等六个未使用参数,并简化内部断言,使核心调用更清晰。

  3. 实现分块 apply 方法:重新实现 apply 方法,保持与原签名一致。内部先获取 chunk size,若当前 token 数 M 小于 chunk size,则直接调用 _invoke_kernel;否则在 0M 区间内以 chunk size 步进循环,每次将 hidden_statesoutput 的相应切片传入 _invoke_kernel,其他参数(权重、topk 等)保持不变。

  4. 清理其他实现:在 trtllm_bf16_moe.pytrtllm_fp8_moe.pytrtllm_mxfp4_moe.py 中删除 supports_chunking 方法。该方法原本返回 False 且未被上层调用,此处移除以减少冗余接口,与 NVFP4 实现保持一致(NVFP4 用 _get_chunk_size 替代了 supports_chunking)。

  5. 测试验证:虽然未添加自动化单元测试,但通过手工运行 Nemotron 3 Super 模型在 EP+DP=8、max_num_batched_tokens=50k(总 token 数 400k)的极限配置下,服务不再崩溃,GSM8K 评测精度(约 93%)与低 token 配置一致,证明修复有效。

文件 模块 状态 重要度
vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py MoE 专家层 modified 8.73
vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py MoE 专家层 modified 5.03
vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py MoE 专家层 modified 5.03
vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py MoE 专家层 modified 5.03

关键符号

supports_chunking _get_chunk_size _calc_max_supported_tokens apply _invoke_kernel

关键源码片段

vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py core-logic

核心变更:引入分块机制以避免 CUDA grid 和 IMA 崩溃,重构内核调用接口。

# vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
# TrtLlmNvFp4ExpertsModular 类中的核心方法def _get_chunk_size(self) -> int:
    '''计算安全的 chunk size,同时避免 CUDA grid.Y 64k 限制和 IMA 崩溃。'''
    MAX_GRID_Y = 65535 # CUDA grid.Y 维度最大值
    MAX_TILE_TOKENS_DIM = 128 # flashinfer TRTLLM MoE 内核的 tile 大小
​
    def _calc_max_supported_tokens(top_k: int, num_experts: int) -> int:
        '''基于 flashinfer 的 `getMaxNumCtasInBatchDim` 函数推导:
        https://github.com/flashinfer-ai/flashinfer/blob/719ee23fd82cb220d51ad118ca60198718f6c9d1/include/flashinfer/trtllm/fused_moe/runner.h#L97
        已知 `maxNumCtas = numExperts + (numTokens * topK - numExperts) // tileTokensDim`,
        约束 `maxNumCtas <= MAX_GRID_Y`,反解出 `numTokens` 的最大值。
        '''
        # 整数 floor 关系 : (a)//b <= c 等价于 a < b*(c+1),即 a <= b*(c+1)-1
        # 代入 a = numTokens * topK - numExperts, b = tileTokensDim, c = MAX_GRID_Y - numExperts
        return (num_experts + (MAX_GRID_Y - num_experts + 1) * MAX_TILE_TOKENS_DIM - 1) // top_k
​
    # 使用 300k 作为硬上限,因为超过 305k 时内核会出现 IMA 错误(Nemotron 3 Super 观察值)
    return min(300000, _calc_max_supported_tokens(self.topk, self.moe_config.num_experts))
​
​
def apply(self,
          output: torch.Tensor,
          hidden_states: torch.Tensor,
          w1: torch.Tensor,
          w2: torch.Tensor,
          topk_weights: torch.Tensor,
          topk_ids: torch.Tensor,
          activation: MoEActivation,
          global_num_experts: int,
          expert_map: torch.Tensor | None,
          a1q_scale: torch.Tensor | None,
          a2_scale: torch.Tensor | None,
          workspace13: torch.Tensor,
          workspace2: torch.Tensor,
          expert_tokens_meta: mk.ExpertTokensMetadata | None,
          apply_router_weight_on_input: bool):
    assert self._supports_activation(activation)
    assert a1q_scale is not None
​
    M = hidden_states.shape[0]
    chunk_size = self._get_chunk_size()
​
    if chunk_size >= M:
        # 无需分块,直接调用内核
        self._invoke_kernel(
            output, hidden_states, w1, w2,
            topk_weights, topk_ids, activation,
            global_num_experts, a1q_scale)
    else:
        # 按 chunk_size 分块,每次处理连续的 hidden_states / output 切片
        # 注意:topk_weights 和 topk_ids 未切片,因为它们与全局 token 索引相关,
        # 由上层路由保证正确性
        for start in range(0, M, chunk_size):
            end = min(start + chunk_size, M)
            self._invoke_kernel(
                output[start:end],
                hidden_states[start:end],
                w1, w2,
                topk_weights, topk_ids,
                activation,
                global_num_experts,
                a1q_scale)

评论区精华

公式正确性 正确性

gemini-code-assist 指出 `_calc_max_supported_tokens` 公式可能错误;amitz-nv 反驳并引导根据 flashinfer 源码推导,最终 AI 承认公式正确。

结论:公式经数学推导验证正确,满足 `maxNumCtas <= MAX_GRID_Y` 约束。 · 已解决

注释澄清 question

tomeras91 认为 `300k` 注释令人困惑,amitz-nv 解释指 `>=305k` 导致 IMA 错误,极限 300k 是安全裕度。

结论:注释正确且恰当。 · 已解决

风险与影响

性能:chunking 仅在 token 数超过 300k 时触发,正常场景无额外开销;但对于极端大 batch,会增加 kernel launch 次数,可能略微增加延迟。正确性:公式推导经过 flashinfer 验证,且手动测试确认精度无损。兼容性:改动仅限于 trtllm_nvfp4_moe.py,不影响其他 MoE 实现;移除的 supports_chunking 方法在其余文件中未使用,无功能变化。潜在风险:chunk size 计算逻辑依赖于 flashinfer 内部实现细节,若 flashinfer 更新可能需同步调整;但公式与上游保持对齐,风险较低。此外,分块时 topk_weightstopk_ids 未切片,依赖于上层路由已保证全局一致性,若未来使用方式变化可能存在隐患。

用户影响:运行 NVIDIA TRTLLM NVFP4 MoE 模型(如 Nemotron 3 Super)且使用高数据并行度或大 max_num_batched_tokens 的用户将不再遇到崩溃,服务稳定性提升。系统影响:增加了少量分支判断和循环,但整体额外计算量可忽略。团队影响:明确了 TRTLLM MoE 内核的 CUDA 限制,为后续其他内核类似处理提供了参考模式。

核心路径变更 CUDA 硬件特定限制 缺少自动化测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论