Prhub

#22094 [JIT Kernel] Reland JIT activation

原始 PR 作者 DarkSharpness 合并时间 2026-04-25 14:00 文件变更 14 提交数 11 评论 9 代码增减 +473 / -40

执行摘要

重新引入 JIT 激活内核,修复 num_token=0 边界问题

关联 Issue #22078 因 CI 失败回退了初版 JIT activation。本 PR 修复了导致失败的 num_token=0 边界 case,并重新提交。JIT activation 可以避免对 sgl-kernel 特定编译版本的依赖,在更多 GPU 架构(如 Blackwell SM100+)上灵活控制编译选项,并为后续扩展提供统一入口。

值得精读,尤其是 _fast_math_flags 的设计权衡、custom op 的注册方式、以及如何通过条件导入保持向后兼容。对于需要维护多平台支持的开发者,本 PR 展示了如何用 JIT 替代部分 AOT 组件。

讨论亮点

Review 讨论 (ch-wan): 在 activation.py 中,--use_fast_math 标志是否应该全局启用?DarkSharpness 后续在 commit 中实现条件禁用:SM100+ (Blackwell) 和 ROCm 上关闭,其余开启。结论:在 _fast_math_flags 函数中依据平台和架构版本动态决定,平衡性能与精度。
PR 普通评论 (ch-wan): "I fixed some numerical issues. Could you check my changes?" — 合并者发现了多个数值不一致问题,包括 custom op fake shape、GELU multiply 顺序、gate*up 精度等,并逐个提交修复,确保了 JIT 版本与 sgl-kernel 的输出在测试容忍度内一致。

实现拆解

  1. 添加 CUDA kernel: python/sglang/jit_kernel/csrc/elementwise/activation.cuh 定义 ActivationKernel<T, kUsePDL> 类模板,支持 SiLU、GELU、GELU_Tanh 三种激活,使用向量化访存和 PDL 延迟加载。
  2. 创建 Python 封装: python/sglang/jit_kernel/activation.py 通过 load_jit 编译 kernel,导出 silu_and_mulgelu_and_mulgelu_tanh_and_mul 函数,并注册 custom op _run_activation_inplace_fast_math_flags 根据平台条件添加 --use_fast_math 标志(SM100+ 和 ROCm 禁用)。
  3. 修改现有模块的导入路径: 在 gguf.pycutlass_w4a8_moe.pyfused_moe_triton/triton_kernels_moe.pyfused_marlin_moe.pycutlass_moe.pysrt/layers/activation.pymultimodal_gen/runtime/layers/activation.py 等文件中,将 CUDA 分支下激活函数的导入从 sgl_kernel 切换为 sglang.jit_kernel.activation,并作 _is_cuda / _is_musa 条件判断,确保多平台兼容。
  4. 添加单元测试: python/sglang/jit_kernel/tests/test_activation.py 包含 test_activation_correctnesstest_activation_out_param,在多种形状和 dtype 下验证数值正确性(与 PyTorch 参考实现对比),并测试外部 out 参数路径。
  5. 添加性能基准: python/sglang/jit_kernel/benchmark/bench_activation.py 对比 AOT(sgl-kernel)、JIT(本模块)和 torch.compile 实现的延迟,基于 Triton 测试框架生成性能报告。
  6. 后续精度修复: 多个 commit 由合并者 ch-wan 提交,包括:修复 custom op fake 实现输出形状(out_shape="input" 应匹配实际输出形状)、禁用 Blackwell 上的 --use_fast_math、调整 GELU Tanh 中 alpha 乘法的顺序以匹配浮点结果、将 gate*up 乘法提升到 float32 精度以对齐 flashinfer。
文件 模块 状态 重要度
python/sglang/jit_kernel/activation.py JIT 核心 added 8.71
python/sglang/jit_kernel/tests/test_activation.py 测试 added 7.37
python/sglang/srt/layers/quantization/gguf.py GGUF 集成 modified 6.67
python/sglang/jit_kernel/benchmark/bench_activation.py 基准测试 added 8.54
python/sglang/jit_kernel/csrc/elementwise/activation.cuh CUDA 内核 added 6.33
python/sglang/srt/layers/moe/cutlass_w4a8_moe.py Cutlass MoE modified 6.01

关键符号

_fast_math_flags _jit_activation_module _run_activation_inplace run_activation silu_and_mul gelu_and_mul gelu_tanh_and_mul benchmark _reference test_activation_correctness test_activation_out_param

关键源码片段

python/sglang/jit_kernel/activation.py core-logic

