执行摘要
- 一句话:LoRA CSGMV kernel 离线自动调优
- 推荐动作:LoRA 调优对生产环境有显著收益,建议所有使用 LoRA 的用户关注此 PR。设计上复用了 MoE 调优的模式,是值得参考的框架扩展方式。特别关注
lora_tuning_config.py 的版本回退逻辑和最近邻 chunk 选择策略,可复用至未来其他 Triton kernel 的调优。
功能与动机
LoRA CSGMV kernel 的性能高度依赖 BLOCK_N、BLOCK_K、num_warps 等 Triton 编译参数,手动调优费时且平台不通用。PR body 明确说明:'Add offline auto-tuning script for LoRA csgmv shrink / expand kernels (similar to MoE auto-tuning)',并给出在 Qwen3-Embedding-0.6B 上 shrink kernel 2-3x、expand kernel 1.1-1.5x 的加速数据。
实现拆解
-
新增离线调优脚本 (benchmark/kernels/lora_csgmv/tune_lora_csgmv.py):定义搜索空间(BLOCK_N, BLOCK_K, num_warps, num_stages, maxnreg)、编写基准测试函数 benchmark_shrink_config 和 benchmark_expand_config,通过网格搜索找到每个 (kernel, K, R, S, chunk_size) 下的最优配置并保存为 JSON。
-
新增配置加载器 (python/sglang/srt/lora/triton_ops/lora_tuning_config.py):提供 get_lora_config_file_name、get_lora_configs(含 LRU 缓存)、get_lora_shrink_config、get_lora_expand_config。加载时先精确匹配当前 Triton 版本目录,若找不到则回退到其他版本目录;仍未找到则返回 None,使调用方使用默认值。默认值在 review 后调整为仅含 BLOCK_N 和 BLOCK_K,保留 Triton 对 num_warps/num_stages 的自动选择。
-
修改 kernel 调用 (chunked_sgmv_shrink.py、chunked_sgmv_expand.py):在 chunked_lora_shrink 和 chunked_lora_expand 函数中,通过 get_lora_shrink_config / get_lora_expand_config 获取调优参数,并作为 kwargs 传入 Triton kernel。当配置不存在时使用硬编码默认值,保证了零侵入。
-
新增单元测试 (test/manual/lora/test_lora_tuning_config.py):覆盖配置文件名生成、精确加载、Triton 版本回退、最近邻 chunk_size 选择(最近匹配而非严格相等)、完全缺失回退到默认值等场景。
-
提供预调优配置 (csgmv_configs/triton_3_5_1/ 下 7 个 JSON 文件):针对 H200 GPU、Triton 3.5.1,覆盖 qkv_proj (S=3)、gate_up_proj (S=2) 等常见 layer 的 shrink 和 expand 配置,作为用户开箱即用的参考。
关键文件:
benchmark/kernels/lora_csgmv/tune_lora_csgmv.py(模块 调优工具;类别 source;类型 dependency-wiring;符号 _get_raw_kernel, build_batch_info, timed_cuda_ms, get_shrink_search_space): 新增离线调优脚本,是整个功能的核心引擎,定义了搜索空间、计时器和主循环。
test/manual/lora/test_lora_tuning_config.py(模块 单元测试;类别 test;类型 test-coverage;符号 TestLoraConfigFileName, test_includes_all_params, test_different_slices_different_filenames, TestLoraConfigLoading): 新增单元测试,覆盖了配置加载的三种关键场景:精确匹配、Triton 版本回退、缺失回退。
python/sglang/srt/lora/triton_ops/lora_tuning_config.py(模块 配置加载;类别 infra;类型 infrastructure;符号 get_lora_config_file_name, get_lora_configs, get_lora_shrink_config, get_lora_expand_config): 配置加载器,封装了文件名生成、LRU 缓存、Triton 版本回退和 chunk_size 最近邻选择策略,是运行时调优配置的枢纽。
python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py(模块 内核层;类别 infra;类型 infrastructure): 修改 shrink kernel 调用,集成调优配置加载。
python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py(模块 内核层;类别 infra;类型 infrastructure): 修改 expand kernel 调用,集成调优配置加载。
python/sglang/srt/lora/triton_ops/csgmv_configs/triton_3_5_1/lora_expand,K=4096,R=64,S=3,device=NVIDIA_H200.json(模块 配置数据;类别 infra;类型 infrastructure): H200 上 qkv_proj expand kernel 预调优配置示例,展示输出格式。
关键符号:get_lora_config_file_name, get_lora_configs, get_lora_shrink_config, get_lora_expand_config, _get_raw_kernel, benchmark_shrink_config, benchmark_expand_config
关键源码片段
python/sglang/srt/lora/triton_ops/lora_tuning_config.py
配置加载器,封装了文件名生成、LRU 缓存、Triton 版本回退和 chunk_size 最近邻选择策略,是运行时调优配置的枢纽。
以下为 get_lora_configs 函数,展示了版本回退逻辑:
@functools.lru_cache
def get_lora_configs(
kernel: str,
K: int,
R: int,
S: int,
) -> Optional[Dict[int, Dict[str, Any]]]:
"""加载调优配置,优先精确匹配 Triton 版本,否则回退到其他版本。"""
json_file_name = get_lora_config_file_name(kernel, K, R, S)
config_dir = os.environ.get(
"SGLANG_LORA_CONFIG_DIR",
os.path.dirname(os.path.realpath(__file__))
)
configs_root = os.path.join(config_dir, "csgmv_configs")
triton_version = triton.__version__
version_dir = f"triton_{triton_version.replace('.', '_')}"
# 1. 精确匹配当前 Triton 版本
config_file_path = os.path.join(configs_root, version_dir, json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(f"Using LoRA {kernel} config from {config_file_path}.")
return {int(key): val for key, val in json.load(f).items()}
# 2. 扫描已有版本目录(按降序),取第一个找到的回退
if os.path.isdir(configs_root):
version_dirs = sorted(
(d for d in os.listdir(configs_root) if d.startswith("triton_")),
reverse=True,
)
for fallback_dir in version_dirs:
fallback_path = os.path.join(configs_root, fallback_dir, json_file_name)
if os.path.exists(fallback_path):
logger.info(
f"LoRA {kernel}: triton version {triton_version} not found, "
f"falling back to {fallback_dir}."
)
with open(fallback_path) as f:
return {int(key): val for key, val in json.load(f).items()}
# 3. 未找到任何版本,返回 None
logger.info(
f"LoRA {kernel} config not found for {K=} {R=} {S=}. "
"Falling back to hardcoded defaults."
)
return None
评论区精华
zminglei 在 lora_tuning_config.py 的 review 中指出:DEFAULT_SHRINK_CONFIG 和 DEFAULT_EXPAND_CONFIG 包含 num_warps=4, num_stages=2,这些参数之前未作为 kwargs 传入。当回退时,这些默认值会覆盖 Triton 的自动选择,可能改变默认行为。
作者通过最终 commit Remove num_warps/num_stages from default configs to preserve Triton auto-tuning 解决了此问题,移除了默认配置中的 num_warps 和 num_stages,使得无调优文件时完全使用 Triton 的原生自动选择。
- 默认配置覆盖 Triton 自动选择 (design): 作者移除了默认配置中的 num_warps 和 num_stages,仅保留 BLOCK_N 和 BLOCK_K,保证无调优文件时行为与之前一致。
风险与影响
-
风险:兼容性风险:调优配置仅作用于 --lora-backend csgmv 后端,不影响其他 backend。
性能风险:预调优配置针对 H200 生成,在其他 GPU 上可能不是最优,但用户可随时运行自己模型的调优脚本。
回归风险:通过回退默认值保留了原始 Triton 自动选择行为,且测试覆盖了缺失配置回退,回归可能性低。
维护成本:新增的配置 JSON 文件须随新 Triton 版本维护,但自动调优脚本可以重新生成。
-
影响:用户/开发者:LoRA 推理用户无需手动调优,Server 自动加载最优配置,显著提升性能(端到端 12-24%)。仅作用于 csgmv 后端,无侵入性。
系统:无额外运行时开销,配置仅在加载时读取。
团队:新增一个 benchmark 脚本和约 200 行配置加载代码,遵循 MoE 调优的已有模式,易于维护。
-
风险标记:回退配置覆盖, GPU 平台依赖
关联脉络
参与讨论