Prhub

#22342 [AMD] Enable DFLASH speculative decoding on ROCm

原始 PR 作者 andyluo7 合并时间 2026-04-18 04:10 文件变更 3 提交数 3 评论 5 代码增减 +33 / -10

执行摘要

AMD ROCm DFLASH 投机解码支持

DFLASH 投机解码目前仅支持 FlashInfer/FA3/FA4 后端,这些后端在 AMD ROCm 上不可用。为了在 AMD 硬件上运行 DFLASH,需要启用 Triton 注意力后端。

值得精读,尤其是 Triton 后端的 custom_mask 守卫和 ROCm 的 fallback 逻辑。建议后续合并 fallback 逻辑为单一 helper 函数以简化维护。

讨论亮点

gemini-code-assist[bot] 建议合并 fallback 后端选择逻辑(重复的 'triton' if torch.version.hip else 'flashinfer')并移除冗余的 torch 局部导入,但 PR 未采纳该建议。

实现拆解

  1. dflash_worker.py: 在 supported_draft_backends 元组中添加 "triton";当 draft_backend 为 None 时,自动检测 torch.version.hip 并默认为 triton;trtllm_mha 和不支持后端的 fallback 也改为 ROCm 感知。
  2. triton_backend.py: 在 CUDA graph capture 和 replay 的两个分支中,对 custom_mask 的访问添加 None 守卫,避免非因果 ENCODER_ONLY attention 模式下 spec_info.custom_mask 为 None 时崩溃。
  3. dflash.py: 修改 apply_k_rope 中的 dummy_q shape,从 [batch, head_dim] 改为匹配 K 的完整 shape [batch, num_heads, head_dim],以满足 ROCm sgl_kernel.rotary_embedding 的 num_heads % num_kv_heads == 0 约束。
  4. qwen3.py: 为 Qwen3ForCausalLM 添加 set_dflash_layers_to_capture 方法,使其支持 DFLASH hidden state 捕获。
文件 模块 状态 重要度
python/sglang/srt/speculative/dflash_worker.py 投机解码 modified 6.64
python/sglang/srt/layers/attention/triton_backend.py 注意力层 modified 6.06
python/sglang/srt/models/dflash.py 模型定义 modified 5.28

关键符号

DFlashWorker.__init__ TritonAttentionBackend.init_forward_metadata_capture_cuda_graph TritonAttentionBackend.init_forward_metadata_replay_cuda_graph DFlashAttention.apply_k_rope

关键源码片段

python/sglang/srt/speculative/dflash_worker.py dependency-wiring

核心入口:将 triton 加入支持后端列表,实现 ROCm 自动检测与 fallback 逻辑。

# dflash_worker.py (partial: backend selection)
supported_draft_backends = ("flashinfer", "fa3", "fa4", "triton")if draft_backend is None:
    draft_backend, _ = draft_server_args.get_attention_backends()
if draft_backend is None:
    import torch as _torch
    # 默认:ROCm 上使用 triton ,CUDA 上使用 flashinfer
    draft_backend = "triton" if _torch.version.hip else "flashinfer"
elif draft_backend == "trtllm_mha":
    import torch as _torch
    _fb = "triton" if _torch.version.hip else "flashinfer"
    logger.warning("DFLASH draft worker does not support 'trtllm_mha'...Falling back to '%s'.", _fb)
    draft_backend = _fb
elif draft_backend not in supported_draft_backends:
    import torch as _torch
    _fb = "triton" if _torch.version.hip else "flashinfer"
    logger.warning("DFLASH draft worker only supports %s...Falling back to '%s'.", supported_draft_backends, draft_backend, _fb)
    draft_backend = _fb
python/sglang/srt/layers/attention/triton_backend.py core-logic

修复 CUDA graph 下 custom_mask 为 None 时的崩溃,影响所有使用非因果注意力的投机模式。

# triton_backend.py (partial: CUDA graph custom_mask guard)
custom_mask = self.cuda_graph_custom_mask
if (
    spec_info is not None
    and getattr(spec_info, "custom_mask", None) is not None
):
    custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
else:
    # 非因果模式下(如 DFLASH draft 的 ENCODER_ONLY attention )没有 mask ,避免访问 None.shape
    custom_mask = None
python/sglang/srt/models/dflash.py data-contract

修复 RoPE kernel 在 ROCm 上的 shape 兼容性,影响所有使用 sgl_kernel 的推理。

# dflash.py (partial: apply_k_rope)
def apply_k_rope(self, positions: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
    # 匹配 K shape 使得 RoPE kernel 的 head count 检查能通过所有后端(尤其是 ROCm 的 sgl_kernel )
    dummy_q = k.new_empty(k.shape) # 原为 (k.shape[0], self.head_dim) ,只对单头,导致 num_heads % num_kv_heads != 0
    _, k = self.rotary_emb(positions, dummy_q, k)
    return k

评论区精华

重复的 fallback 逻辑与冗余 import style

gemini-code-assist[bot] 指出 dflash_worker.py 中多次重复编写 'triton' if torch.version.hip else 'flashinfer' 且局部导入 torch 冗余,建议提取为 fallback_backend 变量。

结论:未采纳,PR 保持原有风格合并。 · unresolved

风险与影响

  • 回归风险: 低。Triton 后端已被 SGLang 用于非投机场景,新增的 None 守卫是防御性编程。
  • 性能风险: 低。仅在 ROCm 上启用新后端,CUDA 行为不变。
  • 兼容性风险: 低。Qwen3 的 set_dflash_layers_to_capture 方法与 Llama 实现一致,无破坏性。
  • 测试覆盖: 缺少显式 ROCm 测试文件,依赖 CI 中的 test_dflash.py 来验证。
  • 用户: AMD ROCm 用户现在可以使用 DFLASH 投机解码,提升推理吞吐量。
  • 系统: 无重大影响,Triton 后端已是系统一部分。
  • 团队: DFLASH 维护者需关注 Triton 后端在 ROCm 上的稳定性。
缺少 ROCm 测试覆盖 重复代码未合并

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论