Prhub

#26383 [AMD][DSV4] DSV4 MTP graph + sparse triton attn optimizations

原始 PR 作者 kkHuang-amd 合并时间 2026-05-28 06:23 文件变更 10 提交数 5 评论 9 代码增减 +659 / -65

执行摘要

修复 DSV4 MTP 在 ROCm 上的 CUDA Graph 捕获并优化注意力与融合 kernel

引用 PR body:'The long-standing draft #25552 collected 85 commits of AMD / DSV4 work over several rebases. Most of the optimization PRs ... have already been merged into main through smaller follow-up PRs. What remains is a small but meaningful set of changes needed to:

  1. Make DSV4 + EAGLE/MTP speculative decoding work correctly under CUDA-graph capture on ROCm. 2. Land the last round of ROCm performance wins ... 3. Provide a clean, single-commit, conflict-free landing.'

建议精读。该 PR 展示了如何在不破坏 CUDA 路径的前提下为 ROCm 修复关键错误并注入性能优化,其设计权衡(始终 eager 构造、fused kernel 阈值选择、三级 fallback 模式)值得学习。尤其推荐关注 deepseek_v4_fused_mhc.py 中的缓冲池与运行时禁用机制。

讨论亮点

Review 中 gemini-code-assist[bot] 提出了 4 条中等偏高的修改建议:

  • (高) _get_triton_mhc_post_pre_ops 导入失败后未缓存:每次 forward 都会重试导入并打 warning 日志,建议用布尔标记缓存失败状态。
  • (中) _freqs_cis_to_cos_sin 使用 id(freqs_cis) 作缓存键:若原 tensor 被回收存在 id 重用风险,建议在缓存值中保留引用。
  • (中) sampler.py 中直接调用 get_bool_env_var('SGLANG_USE_AITER'):应使用集中的 envs 配置系统。
  • (中) init_forward_metadata_target_verify 新增 extend_seq_lens 参数但未使用:函数体内完全忽略,也未传给 init_forward_metadata_target_verify_old
    四条评论均未得到维护者回复或修改,PR 已合并,这些建议处于未解决状态。此外,yctseng0211 在 issue 评论中指出 AMD CI 的已知失败与本 PR 无关。

实现拆解

  1. HIP 后端分发与 MTP 图捕获修复
    draft_utils.py 中对 _create_dsv4_decode_backend_create_dsv4_prefill_backend 添加 is_hip() 判断,将 ROCm 平台分派到 deepseek_v4_backend_hip_radixDeepseekV4MultiStepBackend / DeepseekV4HipRadixBackend
    deepseek_v4_backend_hip_radix.py 中重写 init_forward_metadata_target_verify,使其始终走 eager metadata 构造路径(旧版 init_forward_metadata_target_verify_old),避免在 SGLANG_PREP_IN_CUDA_GRAPH 开启时因 lazy-upgrade 导致规划器不变量被破坏。同时修正了多步 draft 的 out_cache_loc 切片:在 init_forward_metadata 解码态下通过 per_step_draft_out_cache_loc 取当前 step 的缓存位置。

  2. Fused Triton mHC post+pre 算子
    新增 python/sglang/srt/models/deepseek_common/amd/deepseek_v4_fused_mhc.py,实现 try_fused_hc_post_pre 函数。当满足条件(token 数 ≤64、启用环境变量 SGLANG_OPT_USE_TRITON_FUSED_MHC、GFX95 支持)时,尝试调用 aitermhc_post_pre 内核,将 hc_posthc_pre 和后续的 layernorm 合并为单个 Triton 启动。该函数返回 (residual, hidden_states, post, comb, norm_fused) 五元组;若内核不可用则返回 None,调用方回退到原有的分开执行路径。

  3. 优化 Triton 稀疏注意力解码内核
    triton_mla_kernels_decode_optimized.py 中新增 _should_use_fused_nosplitk 决策函数,用于大 batch(≥256 token)时选择 fused no-splitk 双域注意力内核,实测比分离 gather+attention 路径快 10-14%。_should_use_fused_dual_scope 的阈值也随之调整,以匹配 MI355X 实测数据。

  4. 启用 w_qkv 融合与 aiter 采样
    deepseek_v4.py 中移除 fuse_wqa_wkvnot _is_hip 限制,使 ROCm 也能使用 w_qkv 融合压缩;添加 _freqs_cis_to_cos_sin 缓存以避免各层重复转换。在 sampler.py 中,当 SGLANG_USE_AITER 为 true 且为 HIP 时,将 all-greedy 采样替换为 aiter.greedy_sample,写入预分配的 int32 缓冲。

  5. 测试与配置配套
    新增 test/registered/ops/test_aiter_greedy_sample_amd.py,覆盖 kernel 级与 torch.argmax 的等价性验证、各种 batch 与 vocab 尺寸、tied values、负数 logits 等;注册到 AMD CI stage-b 测试套件。同时在 environ.py 中新增 SGLANG_OPT_USE_TRITON_FUSED_MHC 开关。

