Prhub

#27188 [AMD] Fix TP2 DeepSeek-R1 nhead=64 MLA decode crash and add nightly coverage

原始 PR 作者 clintg6 合并时间 2026-06-04 07:56 文件变更 6 提交数 1 评论 3 代码增减 +518 / -1

执行摘要

修复 DeepSeek-R1 TP2 时 nhead=64 MLA decode 崩溃并添加夜间测试

来自 PR body:'Fixes a Memory access fault crash when running DeepSeek-R1-MXFP4 with TP2 and the AITER persistent MLA decode path.' 以及关联的 ROCm/aiter#3496 报告了在 TP2 下使用 native qh64 persistent kernel 时多 GPU 并发崩溃。

值得精读。PR 展示了处理多 GPU 内核选择时的边界情况(head count 门控),并提供了完整的回归测试设计。建议关注 persistent 模式与非 persistent 模式的切换条件,以及如何通过 CI 配置覆盖不同 TP 场景。

讨论亮点

Review 中 HaiShaw 直接批准,Lzy17 给出 LGTM。bingxche 在 CI 中观察到准确率问题(TP2 准确率 0.945、TP4 0.965)并提供了链接,但未进一步讨论。该问题未在此 PR 中解决,但准确率阈值(0.93)已达标。

实现拆解

  1. python/sglang/srt/layers/attention/aiter_backend.py 中,将 __init__ 方法内条件从 if self.num_head == 32 or self.num_head == 128: 改为 if self.num_head in (32, 64, 128):,使 nhead=64 时启用 fast_mode=Trueintra_batch_mode=False,从而使用正确的 persistent metadata 模式。
  2. 新增两个 nightly 测试文件:test_deepseek_r1_mxfp4_tp2_mi35x.pytest_deepseek_r1_mxfp4_tp4_mi35x.py,分别覆盖 TP2 (nhead=64) 和 TP4 (nhead=32) 的 GSM8K 准确率测试,注册到 nightly 测试套件。
  3. test/run_suite.pyHWBackend.AMD 套件列表中添加两个新套件名称。
  4. 在两个 AMD nightly 工作流文件(nightly-test-amd.ymlnightly-test-amd-rocm720.yml)中分别添加对应的 job 定义,触发条件、运行环境和步骤。
  5. 本地验证:在 MI35X 上 TP2 配置之前 5/5 次崩溃,修复后 3/3 通过;TP4 控制组也通过;非 FP8 KV cache 和 MTP 场景也通过 smoke test。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/aiter_backend.py 注意力层 modified 5.58
test/registered/amd/accuracy/mi35x/test_deepseek_r1_mxfp4_tp2_mi35x.py 回归测试 added 7.37
test/registered/amd/accuracy/mi35x/test_deepseek_r1_mxfp4_tp4_mi35x.py 回归测试 added 7.37
.github/workflows/nightly-test-amd.yml CI 配置 modified 4.64
.github/workflows/nightly-test-amd-rocm720.yml CI 配置 modified 4.64
test/run_suite.py 测试配置 modified 3.25

关键符号

AiterMlaBackend.__init__ TestDeepSeekR1MXFP4TP2MI35x.setUpClass TestDeepSeekR1MXFP4TP4MI35x.setUpClass run_gsm8k_benchmark

关键源码片段

python/sglang/srt/layers/attention/aiter_backend.py core-logic

核心修复,改动一行条件,使 nhead=64 启用 persistent 模式,解决多 GPU 崩溃。

# 设置 persistent MLA 解码元数据模式
# 当前 mla_decode_fwd 只支持 fake-nps 在 self.num_head == 16
# 因此所有 num_head 大小都不使用 qh16 内核来模拟
# 它不应该使用 fake-nps (fast_mode=False, intra_batch_mode=True)
# 否则会导致 GPU 故障或精度问题
if self.num_head in (32, 64, 128): # 修复前:只有 32 和 128;修复后:加入 64
    fast_mode = True
    intra_batch_mode = False# 当前 persistent a16w16 mla_decode 内核不支持 head_num=128
# 需要回退到非 persistent 模式
# 仅当 fp8 kv_cache 时使用 mla_ps_kernel
if (
    self.num_head_padded == 16 or self.num_head_padded == 128
) and self.kv_cache_dtype is not fp8_dtype:
    _use_mla_ps_kernel = False
    fast_mode = False
    intra_batch_mode = False
test/registered/amd/accuracy/mi35x/test_deepseek_r1_mxfp4_tp2_mi35x.py test-coverage

新增 TP2 回归测试,直接覆盖崩溃路径,确保 nhead=64 路径被 nightly CI 测试。

# 注册夜间测试套件
register_amd_ci(
    est_time=1800,
    suite="nightly-amd-2-gpu-mi35x-deepseek-r1-mxfp4-tp2",
    nightly=True,
)# 常量定义
DEEPSEEK_R1_MXFP4_LOCAL_PATH = "/data2/models/amd-DeepSeek-R1-MXFP4-Preview"
DEEPSEEK_R1_MXFP4_HF_MODEL_ID = "amd/DeepSeek-R1-MXFP4-Preview"
SERVER_LAUNCH_TIMEOUT = 3600
GSM8K_ACCURACY_THRESHOLD = 0.93# 测试类:TP=2 时每 rank 64 个 heads
class TestDeepSeekR1MXFP4TP2MI35x(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.model = get_model_path()
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.num_questions = int(os.environ.get("GSM8K_NUM_QUESTIONS", "200"))
​
        # 强制使用 AITER 和 persistent MLA 模式
        env = os.environ.copy()
        env["SGLANG_USE_AITER"] = "1"
        env["SGLANG_AITER_MLA_PERSIST"] = "1"
​
        cls.process = popen_launch_server(
            model=cls.model,
            base_url=cls.base_url,
            timeout=SERVER_LAUNCH_TIMEOUT,
            other_args=[
                "--attention-backend", "aiter",
                "--tp", "2",
                "--chunked-prefill-size", "131072",
                "--disable-radix-cache",
                "--mem-fraction-static", "0.85",
                "--trust-remote-code",
                "--kv-cache-dtype", "fp8_e4m3",
                "--model-loader-extra-config", '{"enable_multithread_load": true}',
            ],
            env=env,
        )

评论区精华

CI 准确率问题观察 测试

bingxche 在 CI 运行中观察到准确率问题,并提供了指向 CI 日志的链接。

结论:未在本次 PR 中解决;准确率阈值 (0.93) 已达标,但偏差需要后续调查。 · 待处理

风险与影响

核心更改仅修改一行条件表达式,风险极低。但新增的 nightly 测试依赖于特定硬件(MI35X)和模型路径,若环境配置不正确可能失败。此外,准确率偏差提示 persistent kernel 在 nhead=64 时的浮点行为可能与之前非 persistent 路径略有差异,但 GSM8K 准确率仍在阈值内,且崩溃问题已修复。

直接影响所有在 AMD GPU 上使用 AITER 后端、DeepSeek-R1 模型且 TP=2 的用户,解决了一直以来多 GPU 崩溃的问题。对其他模型和配置无影响。新增的 nightly 测试将确保该路径持续被 CI 覆盖。

AMD 特定修复 测试环境依赖 潜在准确率漂移

关联 Issue

#3496 [Issue]: Multi-GPU crash with native qh64 fp8 persistent MLA kernel (mla_a8w8_qh64_qseqlen1_gqaratio64_v3_ps)

完整报告

参与讨论