Prhub

#25976 [DeepSeek-V4] Add mhc_fused_post_pre kernel

原始 PR 作者 JoeLee314 合并时间 2026-05-30 17:04 文件变更 5 提交数 1 评论 15 代码增减 +875 / -48

执行摘要

融合 DeepSeek-V4 mHC 后 / 前步长 kernel,解码性能 +3.35%

优化 DeepSeek-V4 mHC 延迟敏感解码路径,避免启动分离内核,借鉴 TRTLLM 融合策略。

建议精读,重点关注 mhc_fused_post_pre_fma_tilelang 的 TileLang 实现和融合调度策略,对 LLM 推理 kernel fusion 有参考价值。同时注意其与现有 TileLang mHC 路径的依赖关系。

讨论亮点
  1. 测试覆盖率扩展:审查者 yhyang201 建议增加 hidden_size=7168 和 batch size 64,以覆盖 V4 Pro 和大路径;作者已采纳。
  2. 融合内核预热:yhyang201 提出在 prewarm_mhc_token_counts 中添加融合内核预热,避免首次调用冷启动延迟;已实现。
  3. 代码质量:gemini-code-assist 指出未使用变量 block_k/block_m、重复导入及魔数 2.0;后续添加了 _MHC_POST_MULT_VALUE 常量,部分问题未明确回应。

实现拆解

  1. 环境变量注册:在 python/sglang/srt/environ.py 添加 SGLANG_OPT_FUSE_MHC_POST_PRE = EnvBool(False),默认关闭。
  2. 核心融合内核:在 python/sglang/srt/layers/mhc.py 新增 mhc_fused_post_pre_fma_tilelang(TileLang JIT 内核),以及调度函数 mhc_fused_post_pre,小批次走融合 FMA,大批次走非融合 mhc_post + mhc_pre
  3. 模型集成:在 python/sglang/srt/models/deepseek_v4.py 中添加 _is_fused_mhc_post_pre_enabledrefresh_mhc_norm_weight_cacheprewarm_mhc_token_counts 预热扩展;修改 forward 方法实现跨层与层内融合。
  4. NextN 适配:在 python/sglang/srt/models/deepseek_v4_nextn.py 中,解码器返回四元组,NextN 层执行最终 hc_post。
  5. 单元测试:新增 tests/kernels/test_mhc_kernels.py,参数化验证融合路径与分离路径数值等价。
文件 模块 状态 重要度
python/sglang/srt/models/deepseek_v4.py 模型适配 modified 9.21
python/sglang/srt/layers/mhc.py mHC 内核 modified 8.33
tests/kernels/test_mhc_kernels.py 测试用例 added 6.76
python/sglang/srt/models/deepseek_v4_nextn.py NextN 适配 modified 5.94
python/sglang/srt/environ.py 环境配置 modified 4.58

关键符号

_is_fused_mhc_post_pre_enabled refresh_mhc_norm_weight_cache prewarm_mhc_token_counts prewarm_mhc_token_count_buckets mhc_fused_post_pre_fma_tilelang mhc_fused_post_pre test_mhc_fused_post_pre_matches_unfused

关键源码片段

python/sglang/srt/models/deepseek_v4.py data-contract

集成融合逻辑的核心文件,包含使能检查、归一化权重缓存、预热扩展和 forward 修改,是整个 PR 的主控点。

# 全局函数:检查融合 mHC post/pre 是否启用
def _is_fused_mhc_post_pre_enabled() -> bool:
    return (
        envs.SGLANG_OPT_FUSE_MHC_POST_PRE.get()
        and envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.get()
        and envs.SGLANG_OPT_USE_TILELANG_MHC_POST.get()
    )# DeepseekV4DecoderLayer 类新增的方法:缓存 bf16 归一化权重
# 避免每轮 forward 重复 cast 和 contiguous 操作
def refresh_mhc_norm_weight_cache(self):
    self._input_layernorm_weight_bf16 = (
        self.input_layernorm.weight.data.bfloat16().contiguous()
    )
    self._post_attention_layernorm_weight_bf16 = (
        self.post_attention_layernorm.weight.data.bfloat16().contiguous()
    )
tests/kernels/test_mhc_kernels.py test-coverage

新测试文件,参数化验证融合路径与分离路径数值等价,覆盖多种隐藏维度、token 数量和归一化配置。

import pytest
import torchimport sglang.srt.layers.mhc as mhc
from sglang.srt.layers.mhc import mhc_fused_post_pre, mhc_post, mhc_pre
​
​
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("num_tokens", [0, 1, 8, 17, 32, 64])
@pytest.mark.parametrize("use_norm", [False, True])
def test_mhc_fused_post_pre_matches_unfused(
    monkeypatch, hidden_size, num_tokens, use_norm
):
    if not torch.cuda.is_available():
        pytest.skip("CUDA is required for TileLang mHC kernels")
