Prhub

#24933 Amd/deepseek v4 rebase main 0509

原始 PR 作者 kkHuang-amd 合并时间 2026-05-19 00:15 文件变更 17 提交数 28 评论 9 代码增减 +3678 / -70

执行摘要

为 ROCm 平台添加 DeepSeek V4 模型支持,新增 HIP 注意力后端与 Triton 内核

PR body明确说明:'Enable deepseek v4 model support (merge to main) to ROCm platform',目标是让DeepSeek V4系列模型在AMD GPU上首次可运行,作为后续优化的基础,并逐步迁移amd/deepseek_v4分支上的优化工作。

值得精读的文件包括deepseek_v4_backend_hip_radix.py(理解ROCm后端架构)、compress_hip.py(HIP专用压缩设计)和tilelang_kernel.py(TileLang内核实现与monkey-patch技巧)。建议重点关注环境变量治理和条件编译模式,后续可借鉴到其他平台适配。

讨论亮点

主要讨论集中在几个方面:

  • 设计:DarkSharpness建议将compressor相关改动移到单独文件(如compress_hip.py)以避免diff过大,PR采纳该建议并重构。
  • 精度:DarkSharpness担心tgemm.mm(x, y, otype=x.dtype).float()中bf16 in/out再cast到fp32可能导致精度退化,PR作者kkHuang-amd回应实测未见退化。
  • 文件命名:DarkSharpness对hip_flash_mla.py的必要性提出疑问,PR作者表示会专门针对ROCm修复该文件。
  • 风格:DarkSharpness建议将jit_kernel/deepseek_v4.py中新增的docstring移动到文件开头,已接受。

实现拆解

  1. 新增HIP注意力后端:创建deepseek_v4_backend_hip_radix.py,实现DeepseekV4HipRadixBackend,继承AttentionBackendCompressorBackendMixinC4IndexerBackendMixin。定义DSV4AttnMetadata数据结构,包含分页索引、压缩元数据等。_create_flashmla_metadata在HIP平台返回None,不依赖flash_mla库;init_compression_metadata导入Triton内核。
  2. HIP专用压缩器:新增compress_hip.py,实现CompressorHip(继承Compressor基类)。使用Triton RMS Normalize内核替代CUDA版本;use_fused_compress固定为False,通过环境变量SGLANG_OPT_USE_FUSED_COMPRESS选择性开启HIP fused压缩。提供overlap_transform等方法用于预填充阶段。
  3. Flash MLA入口适配:新增hip_flash_mla.py,暴露flash_mla_with_kvcache_entrypoint函数,根据SGLANG_HACK_FLASHMLA_BACKEND环境变量选择“tilelang”或“torch”后端;flash_mla_with_kvcache_torch用于调试对比。
  4. TileLang/Triton内核扩展:在tilelang_kernel.py中大幅新增fp8_paged_mqa_logits_kernel系列、dpsk_v4_fp8_partial_kerneldpsk_v4_fp8_attention_fwd函数,处理FP8分页MQA的logits计算和稀疏注意力解码。同时增加了tilelang适配器bug的monkey-patch(_legalize_result_idx_safe)。
  5. 内存管理与状态扩展:修改deepseek_v4_compress_state.py,为KVAndScore添加from_kv_scoreclonecat等便捷方法,CompressStatePool根据is_hip使用不同的内存分配策略。其他适配包括deepseek_v4_rope.py添加fused_norm_rope_inplace_triton融合内核,deepseek_v4.py调整导入路径,attention_registry.py注册新后端。
文件 模块 状态 重要度
python/sglang/srt/layers/attention/deepseek_v4_backend_hip_radix.py 注意力后端 added 9.18
python/sglang/srt/layers/attention/dsv4/compress_hip.py 压缩器 added 8.99
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py TileLang 内核 modified 8.84
python/sglang/srt/layers/attention/hip_flash_mla.py FlashMLA 适配 added 8.59
python/sglang/srt/mem_cache/deepseek_v4_compress_state.py 压缩状态 modified 8.37
python/sglang/srt/layers/deepseek_v4_rope.py 融合 RoPE modified 7.92
python/sglang/srt/layers/quantization/fp8.py 量化层 modified 7.52
python/sglang/srt/models/deepseek_v4.py 模型定义 modified 7.01

