执行摘要
- 一句话:AMD启用EAGLE统一注意力验证并修复MXFP4加载
- 推荐动作:值得精读,尤其关注注意力后端如何适配不同数据类型(MLA/non-MLA)和投机解码布局(ragged vs paged)。设计决策(如保持radix-cache分离)体现了模块化思维。建议后续补充单元测试覆盖新路径。
功能与动机
之前当SGLANG_USE_AITER_UNIFIED_ATTN=1配合--speculative-algorithm EAGLE时,非MLA且topk==1的EAGLE缺少unified_attention目标验证路径,被迫使用低效的自定义mask扩展注意力。此外,Quark导出的MXFP4 Qwen3.5检查点中MTP模块为bf16,但加载器误分配MXFP4 packed形状导致加载失败(issue #23113),且--quantization quark因未注册而被CLI拒绝。
实现拆解
-
aiter_backend.py核心改造:在AiterAttnBackend中新增_build_unified_page_table_from_spec和_build_verify_unified_metadata方法,将spec_info的ragged token级索引转换为unified_attention所需的2D块级页表;在forward_extend中根据_use_unified_verify标志(非MLA且topk==1生效)将target_verify路径路由到unified_attention而非extend_attention_fwd;预计算qo_indptr_unified_decode避免逐请求cumsum。
-
新增Triton scatter kernel:在aiter_unified_attention.py中实现scatter_ragged_to_page_table_kernel和scatter_req_to_token_to_page_table_kernel,高效完成ragged到page_table的并行转换,并支持SWA槽映射。
-
滑动窗口支持:在init_forward_metadata和_build_verify_unified_metadata中传递swa_page_table,根据layer.sliding_window_size设置window_size参数。
-
修复MTP量化加载:在qwen3_5_mtp.py的__init__中,检测quark量化配置的排除层列表,若包含mtp.*则跳过量化,确保线性层分配bf16形状。
-
注册quark量化选项:在server_args.py的QUANTIZATION_CHOICES中添加"quark"以支持CLI输入。测试配套:本次变更未新增测试文件,依赖自有实验验证。
关键文件:
python/sglang/srt/layers/attention/aiter_backend.py(模块 注意力后端;类别 source;类型 core-logic;符号 _build_unified_page_table_from_spec, _build_verify_unified_metadata, forward_extend, init_forward_metadata): 主要改动,新增EAGLE统一验证路径和页表转换逻辑,核心变更文件
python/sglang/srt/layers/attention/triton_ops/aiter_unified_attention.py(模块 注意力Kernel;类别 infra;类型 infrastructure;符号 scatter_ragged_to_page_table_kernel, scatter_req_to_token_to_page_table_kernel): 新增的Triton scatter kernel,将ragged索引转换为块级页表,是统一验证路径的基础
python/sglang/srt/models/qwen3_5_mtp.py(模块 模型加载;类别 source;类型 data-contract;符号 init): 修复Quark量化MTP模块bf16加载异常,关联issue #23113
python/sglang/srt/server_args.py(模块 配置;类别 source;类型 core-logic): 注册quark量化选项使CLI可接受--quantization quark
关键符号:_build_unified_page_table_from_spec, _build_verify_unified_metadata, scatter_ragged_to_page_table_kernel, scatter_req_to_token_to_page_table_kernel, forward_extend
关键源码片段
python/sglang/srt/layers/attention/aiter_backend.py
主要改动,新增EAGLE统一验证路径和页表转换逻辑,核心变更文件
def _build_unified_page_table_from_spec(
self,
spec_info,
bs: int,
dest_buf: Optional[torch.Tensor] = None,
swa_dest_buf: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Convert ragged (token-level) kv_indices from spec_info into a 2D
block-level page_table of shape (bs, max_num_blocks_per_seq).
unified_attention expects max_seqlen_k = page_table.shape[1] * page_size
to be a captured constant, so rows are sized to the backend-level
max_num_blocks_per_seq regardless of seqused_k.
"""
kv_indptr = spec_info.kv_indptr
kv_flat = spec_info.kv_indices
page_size = self.page_size
max_blocks = (self.max_context_len + page_size - 1) // page_size
swa_slot_mapping = None
swa_page_table = None
if dest_buf is not None:
# The scatter kernel fills [0, num_blocks) and loads past that use
# other=0, so the tail is 0-filled. Under graph replay rows > bs
# are stale but unified_attention only walks rows [0, bs).
page_table = dest_buf
else:
page_table = torch.zeros(
bs, max_blocks, dtype=torch.int32, device=self.device
)
if self.use_sliding_window_kv_pool:
swa_slot_mapping = self.token_to_kv_pool.full_to_swa_index_mapping.long()
if swa_dest_buf is not None:
swa_page_table = swa_dest_buf
elif self.use_sliding_window_kv_pool:
swa_page_table = torch.zeros_like(page_table)
# Launch scatter kernel to populate page_table from ragged indices
scatter_ragged_to_page_table_kernel[(bs, max_blocks)](
kv_flat,
kv_indptr,
page_table,
page_table.stride(0),
swa_page_table,
swa_slot_mapping,
PAGE_SIZE=page_size,
BLOCK_SIZE=128, # internal block size for Triton
HAS_SWA=swa_page_table is not None,
)
return page_table
python/sglang/srt/layers/attention/triton_ops/aiter_unified_attention.py
新增的Triton scatter kernel,将ragged索引转换为块级页表,是统一验证路径的基础
import triton
import triton.language as tl
@triton.jit
def scatter_ragged_to_page_table_kernel(
kv_flat_ptr,
kv_indptr_ptr,
dest_ptr,
dest_stride,
sw_page_table_ptr,
swa_slot_mapping_ptr,
PAGE_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_SWA: tl.constexpr,
):
"""Scatter ragged token-level kv_indices into a 2D block-level page table."""
pid = tl.program_id(0) # request index
block_id = tl.program_id(1) # block offset index
start = tl.load(kv_indptr_ptr + pid).to(tl.int64)
kv_len = tl.load(kv_indptr_ptr + pid + 1).to(tl.int64) - start
num_blocks = (kv_len + PAGE_SIZE - 1) // PAGE_SIZE
offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_id * BLOCK_SIZE >= num_blocks:
return
mask = offsets < num_blocks
token_idx = offsets.to(tl.int64) * PAGE_SIZE
vals = tl.load(kv_flat_ptr + start + token_idx, mask=mask, other=0)
block_vals = vals // PAGE_SIZE
tl.store(
dest_ptr + pid.to(tl.int64) * dest_stride + offsets,
block_vals,
mask=mask,
)
if HAS_SWA:
sw_vals = tl.load(swa_slot_mapping_ptr + vals)
block_vals = sw_vals // PAGE_SIZE
tl.store(
sw_page_table_ptr + pid.to(tl.int64) * dest_stride + offsets,
block_vals,
mask=mask,
)
评论区精华
-
gemini-code-assist[bot] 指出target_verify的unified_attention输出视图应使用layer.v_head_dim而非qk_head_dim,与缓存形状一致。作者已修正。
-
gemini-code-assist[bot] 建议在unified_attention调用中尊重sliding_window_size而非硬编码(-1,-1)。作者后续提交添加了滑动窗口支持。
-
mqhc2020 建议将inline的Triton kernel移到triton_ops模块。HaiShaw也建议后续重构,作者已将kernel移至新文件。
-
kkHuang-amd 指出init_forward_metadata中else分支对MLA或非unified_attention情况下可能破坏逻辑。HaiShaw询问后作者已处理(具体未明确但PR已合并)。
- 输出视图维度使用 v_head_dim vs qk_head_dim (correctness): 作者接受了建议,在后续提交中修正。
- 启用滑动窗口注意力支持 (correctness): 作者添加了滑动窗口支持(提交2和3)。
- 将内联 Triton kernel 重构到 triton_ops 模块 (design): 作者在第四提交中将 kernel 移至新文件 python/sglang/srt/layers/attention/triton_ops/aiter_unified_attention.py。
- init_forward_metadata 中 else 分支对 MLA 和非 unified_attention 的潜在破坏 (correctness): 作者可能已处理(PR 合并),但具体响应不明确。
风险与影响
- 风险:该PR引入多项核心逻辑变更:
1) 环境变量SGLANG_AITER_UNIFIED_VERIFY控制新路径,默认开启,但未设置时可能回退到旧路径,需确保兼容性;
2) 新增的Triton kernel在非AMD平台可能未定义,通过try/except导入aiter,但scatter kernel始终编译,可能引入符号冲突;
3) Qwen3.5 MTP量化检测逻辑依赖quant_config.get_name()和exclude_layers属性,若Quark配置格式变化可能失效;
4) 缺少单元测试,仅依赖手动GSM8K验证,回归风险存在;
5) 预计算qo_indptr_unified_decode假设q_len==1,若未来支持长序列可能会出错。
- 影响:用户影响:AMD gfx950/MI355X用户可体验Qwen3.5 397B FP8/MXFP4模型启用EAGLE投机解码(需设置环境变量并
--disable-radix-cache),吞吐提升约10倍(根据输出吞吐指标),精度保持94%+。系统影响:改动限制在注意力后端和量化配置,未波及调度器或其他模块。团队影响:维护者需关注相关环境变量和新kernel维护。
- 风险标记:缺少测试覆盖, 环境变量依赖, 平台兼容性风险
关联脉络
- PR #23113 [Bug] [rocm] qwen 3.5 mtp fp4 broken: 关联的bug issue,本PR修复了MTP加载问题
- PR #23461 [AMD] Add kimi-k2.5 eagle3 support with unified attention.: 相关的draft-decode修复PR,本PR保持范围外但依赖
参与讨论