文件 模块 状态 重要度
python/sglang/srt/models/deepseek_common/amd/deepseek_v4_fused_mhc.py 融合算子 added 8.91
test/registered/ops/test_aiter_greedy_sample_amd.py AMD 测试 added 7.76
python/sglang/srt/layers/attention/nsa/triton_decode/triton_mla_kernels_decode_optimized.py 注意力内核 modified 7.68
python/sglang/srt/models/deepseek_v4.py 模型适配 modified 7.68
python/sglang/srt/layers/attention/deepseek_v4_backend_hip_radix.py HIP 后端 modified 7.25

关键符号

try_fused_hc_post_pre _get_triton_mhc_post_pre_ops _get_fused_hc_post_pre_buffers _should_use_fused_nosplitk _freqs_cis_to_cos_sin

关键源码片段

python/sglang/srt/models/deepseek_common/amd/deepseek_v4_fused_mhc.py core-logic

新增的 AMD 专用融合 mHC 算子模块,封装了从导入尝试到缓冲区分配再到运行的核心逻辑,是整个 PR 性能优化的关键。

import logging
from typing import Optional, Tupleimport torch
import tritonfrom sglang.srt.environ import envslogger = logging.getLogger(__name__)# 模块级常量与全局状态
_FUSED_HC_POST_PRE_M_THRESHOLD = 64 # batch token 阈值
_FUSED_HC_POST_PRE_CACHE: dict[tuple, dict[str, torch.Tensor]] = {} # 缓冲池缓存
_TRITON_MHC_POST_PRE_OPS = None # 惰性导入的结果
_TRITON_MHC_POST_PRE_RUNTIME_DISABLED = False # 运行时禁用标记
​
​
def _get_triton_mhc_post_pre_ops():
    """惰性导入 aiter 的 mhc_post_pre 及其配置工具,结果全局缓存。"""
    global _TRITON_MHC_POST_PRE_OPS
    if _TRITON_MHC_POST_PRE_OPS is not None:
        return _TRITON_MHC_POST_PRE_OPS
    try:
        from aiter.ops.triton.fusions.mhc import mhc_post_pre
        from aiter.ops.triton.utils.mhc_config_utils import get_mhc_config
    except Exception as err:
        # 导入失败时返回 None,调用方自行降级
        logger.warning("Triton fused mHC (mhc_post_pre) is unavailable, falling back: %s", err)
        return None
    _TRITON_MHC_POST_PRE_OPS = (mhc_post_pre, get_mhc_config)
    return _TRITON_MHC_POST_PRE_OPS