关键符号

_pad_last_dim _create_flashmla_metadata DSV4AttnMetadata.get_flashmla_metadata DSV4AttnMetadata.copy_ _rms_normalize_kernel rms_normalize_triton CompressorHip.use_fused_compress flash_mla_with_kvcache_entrypoint fp8_paged_mqa_logits_kernel dpsk_v4_fp8_attention_fwd _legalize_result_idx_safe fused_norm_rope_inplace_triton KVAndScore.from_kv_score KVAndScore.clone

关键源码片段

python/sglang/srt/layers/attention/deepseek_v4_backend_hip_radix.py core-logic

核心新增文件,实现了 HIP 专用的 DeepSeek V4 注意力后端,包含 DSV4AttnMetadata 数据结构、flash_mla 元数据管理、分页压缩数据创建等关键逻辑,是整个 ROCm 适配的入口。

# 文件 : sglang/srt/layers/attention/deepseek_v4_backend_hip_radix.py
from __future__ import annotationsimport torch
import torch.nn.functional as F
from sglang.srt.utils import ceil_align# 常量定义
SWA_WINDOW = 128
C4_TOPK = 512
PAGE_INDEX_ALIGNED_SIZE = 64
​
​
def _pad_last_dim(x, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE):
    """将张量最后一维补齐到 multiples_of 的整数倍,填充值为 -1。"""
    if x is None:
        return None
    curr_size = x.shape[-1]
    target_size = ceil_align(curr_size, multiples_of)
    return F.pad(x, pad=(0, target_size - curr_size), mode="constant", value=-1)
​
​
def _create_flashmla_metadata():
    """在HIP(ROCm)上返回 None,避免依赖 flash_mla 库。"""
    from sglang.srt.utils import is_hip
    if is_hip():
        return None
    import flash_mla
    return flash_mla.get_mla_metadata()[0]
​
​
def _create_dummy_paged_compress_data(compress_ratio: int):
    """HIP 平台暂不支持 paged compress,返回 None。"""
    return None
​
​
@dataclass
class DSV4AttnMetadata:
    """存储 DeepSeek V4 注意力后端所需的所有元数据。
    包含 page_table、raw_out_loc、各种压缩层级(c1/c4/c128)的索引和 flash_mla 元数据。
    """
    page_size: int
    page_table: torch.Tensor
    raw_out_loc: torch.Tensor
    cuda_int32_kwargs: dict
    seq_lens_casual: torch.Tensor
    positions_casual: torch.Tensor
    swa_page_indices: torch.Tensor
    swa_topk_lengths: torch.Tensor
    c4_sparse_topk: int
    # 可选字段
    c4_out_loc: Optional[torch.Tensor] = None
    c4_topk_lengths_raw: Optional[torch.Tensor] = None
    c4_topk_lengths_clamp1: Optional[torch.Tensor] = None
    # 动态初始化字段
    c4_sparse_topk_lengths: torch.Tensor = field(init=False)
    c4_sparse_page_indices: torch.Tensor = field(init=False)
    c128_out_loc: Optional[torch.Tensor] = None
    c128_page_indices: Optional[torch.Tensor] = None
    c128_topk_lengths_clamp1: Optional[torch.Tensor] = None
    c1_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False)
    c4_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False)
    c128_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False)
​
    @property
    def positions(self) -> torch.Tensor:
        return self.positions_casual
​
    def get_flashmla_metadata(self, compress_ratio: Literal[0, 4, 128]):
        """根据压缩比返回对应的 flash_mla 元数据。"""
        if compress_ratio == 0:
            return self.c1_flashmla_metadata
        elif compress_ratio == 4:
            return self.c4_flashmla_metadata
        elif compress_ratio == 128:
            return self.c128_flashmla_metadata
        else:
            raise ValueError(f"invalid {compress_ratio=}")
python/sglang/srt/layers/attention/dsv4/compress_hip.py core-logic