核心 JIT 激活函数封装,定义编译入口、fast_math 策略及公开 API

from __future__ import annotations
from typing import TYPE_CHECKING, Optional
import torchfrom sglang.jit_kernel.utils import (
    cache_once,
    get_jit_cuda_arch,
    is_arch_support_pdl,
    is_hip_runtime,
    load_jit,
    make_cpp_args,
)
from sglang.srt.utils.custom_op import register_custom_opif TYPE_CHECKING:
    from tvm_ffi.module import Module
​
​
def _fast_math_flags() -> list[str]:
    # 平台条件判断:ROCm 和 Blackwell (SM100+) 禁用 --use_fast_math
    # 因为 Blackwell 需要精确的 expf,且 triton 在 ROCm 下使用 clang 编译会出错
    if is_hip_runtime():
        return []
    if get_jit_cuda_arch().major >= 10:
        return []
    return ["--use_fast_math"]
​
​
@cache_once
def _jit_activation_module(dtype: torch.dtype) -> Module:
    args = make_cpp_args(dtype, is_arch_support_pdl())
    return load_jit(
        "activation",
        *args,
        cuda_files=["elementwise/activation.cuh"],
        extra_cuda_cflags=_fast_math_flags(),
        cuda_wrappers=[
            ("run_activation", f"ActivationKernel<{args}>::run_activation"),
        ],
    )
​
​
SUPPORTED_ACTIVATIONS = {"silu", "gelu", "gelu_tanh"}
​
​
@register_custom_op(mutates_args=["out"])
def _run_activation_inplace(op_name: str, input: torch.Tensor, out: torch.Tensor) -> None:
    # 视图变换为 (batch, hidden*2) 和 (batch, hidden) 后调用编译后的 kernel
    hidden_size = input.shape[-1] // 2
    module = _jit_activation_module(input.dtype)
    input_2d = input.view(-1, hidden_size * 2)
    out_2d = out.view(-1, hidden_size)
    module.run_activation(input_2d, out_2d, op_name)