​
​
def _get_fused_hc_post_pre_buffers(
    num_tokens: int, hidden_size: int, hc_mult: int,
    dtype: torch.dtype, device: torch.device,
) -> Optional[dict[str, torch.Tensor]]:
    """根据 shape 与配置获取预分配的中间缓冲(来自全局缓存或新创建)。"""
    ops = _get_triton_mhc_post_pre_ops()
    if ops is None:
        return None
    _, get_mhc_config = ops
    key = (num_tokens, hidden_size, hc_mult, dtype, device.type, device.index)
    bufs = _FUSED_HC_POST_PRE_CACHE.get(key)
    if bufs is not None:
        return bufs
    # 通过 get_mhc_config 获取 kernel 配置,动态计算拆分组数
    try:
        cfg, _ = get_mhc_config("MHC_FUSED", num_tokens, hidden_size, mode="sinkhorn")
    except Exception as err:
        logger.warning("Failed to initialize fused mHC config, falling back: %s", err)
        return None
    n_total = 2 * hc_mult + hc_mult * hc_mult
    k_dim = hc_mult * hidden_size
    block_k = cfg.get("BLOCK_K", min(512, triton.next_power_of_2(k_dim)))
    block_k = min(block_k, triton.next_power_of_2(k_dim))
    block_c_split = max(block_k // hc_mult, 1)
    num_ksplit = triton.cdiv(hidden_size, block_c_split)
    # 分配全部中间缓冲
    bufs = {
        "residual_out": torch.empty(num_tokens, hc_mult, hidden_size, dtype=dtype, device=device),
        "layer_input_out": torch.empty(num_tokens, hidden_size, dtype=dtype, device=device),
        "h_post": torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=device),
        "h_res": torch.empty(num_tokens, hc_mult, hc_mult, dtype=torch.float32, device=device),
        "acc_partial": torch.empty(num_ksplit, num_tokens, n_total, dtype=torch.float32, device=device),
        "acc_sq_partial": torch.empty(num_ksplit, num_tokens, dtype=torch.float32, device=device),
    }
    _FUSED_HC_POST_PRE_CACHE[key] = bufs
    return bufs
test/registered/ops/test_aiter_greedy_sample_amd.py test-coverage

新增的 aiter.greedy_sample 单元测试,覆盖 kernel 正确性与 Sampler 集成,注册到 AMD CI。

import unittest
from unittest import mockimport torchfrom sglang.srt.utils.common import is_hip
from sglang.test.ci.ci_register import register_amd_ci# 注册到 AMD CI stage-b 套件,预估 60 秒
register_amd_ci(est_time=60, suite="stage-b-test-1-gpu-small-amd")
​
​
def _mock_global_server_args(backend="pytorch"):
    """替换 sampler 模块中的全局配置为伪值,避免真正初始化 ServerArgs。"""
    from sglang.srt.layers import sampler as sampler_mod
    from sglang.srt.server_args import ServerArgs
    sampler_mod.get_global_server_args = lambda: ServerArgs(
        model_path="dummy", sampling_backend=backend,
    )
    class _DummyTPGroup:
        device_group = None
    sampler_mod.get_tp_group = lambda: _DummyTPGroup()
    sampler_mod.is_dp_attention_enabled = lambda: False
​
​
def _make_sampling_info(batch_size, vocab_size, device="cuda"):
    """构造一个全 greedy 的 SamplingBatchInfo 用于测试。"""
    from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
    return SamplingBatchInfo(
        temperatures=torch.ones(batch_size, 1, device=device, dtype=torch.float),
        top_ps=torch.ones(batch_size, device=device),
        top_ks=torch.zeros(batch_size, device=device, dtype=torch.int32),
        min_ps=torch.zeros(batch_size, device=device),
        is_all_greedy=True,
        need_top_p_sampling=False,
        need_top_k_sampling=False,
        need_min_p_sampling=False,
        vocab_size=vocab_size,
        device=device,
    )
​
​
@unittest.skipUnless(is_hip(), "aiter greedy_sample requires ROCm")
class TestAiterGreedySample(unittest.TestCase):
    """Kernel 级正确性:aiter.greedy_sample 与 torch.argmax 等价。"""