新增的 HIP 专用压缩器,使用 Triton RMS Normalize 内核,实现 CompressorHip 类,展示 ROCm 上与 CUDA 差异化实现的关键逻辑。

# 文件 : sglang/srt/layers/attention/dsv4/compress_hip.py
from functools import cached_property
import triton
import triton.language as tl
from sglang.srt.layers.attention.dsv4.compressor import Compressor as _CompressorBase
​
​
@triton.jit
def _rms_normalize_kernel(x_ptr, weight_ptr, eps, stride_row, dim,
                           BLOCK_SIZE: tl.constexpr, HAS_WEIGHT: tl.constexpr):
    """Triton 实现的 RMS 归一化内核,支持可选的逐元素权重缩放。"""
    pid = tl.program_id(0)
    offs = tl.arange(0, BLOCK_SIZE)
    mask = offs < dim
    base = pid * stride_row
    x = tl.load(x_ptr + base + offs, mask=mask, other=0.0).to(tl.float32)
    mean_sq = tl.sum(x * x, axis=0) / dim
    rms_inv = tl.rsqrt(mean_sq + eps)
    out = x * rms_inv
    if HAS_WEIGHT:
        weight = tl.load(weight_ptr + offs, mask=mask, other=0.0)
        out = out * weight
    tl.store(x_ptr + base + offs, out, mask=mask)
​
​
def rms_normalize_triton(x: torch.Tensor, eps: float, weight: torch.Tensor = None) -> torch.Tensor:
    """对输入 x 做 RMS 归一化,可选权重。
    将输入展开为二维,逐行调用 Triton 内核。
    """
    dim = x.shape[-1]
    x_flat = x.view(-1, dim)
    num_rows = x_flat.shape[0]
    BLOCK_SIZE = triton.next_power_of_2(dim)
    grid = (num_rows,)
    _rms_normalize_kernel[grid](x_flat, weight, eps, x_flat.stride(0), dim,
                                   BLOCK_SIZE=BLOCK_SIZE, HAS_WEIGHT=(weight is not None))
    return x
​
​
class CompressorHip(_CompressorBase):
    """HIP 平台专用的 Compressor,使用 Triton 内核完成归一化。
    默认关闭 fused_compress(避免 CUDA 特定优化),通过环境变量可选启用。
    """
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.norm = DeepseekRefRMSNorm(self.head_dim, eps=self.norm.variance_epsilon)
​
    @cached_property
    def use_fused_compress(self) -> bool:
        # HIP 上禁用 fused 压缩,改用 Triton 实现
        return False
​
    @cached_property
    def use_hip_fused_compress(self) -> bool:
        # 通过环境变量控制是否使用 HIP 本地的 fused 压缩
        return envs.SGLANG_OPT_USE_FUSED_COMPRESS.get()
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py core-logic

大幅修改并新增大量 TileLang 内核,包括 fp8_paged_mqa_logits_kernel 和 dpsk_v4_fp8_attention_fwd,是 ROCm 计算性能的关键。包含 tilelang 适配器 bug 的绕过补丁。

# 文件 : sglang/srt/layers/attention/nsa/tilelang_kernel.py
import functools
from tilelang.jit.adapter.base import BaseKernelAdapter as _BaseKernelAdapter# -----------------------
# 绕过 tilelang 内部 bug:_legalize_result_idx 会就地修改 out_idx 列表
# 当同一个 @tilelang.jit 工厂编译两个不同参数计数的 prim_func 时,
# 第二次编译看到的索引已被第一次转换,导致适配器生成错误,运行时 IndexError。
# 补丁:在 mutation 前复制列表。
# -----------------------
if not getattr(_BaseKernelAdapter, "_legalize_result_idx_patched", False):
    _orig_legalize = _BaseKernelAdapter._legalize_result_idx
​
    def _legalize_result_idx_safe(self, result_idx):
        if isinstance(result_idx, list):
            result_idx = list(result_idx) # 复制避免原地修改
        return _orig_legalize(self, result_idx)
​
    _BaseKernelAdapter._legalize_result_idx = _legalize_result_idx_safe
    _BaseKernelAdapter._legalize_result_idx_patched = True