​
​
def run_activation(op_name: str, input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
    assert op_name in SUPPORTED_ACTIVATIONS, f"Unsupported activation: {op_name}"
    hidden_size = input.shape[-1] // 2
    if out is None:
        out = input.new_empty(*input.shape[:-1], hidden_size)
    _run_activation_inplace(op_name, input, out)
    return out
​
​
def silu_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
    return run_activation("silu", input, out)
​
​
def gelu_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
    return run_activation("gelu", input, out)
​
​
def gelu_tanh_and_mul(input: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
    return run_activation("gelu_tanh", input, out)
python/sglang/jit_kernel/tests/test_activation.py test-coverage

单元测试,覆盖三种激活、多种 dtype 和 shape,验证正确性及 out 参数传递

import sys
import pytest
import torch
import torch.nn.functional as Ffrom sglang.jit_kernel.activation import SUPPORTED_ACTIVATIONS, run_activation
from sglang.jit_kernel.utils import get_ci_test_range
from sglang.test.ci.ci_register import register_cuda_ciregister_cuda_ci(est_time=20, suite="stage-b-kernel-unit-1-gpu-large")
register_cuda_ci(est_time=30, suite="nightly-kernel-1-gpu", nightly=True)OPS = SUPPORTED_ACTIVATIONS
DTYPES = [torch.float16, torch.bfloat16, torch.float32]
SHAPES = get_ci_test_range(
    full_range=[
        (7, 16),
        (83, 1024),
        (3, 5, 16),
        (2, 3, 512),
        (1, 17, 4096),
        *[(2**x, 2048) for x in range(0, 15, 2)],
        *[(2**x, 65536) for x in range(0, 5, 2)],
    ],
    ci_range=[(7, 16), (2, 3, 512)],
)
​
​
def _reference(op_name: str, x: torch.Tensor) -> torch.Tensor:
    # 参考实现:使用 PyTorch 的官方激活函数,在 float32 下计算并缩放 up
    d = x.shape[-1] // 2
    lhs = x[..., :d].float()
    rhs = x[..., d:]
    if op_name == "silu":
        act = F.silu(lhs)
    elif op_name == "gelu":
        act = F.gelu(lhs, approximate="none")
    else:
        act = F.gelu(lhs, approximate="tanh")
    return act.to(dtype=x.dtype) * rhs
​
​
def _tolerances(dtype: torch.dtype) -> tuple[float, float]:
    # float32 使用更严格容差,半精度使用 1e-2
    if dtype == torch.float32:
        return 1e-4, 1e-4
    return 1e-2, 1e-2
​
​
@pytest.mark.parametrize("op_name", OPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
def test_activation_correctness(op_name, dtype, shape):
    x = torch.randn(shape, dtype=dtype, device="cuda")
    out = run_activation(op_name, x, None)
    expected = _reference(op_name, x)
    atol, rtol = _tolerances(dtype)
    torch.testing.assert_close(out, expected, atol=atol, rtol=rtol)
​
​
@pytest.mark.parametrize("op_name", OPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shape", SHAPES)
def test_activation_out_param(op_name, dtype, shape):
    # 验证传入外部 out 张量时,结果正确且返回的是同一对象
    x = torch.randn(shape, dtype=dtype, device="cuda")
    out = torch.empty(shape[:-1] + (shape[-1] // 2,), dtype=dtype, device="cuda")
    result = run_activation(op_name, x, out)
    assert result is out
    expected = _reference(op_name, x)
    atol, rtol = _tolerances(dtype)
    torch.testing.assert_close(out, expected, atol=atol, rtol=rtol)
python/sglang/srt/layers/quantization/gguf.py dependency-wiring

修改 GGUF 量化模块导入路径,将 CUDA 上的激活函数切换为 JIT 版本

# 在 gguf.py 中,根据平台条件导入激活函数
if _is_cuda:
    from sgl_kernel import moe_align_block_size, moe_sum
    from sgl_kernel.quantization import (
        ggml_dequantize,
        ggml_moe_a8,
        ggml_moe_a8_vec,
        ggml_moe_get_block_size,
        ggml_mul_mat_a8,
        ggml_mul_mat_vec_a8,
    )
    # CUDA 平台使用 JIT 编译的激活函数,替代 sgl-kernel 的 AOT 版本
    from sglang.jit_kernel.activation import gelu_and_mul, silu_and_mul
elif _is_musa:
    from sgl_kernel import gelu_and_mul, moe_align_block_size, moe_sum, silu_and_mul
    from sgl_kernel.quantization import (
        ggml_dequantize,
        ggml_moe_a8,
        ggml_moe_a8_vec,
        ggml_moe_get_block_size,
        ggml_mul_mat_a8,
        ggml_mul_mat_vec_a8,
    )
else:
    if not _is_hip:
        warnings.warn(f"Only CUDA and MUSA support GGUF quantization currently.")# fused_moe_gguf 内部的 act 辅助函数也相应简化
def act(x: torch.Tensor):
    # 直接调用简化后的接口,无需手动管理临时张量
    if activation == "silu":
        return silu_and_mul(x)
    elif activation == "gelu":
        return gelu_and_mul(x)
    raise ValueError(f"Unsupported activation: {activation}")

评论区精华

关于 --use_fast_math 的使用策略 性能

ch-wan 在 review 中询问是否应该全局启用 `--use_fast_math`。DarkSharpness 确认实现已根据平台条件禁用(SM100+ 和 ROCm 关闭)。

结论:采纳条件禁用方案,在 `_fast_math_flags` 中实现。 · 已解决

数值精度不一致修复 正确性

ch-wan 在 PR 评论中指出发现了数值问题,并提交多个 commit 修复 custom op fake shape、GELU multiply 顺序、gate*up 精度等。

结论:所有修复已合并,JIT 版本与 sgl-kernel 在测试容忍度内一致。 · 已解决

风险与影响

  1. 回归风险:替换多个模块的激活函数导入,如果条件判断错误(如 MUSA 平台意外使用了 JIT 版本),可能导致编译失败或行为异常。
  2. 精度风险:尽管有测试保障,但 --use_fast_math 在非 Blackwell 上的开启可能引入微小数值差异,尤其在混合精度场景下。
  3. 性能风险:JIT 编译首次调用有额外开销;benchmark 显示 JIT 版本与 AOT 版本性能接近,但在某些 shape 上可能略逊。
  4. 兼容性风险:非 CUDA 平台(MUSA, XPU, HIP)仍需 sgl_kernel 路径,修改后 gguf.py 等文件的导入区分了 _is_cuda_is_musa,需确保所有分支均正确。

用户影响:模型输出应保持一致,首次调用可能增加几秒 JIT 编译时间。系统影响:统一了激活函数实现,便于后续添加新激活函数。团队影响:减少对 sgl-kernel 二进制包的依赖,使新 GPU 架构(如 Blackwell)无需等待 sgl-kernel 预编译即可运行。

跨模块依赖变更 精度敏感 多平台条件分支 JIT 编译首次性能开销

关联 Issue

#22078 Revert "[Feature] JIT activation and update skills (by codex)"

完整报告

参与讨论