​
    @classmethod
    def setUpClass(cls):
        try:
            from aiter import greedy_sample
            cls.greedy_sample = staticmethod(greedy_sample)
        except ImportError:
            raise unittest.SkipTest("aiter not installed")
        cls.device = "cuda"
​
    def setUp(self):
        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
​
    def _run_and_compare(self, batch_size, vocab_size):
        """核心比较:随机 bf16 logits -> aiter.greedy_sample vs argmax。"""
        logits = torch.randn(batch_size, vocab_size, device=self.device, dtype=torch.bfloat16)
        expected = torch.argmax(logits, dim=-1)
        actual = torch.empty(logits.shape[0], device=logits.device, dtype=torch.int32)
        self.greedy_sample(actual, logits)
        self.assertTrue(
            torch.equal(actual.to(expected.dtype), expected),
            f"Mismatch for shape ({batch_size}, {vocab_size})",
        )
​
    def test_single_request(self):
        self._run_and_compare(1, 32000)
​
    def test_small_batch(self):
        self._run_and_compare(4, 32000)
​
    # 更多测试(medium/large batch, 真实词汇量 , tied values 等)在完整文件中

评论区精华

导入失败未缓存导致重复日志 性能

gemini-code-assist[bot] 指出 _get_triton_mhc_post_pre_ops 在导入失败后 _TRITON_MHC_POST_PRE_OPS 仍为 None,导致每次 forward 都重试导入并打 warning。建议用布尔标记缓存失败。

结论:未采纳 / 未回复 · unresolved

id(freqs_cis) 缓存键存在回收风险 正确性

gemini-code-assist[bot] 指出 _freqs_cis_to_cos_sin 使用 id(freqs_cis) 作缓存键,若原 tensor 被 GC 回收则地址可重用,导致返回错误的 cos/sin。建议在缓存值中保留引用。

结论:未采纳 / 未回复 · unresolved

SGLANG_USE_AITER 应使用集中 envs 机制 设计

gemini-code-assist[bot] 指出 sampler.py 中直接调用 get_bool_env_var('SGLANG_USE_AITER') 应改为使用 sglang.srt.environ.envs 配置系统。

结论:未采纳 / 未回复 · unresolved

新增 extend_seq_lens 参数未被使用 正确性

gemini-code-assist[bot] 指出 init_forward_metadata_target_verify 新增 extend_seq_lens 参数,但函数体内完全忽略,也未传给下层。

结论:未采纳 / 未回复 · unresolved

风险与影响

  1. id 缓存回收风险_freqs_cis_to_cos_sin 使用 id(freqs_cis) 作为 key,若原 tensor 被垃圾回收,地址重用可能导致返回错误的 cos/sin 表。实践中 freqs_cis 常驻,但仍是契约风险。
  2. extend_seq_lens 被忽略init_forward_metadata_target_verify 新增的 extend_seq_lens 参数未传入下层,可能影响特定场景下 target-verify 的元数据构造,导致偶发错误。
  3. 依赖回退路径测试不足aiter.greedy_samplemhc_post_pre 依赖外部 aiter 包,缺失或版本不兼容时会静默回退到 PyTorch 原语,但回退路径的精度与性能未在 CI 中充分验证。
  4. HIP 后端专用代码的维护成本:部分修改通过 is_hip() 分支隔离,未来若 CUDA 侧逻辑变化可能导致 HIP 分支滞后。

用户影响:AMD ROCm 用户运行 DeepSeek V4 模型时,MTP 推测解码不再因 CUDA Graph 捕获失败而中断;解码延迟降低(fused kernel 加速)。CUDA 用户不受影响(行为不变的 HIP 分支)。
系统影响:新增环境变量 SGLANG_OPT_USE_TRITON_FUSED_MHC 控制融合开关;新测试仅运行在 AMD CI。
团队影响:AMD 团队维护约 159 行新模块 deepseek_v4_fused_mhc.py 和 289 行测试,后续可通过移除旧路径逐步简化。

id 回收风险 extend_seq_lens 忽略 依赖条件回退 核心路径变更 review 建议未修复

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论