​
    monkeypatch.setattr(mhc, "is_dsa_prefill_cp_round_robin_split", lambda: False)
    torch.manual_seed(0)
    device = torch.device("cuda")
    hc_mult = 4
    hc_mult3 = hc_mult * 2 + hc_mult * hc_mult
    hc_hidden_size = hc_mult * hidden_size
​
    x = torch.randn(num_tokens, hidden_size, device=device, dtype=torch.bfloat16) * 0.1
    residual = (
        torch.randn(
            num_tokens, hc_mult, hidden_size, device=device, dtype=torch.bfloat16
        )
        * 0.1
    )
    post_prev = torch.rand(num_tokens, hc_mult, 1, device=device, dtype=torch.float32)
    comb_prev = (
        torch.rand(num_tokens, hc_mult, hc_mult, device=device, dtype=torch.float32)
        * 0.25
    )
    fn = (
        torch.randn(hc_mult3, hc_hidden_size, device=device, dtype=torch.float32) * 0.01
    )
    hc_scale = torch.tensor([0.5, 0.25, 0.25], device=device, dtype=torch.float32)
    hc_base = torch.zeros(hc_mult3, device=device, dtype=torch.float32)
    norm_weight = (
        torch.ones(hidden_size, device=device, dtype=torch.bfloat16)
        if use_norm
        else None
    )
    norm_eps = 1e-6 if use_norm else None
​
    rms_eps = 1e-6
    hc_eps = 1e-6
    sinkhorn_repeat = 2
​
    residual_ref = post_ref = comb_ref = layer_ref = None
    if num_tokens > 0:
        residual_ref = mhc_post(x, residual, post_prev, comb_prev)
        post_ref, comb_ref, layer_ref = mhc_pre(
            residual_ref, fn, hc_scale, hc_base, rms_eps, hc_eps, hc_eps, 2.0,
            sinkhorn_repeat, norm_weight=norm_weight, norm_eps=norm_eps
        )
    residual_out, post_out, comb_out, layer_out = mhc_fused_post_pre(
        x, residual, post_prev, comb_prev, fn, hc_scale, hc_base, rms_eps,
        hc_eps, hc_eps, 2.0, sinkhorn_repeat, norm_weight=norm_weight,
        norm_eps=norm_eps
    )
​
    torch.cuda.synchronize()
    if num_tokens == 0:
        assert residual_out.shape == residual.shape
        assert post_out.shape == (0, hc_mult, 1)
        assert comb_out.shape == (0, hc_mult, hc_mult)
        assert layer_out.shape == (0, hidden_size)
        return
​
    assert residual_ref is not None and post_ref is not None
    assert comb_ref is not None and layer_ref is not None
    torch.testing.assert_close(residual_out, residual_ref, atol=0, rtol=0)
    torch.testing.assert_close(post_out, post_ref, atol=1e-3, rtol=1e-3)
    torch.testing.assert_close(comb_out, comb_ref, atol=1e-3, rtol=1e-3)
    layer_atol = 2e-2 if use_norm else 2e-3
    layer_rtol = 2e-2 if use_norm else 2e-3
    torch.testing.assert_close(layer_out, layer_ref, atol=layer_atol, rtol=layer_rtol)

评论区精华

增加 hidden_size 参数化和 batch size 64 测试

审查者 yhyang201 建议测试覆盖更多 hidden_size(如 7168)和大批次(64)以确保大路径正确。

结论:作者已采纳并补充。 · 已解决

融合内核预热 性能

yhyang201 提出在 prewarm_mhc_token_counts 中添加融合内核预热,避免首次调用冷启动延迟。

结论:作者已实现。 · 已解决

风险与影响

  1. 数值精度:融合内核计算顺序不同,测试允许 2e-2 误差,需持续监控端到端精度。
  2. 大批次回退:依赖现有 DeepGEMM 路径,无回归风险。
  3. TileLang 依赖:仅 CUDA 平台支持,其他硬件自动禁用。
  4. 配置复杂性:需同时启用 SGLANG_OPT_FUSE_MHC_POST_PRE 和 TileLang 开关才能激活。
  1. 性能收益:小批次解码吞吐提升约 3.35%,对交互式部署有利。
  2. 用户影响:新增环境变量默认关闭,不影响现有行为。
  3. 维护负担:融合路径需与分离路径保持语义一致,后续 mHC 修改需同步。
核心路径变更 TileLang 依赖 仅 CUDA 支持 精度兼容风险

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论