Prhub

#24762 [AMD] fix(triton-mla): cap max_kv_splits at 256 on gfx942 (Kimi-K2.6 hang)

原始 PR 作者 bingxche 合并时间 2026-06-03 15:13 文件变更 5 提交数 19 评论 7 代码增减 +27 / -4

执行摘要

限制 gfx942 上 max_kv_splits 为 256,修复 Kimi-K2.6 挂起

修复nightly-8-gpu-kimi-k26 MI325X挂起问题,原因是PR #20479的_mla_decode_kv_splits_cap()max_kv_splits提升至512,导致cuda_graph_attn_logits缓冲区膨胀至4 GiB,超出ROCm CUDA图重放能力(https://github.com/sgl-project/sglang/actions/runs/25513282022/job/74877480809)。

值得精读。设计决策:针对特定SKU硬编码上限是否优于动态内存预算?后续若能统一为“两倍最大上下文分割数”则更通用。此外,is_gfx942_supported的引入为后续AMD特殊处理提供了范例。

讨论亮点

HaiShaw在审查时指出应避免使用含糊变量名(如早期提交中的bs),确保代码可读性;最终实现使用自解释的self.max_kv_splits

实现拆解

  1. python/sglang/srt/utils/common.py中新增is_gfx942_supported()函数,带@lru_cache装饰,检测gcnArchName是否包含gfx942
  2. python/sglang/srt/layers/attention/triton_backend.py中导入is_gfx942_supported,模块级缓存_is_gfx942。在TritonAttentionBackend.__init__的MLA分支内追加条件:若_is_gfx942为真,则将self.max_kv_splits限制为min(self.max_kv_splits, 256)
  3. test/registered/amd/test_kimi_k2_instruct.py中将parallel=1319改为parallel=512,避免修复后剩余内存不足以支撑高并发。
  4. 更新.github/workflows/pr-test-amd.ymlpr-test-amd-rocm720.yml中的--auto-partition-size从3增至4,以容纳新增的MI325X测试分区。
文件 模块 状态 重要度
python/sglang/srt/utils/common.py 工具库 modified 6.37
python/sglang/srt/layers/attention/triton_backend.py 注意力层 modified 6.52
test/registered/amd/test_kimi_k2_instruct.py 测试 modified 3.82
.github/workflows/pr-test-amd.yml CI 配置 modified 2.38
.github/workflows/pr-test-amd-rocm720.yml CI 配置 modified 2.73

关键符号

is_gfx942_supported

关键源码片段

python/sglang/srt/layers/attention/triton_backend.py core-logic

核心修复:在 MLA 初始化时限制 `max_kv_splits`,防止缓冲区过大。

# triton_backend.py 文件头部分
from sglang.srt.utils import (
    is_gfx942_supported,
)
_is_gfx942 = is_gfx942_supported() # 模块级缓存,只检测一次# 在 __init__ 方法中,self.use_mla 分支内
if self.use_mla:
    self.max_kv_splits = _mla_decode_kv_splits_cap(
        self.max_kv_splits,
        self.device_core_count,
        self.max_context_len,
    )
    if _is_gfx942:
        # gfx942 (MI300X / MI325X) 有 304 个 CU,next_power_of_2 会得到 512,
        # 导致 cuda_graph_attn_logits 缓冲区在 Kimi-K2.6 上膨胀到 4 GiB。
        # 强制限制为 256,与 gfx950 的行为一致且经过验证。
        self.max_kv_splits = min(self.max_kv_splits, 256)

评论区精华

变量命名规范 style

HaiShaw 要求避免使用通用变量名 `bs`,应更具体。

结论:最终代码未使用 `bs`,采用了自解释的 `self.max_kv_splits`。 · 已解决

风险与影响

  1. 仅针对gfx942限制,不影响NVIDIA或其他AMD SKU;但未来若有新SKU的CU数量超过256,需重新评估。
  2. 测试并行度降低至512,可能未覆盖高并发场景下的内存压力。
  3. is_gfx942_supported()基于GPU名称字符串匹配,若ROCm报告格式变化可能导致误判。

用户:Kimi-K2.6在AMD MI325X上不再hang,恢复正常推理。系统:CUDA图捕获阶段内存占用从4 GiB降至2 GiB,图重放稳定性提升。团队:新增平台检测函数可复用,但需注意代码维护。CI测试增加一个分区以确保覆盖。

平台特异性硬编码 测试并发度降低 字符串匹配依赖

关联 Issue

#20479 Support Triton MLA FP8 KV cache

完整报告

参与讨论