Prhub

#25274 [Refactor] JIT kernel benchmark

原始 PR 作者 DarkSharpness 合并时间 2026-05-28 15:49 文件变更 6 提交数 10 评论 7 代码增减 +597 / -314

执行摘要

重写 JIT kernel benchmark 框架,替换 triton.testing

PR body指出triton.testing.benchmark过于沉重且存在bug:1)缺失当前CUDA stream的同步;2)无法在CUDA graph benchmark中冷L2缓存。以往通过在代码中手动模拟多层来规避L2缓存效果,导致代码混乱难以理解。因此需要一个轻量级benchmark工具,提升可读性和结果准确性。

建议精读此PR,尤其是marker.pydo_bench的实现和parametrize的pytest风格设计。它为CUDA kernel benchmark提供了一套可复用的轻量方案,值得其他项目借鉴。bench_qknorm.py的迁移展示了如何大幅简化代码。

讨论亮点

Review中gemini-code-assist提出了多项建议:

  • 带宽计算单位歧义:内部计算使用GiB但输出标签为GB/s,建议统一为GiB/s或修改计算。
  • IndexError风险:当line_vals为空时,需要对空列表增加守卫。
  • memory_args计数:in-place kernel(如qknorm)应重复传入读写张量以准确计量带宽,建议memory_args=(q, q, k, k, q_weight, k_weight)
  • 索引唯一性:store cache benchmark应使用torch.randperm避免重复索引导致的写冲突(但原代码已使用,建议多余)。
  • graph_clone_args扩展:需要克隆所有缓存buffer才能完全规避L2缓存效果,建议将k_cache、v_cache也加入克隆列表。
  • memory_args包含写操作:store cache benchmark应计入kv的写入,建议memory_args=(k, k, v, v, indices)

这些讨论中,部分已在head代码中实现或原本就正确(如randperm),部分未采纳(如memory_args重复计数)。

实现拆解

  1. 创建核心benchmark框架:新增python/sglang/jit_kernel/benchmark/marker.py,定义Benchmark类、parametrize装饰器、do_bench函数等核心组件。支持CUDA graph捕获、自动流同步、buffer克隆以冷L2、内存带宽计算。
  2. 重写示例benchmark:修改bench_qknorm.pybench_store_cache.py,用@marker.parametrize声明参数网格,@marker.benchmark声明比较轴,替代原来的itertools.producttriton.testing.Benchmark。移除手动多层模拟代码,改为框架内部通过graph_clone_args刷新缓存。
  3. 适配更多benchmark:更新bench_activation.py同样使用新框架,简化了activation和filter benchmark的编写。
  4. 增加工具函数:在utils.py中添加create_emptycreate_random便捷函数,统一使用默认dtype/device。
  5. 更新开发文档:修改.claude/skills/add-jit-kernel/SKILL.md,详细说明marker框架用法,包括参数化、do_bench关键参数和示例代码。
文件 模块 状态 重要度
python/sglang/jit_kernel/benchmark/marker.py 基准框架 added 9.08
python/sglang/jit_kernel/benchmark/bench_qknorm.py 基准测试 modified 7.88
python/sglang/jit_kernel/benchmark/bench_store_cache.py 基准测试 modified 7.85
python/sglang/jit_kernel/benchmark/bench_activation.py 基准测试 modified 7.43
python/sglang/jit_kernel/benchmark/utils.py 工具函数 modified 5.73
.claude/skills/add-jit-kernel/SKILL.md 开发者文档 modified 4.65

关键符号

skip do_bench parametrize benchmark Benchmark.run create_random create_empty sglang_aot_qknorm torch_impl_qknorm

关键源码片段

python/sglang/jit_kernel/benchmark/marker.py core-logic

核心新增文件,定义了整个 benchmark 框架,包括 Benchmark 类、parametrize 装饰器、do_bench 函数、BenchResult 和 Table 等,是 PR 的核心贡献。

# marker.py - 轻量级 JIT kernel benchmark 框架核心import torch
from sglang.jit_kernel.utils import cache_once# 缓存每个设备的 CUDA stream,用于确保 benchmark 在独立 stream 上进行 @cache_once
def _get_benchmark_stream(device_id: int) -> torch.cuda.Stream:
    return torch.cuda.Stream(device=device_id)# 递归克隆输入,用于每次 graph 重放时刷新 L2 缓存
def _clone_recursive(in_: Any) -> Any:
    if isinstance(in_, torch.Tensor):
        return in_.clone()
    elif isinstance(in_, (list, tuple)):
        return type(in_)(_clone_recursive(x) for x in in_)
    elif isinstance(in_, dict):
        return {k: _clone_recursive(v) for k, v in in_.items()}
    # 基本类型直接返回
    elif isinstance(in_, (bool, int, float, str, torch.dtype, torch.device, type(None))):
        return in_
    raise ValueError(f"unsupported type: {type(in_)}")# 递归获取张量总字节数(用于带宽估算)