​
​
@functools.cache
def fp8_paged_mqa_logits_kernel(
    head_dim: int = 128,
    num_heads: int = 64,
    block_size: int = 64,
    clear_accum: bool = True,
    split_kv: int = 1,
):
    """构造一个 TileLang 内核,计算 FP8 分页 MQA 的 logits。
    使用符号化形状 (N, L, S, C),由 tilelang 自动处理任意 batch 大小。
    """
    # 符号化变量声明
    N = T.symbolic("batch_size")
    L = T.symbolic("max_table_length")
    S = T.symbolic("max_seq_len")
    C = T.symbolic("num_blocks")
    B = block_size
    D = head_dim
    H = num_heads
    SK = int(split_kv)
    BLOCK_BYTES = B * (D + 4)
    SCALE_OFFSET = B * D
​
    @tilelang.jit(pass_configs={**pass_configs, tilelang.PassConfigKey.TL_DISABLE_SAFE_MEMORY_ACCESS: True})
    def fp8_paged_mqa_logits(
        q: T.Tensor[(N, H, D), FP8],
        kvcache_u8: T.Tensor[(C, BLOCK_BYTES), UINT8],
        weight: T.Tensor[(N, H), FP32],
        seq_lens: T.Tensor[(N,), INT32],
        page_table: T.Tensor[(N, L), INT32],
        o: T.Tensor[(N, S), FP32],
    ) -> None:
        # 内核实现:每个 CU 处理一个 batch 和 split 后的分块
        with T.Kernel(N * SK) as bxs:
            bx = bxs % N
            pid_split = bxs // N
            seq_len = seq_lens[bx]
            np_total = T.ceildiv(seq_len, B)
            stride = T.ceildiv(np_total, SK)
            # ... 后续按分块计算 logits,存储在 o 中
            # 注意:clear_accum 和 split_kv 参数控制累加策略
            pass # 具体计算略
​
    return fp8_paged_mqa_logits

评论区精华

compressor 文件拆分建议 设计

DarkSharpness 建议将 compressor 相关大段改动移到单独文件(如 compress_hip.py),遵循相同接口,避免 diff 过大。

结论:作者已采纳,创建了 compress_hip.py。 · 已解决

bf16 精度问题 正确性

DarkSharpness 指出 tgemm.mm 返回 bf16 再 cast 到 fp32 可能导致精度退化,质疑是否需要 fp32 累加。

结论:作者回应实测 GSM8K 未见精度退化,且 DSV4 要求 bf16 in fp32 out 但内核本身使用 fp32 累加。 · 已解决

hip_flash_mla.py 是否需要保留 question

DarkSharpness 询问 hip_flash_mla.py 是否主要用于调试,建议重命名或说明用途。

结论:作者回应会修复该文件以便在 ROCm 上专用运行,保留该文件但后续优化。 · 已解决

docstring 位置调整 style

DarkSharpness 建议将 jit_kernel/deepseek_v4.py 中新增的 fused_rope docstring 移到文件开头。

结论:作者接受该建议,后续提交中调整。 · 已解决

风险与影响

主要技术风险包括:

  • CUDA回归风险:多处使用了is_hip()分支,若未充分测试可能影响CUDA路径(已在commit中明确fix breakage)。
  • 内核兼容性:新增的Triton/TileLang内核依赖ROCm环境,可能存在编译器兼容问题或性能瓶颈。
  • 环境变量复杂度:大量环境变量(如SGLANG_HACK_FLASHMLA_BACKENDSGLANG_OPT_USE_FUSED_COMPRESS)控制行为,配置不当可能导致运行时错误。
  • 缺少单元测试:PR未添加自动化测试,仅提供手动GSM8K精度验证(0.948),未来回归风险较高。

对用户:AMD ROCm用户首次能在MI35x上运行DeepSeek V4系列模型(flash/pro),开启后续优化。CUDA用户无影响。对系统:新增了完整的HIP注意力后端和多个Triton内核,代码侵入性中等(通过is_hip隔离)。对团队:后续将有多个PR迁移剩余优化(压缩流融合、多stream等),此PR奠定架构基础。

缺少测试覆盖 环境变量复杂度 CUDA 回归风险 内核兼容性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论