Prhub

#23146 [AMD] Enable EAGLE speculative decoding for Qwen3.5 FP8 and MXFP4 models with aiter's unified attention

原始 PR 作者 hubertlu-tw 合并时间 2026-05-05 15:09 文件变更 4 提交数 6 评论 17 代码增减 +588 / -148

执行摘要

AMD 启用 EAGLE 统一注意力验证并修复 MXFP4 加载

之前当SGLANG_USE_AITER_UNIFIED_ATTN=1配合--speculative-algorithm EAGLE时,非MLA且topk==1的EAGLE缺少unified_attention目标验证路径,被迫使用低效的自定义mask扩展注意力。此外,Quark导出的MXFP4 Qwen3.5检查点中MTP模块为bf16,但加载器误分配MXFP4 packed形状导致加载失败(issue #23113),且--quantization quark因未注册而被CLI拒绝。

值得精读,尤其关注注意力后端如何适配不同数据类型(MLA/non-MLA)和投机解码布局(ragged vs paged)。设计决策(如保持radix-cache分离)体现了模块化思维。建议后续补充单元测试覆盖新路径。

讨论亮点
  1. gemini-code-assist[bot] 指出target_verify的unified_attention输出视图应使用layer.v_head_dim而非qk_head_dim,与缓存形状一致。作者已修正。

  2. gemini-code-assist[bot] 建议在unified_attention调用中尊重sliding_window_size而非硬编码(-1,-1)。作者后续提交添加了滑动窗口支持。

  3. mqhc2020 建议将inline的Triton kernel移到triton_ops模块。HaiShaw也建议后续重构,作者已将kernel移至新文件。

  4. kkHuang-amd 指出init_forward_metadata中else分支对MLA或非unified_attention情况下可能破坏逻辑。HaiShaw询问后作者已处理(具体未明确但PR已合并)。

实现拆解

  1. aiter_backend.py核心改造:在AiterAttnBackend中新增_build_unified_page_table_from_spec_build_verify_unified_metadata方法,将spec_info的ragged token级索引转换为unified_attention所需的2D块级页表;在forward_extend中根据_use_unified_verify标志(非MLA且topk==1生效)将target_verify路径路由到unified_attention而非extend_attention_fwd;预计算qo_indptr_unified_decode避免逐请求cumsum。

  2. 新增Triton scatter kernel:在aiter_unified_attention.py中实现scatter_ragged_to_page_table_kernelscatter_req_to_token_to_page_table_kernel,高效完成ragged到page_table的并行转换,并支持SWA槽映射。

  3. 滑动窗口支持:在init_forward_metadata_build_verify_unified_metadata中传递swa_page_table,根据layer.sliding_window_size设置window_size参数。

  4. 修复MTP量化加载:在qwen3_5_mtp.py__init__中,检测quark量化配置的排除层列表,若包含mtp.*则跳过量化,确保线性层分配bf16形状。

  5. 注册quark量化选项:在server_args.pyQUANTIZATION_CHOICES中添加"quark"以支持CLI输入。测试配套:本次变更未新增测试文件,依赖自有实验验证。

文件 模块 状态 重要度
python/sglang/srt/layers/attention/aiter_backend.py 注意力后端 modified 8.39
python/sglang/srt/layers/attention/triton_ops/aiter_unified_attention.py 注意力 Kernel added 6.68
python/sglang/srt/models/qwen3_5_mtp.py 模型加载 modified 6.18
python/sglang/srt/server_args.py 配置 modified 4.18

关键符号

_build_unified_page_table_from_spec _build_verify_unified_metadata scatter_ragged_to_page_table_kernel scatter_req_to_token_to_page_table_kernel forward_extend

关键源码片段

python/sglang/srt/layers/attention/aiter_backend.py core-logic

主要改动,新增 EAGLE 统一验证路径和页表转换逻辑,核心变更文件

    def _build_unified_page_table_from_spec(
        self,
        spec_info,
        bs: int,
        dest_buf: Optional[torch.Tensor] = None,
        swa_dest_buf: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Convert ragged (token-level) kv_indices from spec_info into a 2D
        block-level page_table of shape (bs, max_num_blocks_per_seq).
        unified_attention expects max_seqlen_k = page_table.shape[1] * page_size
        to be a captured constant, so rows are sized to the backend-level
        max_num_blocks_per_seq regardless of seqused_k.
        """
        kv_indptr = spec_info.kv_indptr
        kv_flat = spec_info.kv_indices
        page_size = self.page_size
        max_blocks = (self.max_context_len + page_size - 1) // page_size
​
        swa_slot_mapping = None
        swa_page_table = None
​
        if dest_buf is not None:
            # The scatter kernel fills [0, num_blocks) and loads past that use
            # other=0, so the tail is 0-filled. Under graph replay rows > bs
            # are stale but unified_attention only walks rows [0, bs).
            page_table = dest_buf
        else:
            page_table = torch.zeros(
                bs, max_blocks, dtype=torch.int32, device=self.device
            )
​
        if self.use_sliding_window_kv_pool:
            swa_slot_mapping = self.token_to_kv_pool.full_to_swa_index_mapping.long()
​
        if swa_dest_buf is not None:
            swa_page_table = swa_dest_buf
        elif self.use_sliding_window_kv_pool:
            swa_page_table = torch.zeros_like(page_table)
​
        # Launch scatter kernel to populate page_table from ragged indices
        scatter_ragged_to_page_table_kernel[(bs, max_blocks)](
            kv_flat,
            kv_indptr,
            page_table,
            page_table.stride(0),
            swa_page_table,
            swa_slot_mapping,
            PAGE_SIZE=page_size,
            BLOCK_SIZE=128, # internal block size for Triton
            HAS_SWA=swa_page_table is not None,
        )
        return page_table
python/sglang/srt/layers/attention/triton_ops/aiter_unified_attention.py infrastructure

新增的 Triton scatter kernel,将 ragged 索引转换为块级页表,是统一验证路径的基础

import triton
import triton.language as tl
​
​
@triton.jit
def scatter_ragged_to_page_table_kernel(
    kv_flat_ptr,
    kv_indptr_ptr,
    dest_ptr,
    dest_stride,
    sw_page_table_ptr,
    swa_slot_mapping_ptr,
    PAGE_SIZE: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    HAS_SWA: tl.constexpr,
):
    """Scatter ragged token-level kv_indices into a 2D block-level page table."""
    pid = tl.program_id(0) # request index
    block_id = tl.program_id(1) # block offset index
​
    start = tl.load(kv_indptr_ptr + pid).to(tl.int64)
    kv_len = tl.load(kv_indptr_ptr + pid + 1).to(tl.int64) - start
    num_blocks = (kv_len + PAGE_SIZE - 1) // PAGE_SIZE
​
    offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    if block_id * BLOCK_SIZE >= num_blocks:
        return
    mask = offsets < num_blocks
    token_idx = offsets.to(tl.int64) * PAGE_SIZE
    vals = tl.load(kv_flat_ptr + start + token_idx, mask=mask, other=0)
    block_vals = vals // PAGE_SIZE
    tl.store(
        dest_ptr + pid.to(tl.int64) * dest_stride + offsets,
        block_vals,
        mask=mask,
    )
​
    if HAS_SWA:
        sw_vals = tl.load(swa_slot_mapping_ptr + vals)
        block_vals = sw_vals // PAGE_SIZE
        tl.store(
            sw_page_table_ptr + pid.to(tl.int64) * dest_stride + offsets,
            block_vals,
            mask=mask,
        )

评论区精华

输出视图维度使用 v_head_dim vs qk_head_dim 正确性

gemini-code-assist[bot] 指出 out 视图应使用 layer.v_head_dim 以匹配 value 缓存形状,避免维度不匹配。

结论:作者接受了建议,在后续提交中修正。 · 已解决

启用滑动窗口注意力支持 正确性

gemini-code-assist[bot] 建议将 window_size 硬编码 (-1,-1) 改为尊重 layer.sliding_window_size,类似 decode 路径。

结论:作者添加了滑动窗口支持(提交 2 和 3)。 · 已解决

将内联 Triton kernel 重构到 triton_ops 模块 设计

mqhc2020 建议将 inline 的 Triton jit kernel 移到 triton_ops 目录下统一管理;HaiShaw 也建议后续重构。

结论:作者在第四提交中将 kernel 移至新文件 python/sglang/srt/layers/attention/triton_ops/aiter_unified_attention.py。 · 已解决

init_forward_metadata 中 else 分支对 MLA 和非 unified_attention 的潜在破坏 正确性

kkHuang-amd 指出当 use_mla 或 use_triton_unified_attention 为 False 时,新的 else 分支中的滑动窗口和页表变换逻辑可能导致错误。HaiShaw 回复 'this?' 表示需要关注。

结论:作者可能已处理(PR 合并),但具体响应不明确。 · addressed

风险与影响

该PR引入多项核心逻辑变更:

1) 环境变量SGLANG_AITER_UNIFIED_VERIFY控制新路径,默认开启,但未设置时可能回退到旧路径,需确保兼容性;
2) 新增的Triton kernel在非AMD平台可能未定义,通过try/except导入aiter,但scatter kernel始终编译,可能引入符号冲突;
3) Qwen3.5 MTP量化检测逻辑依赖quant_config.get_name()exclude_layers属性,若Quark配置格式变化可能失效;
4) 缺少单元测试,仅依赖手动GSM8K验证,回归风险存在;
5) 预计算qo_indptr_unified_decode假设q_len==1,若未来支持长序列可能会出错。

用户影响:AMD gfx950/MI355X用户可体验Qwen3.5 397B FP8/MXFP4模型启用EAGLE投机解码(需设置环境变量并--disable-radix-cache),吞吐提升约10倍(根据输出吞吐指标),精度保持94%+。系统影响:改动限制在注意力后端和量化配置,未波及调度器或其他模块。团队影响:维护者需关注相关环境变量和新kernel维护。

缺少测试覆盖 环境变量依赖 平台兼容性风险

关联 Issue

#23113 [Bug] [rocm] qwen 3.5 mtp fp4 broken
#23461 [AMD] Add kimi-k2.5 eagle3 support with unified attention.

完整报告

参与讨论