def _get_nbytes_recursive(in_: Any) -> int:
    if isinstance(in_, torch.Tensor):
        return in_.nbytes
    elif isinstance(in_, (list, tuple)):
        return sum(_get_nbytes_recursive(x) for x in in_)
    elif isinstance(in_, dict):
        return sum(_get_nbytes_recursive(v) for v in in_.values())
    elif isinstance(in_, (bool, int, float, str, torch.dtype, torch.device, type(None))):
        return 0
    raise ValueError(f"unsupported type: {type(in_)}")def _process_metrics(times: list[float], metrics: tuple[Metric, ...]) -> list[float]:
    # 排序后将微秒转换为秒
    times = sorted(x / 1000 for x in times)
    results = []
    for metric in metrics:
        if metric == "avg":
            results.append(sum(times) / len(times))
        else:
            # 取指定分位数
            which = min(int(len(times) * metric), len(times) - 1)
            results.append(times[which])
    return resultsclass BenchResult(NamedTuple):
    metrics: Tuple[Metric, ...]
    times: List[float] # 单位:秒
    memory_footprint: Optional[int]class Table:
    # 对齐文本表格,用于打印结果
    SEP = " | "
    def __init__(self):
        self._headers: List[str] = []
        self._mins: List[int] = []
        self._pads: List[int] = []
        self._aligns: List[str] = []
        self._seps: set = set()
        self._rows: List[List[str]] = []
​
    @staticmethod
    def format_latency(r: float) -> str:
        if math.isnan(r):
            return "N/A"
        length = len(str(int(r)))
        if length < 5:
            return f"{r:.4f}"
        digits = max(0, 4 - (length - 5))
        return f"{r:.{digits}f}"
​
    @staticmethod
    def format_bandwidth(b: float) -> str:
        if math.isnan(b):
            return "N/A"
        return f"{b:.2f}"
    # ...(其余方法用于构建和打印表格)
python/sglang/jit_kernel/benchmark/bench_qknorm.py core-logic

示例迁移,展示如何用新框架重写 benchmark,去掉大量样板代码,变为声明式参数化。

# bench_qknorm.py - 使用新框架的示例 benchmarkimport torch
from sglang.jit_kernel.benchmark import marker
from sglang.jit_kernel.benchmark.utils import create_random
from sglang.jit_kernel.norm import fused_inplace_qknorm
from sglang.srt.utils import get_current_device_stream_fastFN_MAP = {
    "aot": sglang_aot_qknorm, # 改用 flashinfer rmsnorm
    "jit": fused_inplace_qknorm,
    "torch": torch_impl_qknorm,
}# @marker.parametrize 声明参数网格(每个参数可独立指定 full 和 CI 下的值)
# @marker.benchmark 声明比较轴 line_arg="impl",取值为 ["aot", "jit", "torch"]
@marker.parametrize("head_dim", [128, 256, 512, 1024], [128])
@marker.parametrize("GQA", [4, 8], [4])
@marker.parametrize("num_kv_heads", [1, 2, 4, 8], [1])
@marker.parametrize("batch_size", [2**n for n in range(0, 14)], [16])
@marker.benchmark("impl", ["aot", "jit", "torch"])
def benchmark(head_dim: int, GQA: int, num_kv_heads: int, batch_size: int, impl: str):
    num_qo_heads = GQA * num_kv_heads
    q = create_random(batch_size, num_qo_heads, head_dim)
    k = create_random(batch_size, num_kv_heads, head_dim)
    q_weight = create_random(head_dim)
    k_weight = create_random(head_dim)
    # do_bench 负责 CUDA graph 捕获、L2 缓存规避、内存带宽计算
    return marker.do_bench(
        FN_MAP[impl],
        input_args=(q, k, q_weight, k_weight),
        memory_output=(q, k), # 标记 inplace 写出的张量
    )if __name__ == "__main__":
    benchmark.run()

评论区精华

带宽计算单位歧义(GiB vs GB/s) 设计

gemini-code-assist 指出内部计算使用 GiB(1024^3),但输出标签为 GB/s(10^9),建议统一或更改计算。

结论:未在提交中显式修复,可能维持现状未采纳。 · 待处理

memory_args 对 in-place kernel 应双倍计数 正确性

对于 qknorm 等 in-place 写入的 kernel,read & write 都应计入 memory_args 才能得到准确带宽。建议将 q 和 k 分别传入两次。

结论:head 代码中 memory_args 仍为默认(memory_args="all" 实际仅计一次),未采纳建议。 · 待处理

store cache 索引唯一性 性能

gemini-code-assist 建议使用 torch.randperm 而非 torch.randint 避免重复索引导致写冲突,影响测量真实性。

结论:原代码和 head 代码均已使用 torch.randperm,建议多余。 · 已解决

风险与影响

  1. 框架稳定性:新框架marker.py仅应用于少量benchmark,可能存在未发现的CUDA graph兼容性或数值问题。
  2. CI基准波动:由于加入了L2缓存规避和流同步,benchmark结果可能更稳定但也可能因GPU型号差异产生新的波动。
  3. 内存开销:CUDA graph克隆buffer会增加显存占用,尤其在大型tensor场景下可能OOM。
  4. 单位歧义:带宽计算使用GiB但标签为GB/s,可能造成数据误读。
  5. 迁移不完整:部分旧benchmark(如bench_activation.py中filter部分)已迁移,但仍有其他未迁移的benchmark,维护两个模式成本。

用户:JIT kernel开发者编写benchmark的体验显著改善——使用声明式参数化、自动CUDA graph、准确带宽计算。系统:benchmark结果对L2缓存影响更鲁棒,stream同步使计时准确。团队:需要推广新框架,逐步迁移所有旧benchmark并废弃triton.testing用法。文档SKILL.md已更新,降低学习门槛。

新框架稳定性风险 单位混淆风险 CI 基准波动可能

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论