Prhub

#24978 [MUSA]: Add flashinfer sampling backend

原始 PR 作者 froststeam 合并时间 2026-05-15 11:23 文件变更 9 提交数 8 评论 5 代码增减 +229 / -4

执行摘要

为 MUSA 添加 FlashInfer 采样后端

Add FlashInfer sampling backend support for MUSA to enable optimized sampling operations on MUSA devices.

值得阅读,特别是对 MUSA 后端的适配方式。设计决策中采用了与 CUDA 后端类似的接口封装,便于未来统一。

讨论亮点

Review 中 gemini-code-assist 指出了三个关键错误:在 musa.py 中错误地将 probs.device 作为上下文管理器使用( with probs.device as device ),这会导致运行时 AttributeError 。建议改为直接赋值变量后使用。作者已确认修复。此外 reviewer yeahdongcn 要求 rebase 并参考 flashinfer 官方实现的方式(使用 probs.device 而不是上下文管理器),并最终批准了 PR。

实现拆解

  1. C++ 接口声明:在 sgl-kernel/include/sgl_kernel_musa_ops.h 中添加了 min_p_sampling_from_probstop_p_sampling_from_probs 函数声明,供 Torch 绑定使用。
  2. Torch 算子注册:在 sgl-kernel/csrc/common_extension_musa.cc 中通过 TORCH_LIBRARY_EXPAND 注册上述算子,关联 MUSA 后端实现。
  3. Python 封装:在 sgl-kernel/python/sgl_kernel/musa.py 中新增内部函数和导出函数(如 top_k_renorm_probstop_p_sampling_from_probs 等),统一处理类型转换和参数传递。
  4. 采样器集成:在 python/sglang/srt/layers/sampler.py 中添加 if is_musa(): 条件分支,从 sgl_kernel 导入所需采样函数,使得 MUSA 设备运行时自动使用新后端。
  5. 包入口和构建:在 sgl-kernel/python/sgl_kernel/__init__.py 导出新函数,并在 sgl-kernel/setup_musa.py 中将 FlashInfer 的 sampling.cu 加入编译列表。
  6. 依赖更新:更新 python/pyproject_other.tomlsgl-kernel/pyproject_musa.toml3rdparty/amd/wheel/sglang/pyproject.toml 中的 torchada 版本至 0.1.56。
文件 模块 状态 重要度
sgl-kernel/python/sgl_kernel/musa.py 采样封装 modified 8.8
python/sglang/srt/layers/sampler.py 采样器集成 modified 5.85
sgl-kernel/include/sgl_kernel_musa_ops.h 采样接口 modified 5.89
sgl-kernel/csrc/common_extension_musa.cc 扩展注册 modified 5.29
sgl-kernel/python/sgl_kernel/__init__.py 包入口 modified 4.54

关键符号

top_k_renorm_probs top_p_renorm_probs top_p_sampling_from_probs top_k_top_p_sampling_from_probs min_p_sampling_from_probs

关键源码片段

sgl-kernel/python/sgl_kernel/musa.py dependency-wiring

添加了所有 FlashInfer 采样函数的 Python 封装,是 MUSA 采样后端的核心接口

def _top_p_sampling_from_probs_internal(
    probs: torch.Tensor,
    indices: Optional[torch.Tensor],
    maybe_top_p_arr: Optional[torch.Tensor],
    top_p_val: float,
    deterministic: bool,
    generator: Optional[torch.Generator],
) -> torch.Tensor:
    # 获取设备并转换概率为 float
    device = probs.device
    probs = probs.float()
    # 类型转换:top_p 数组转为 float(若提供)
    maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
    # 预分配输出张量( int32 )
    samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
    # 调用底层 MUSA 算子
    torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
        probs, samples, indices, maybe_top_p_arr, top_p_val, deterministic, generator,
    )
    return samples
​
​
def top_p_sampling_from_probs(
    probs: torch.Tensor,
    top_p: Union[torch.Tensor, float],
    indices: Optional[torch.Tensor] = None,
    deterministic: bool = True,
    generator: Optional[torch.Generator] = None,
    check_nan: bool = False,
) -> torch.Tensor:
    # 可选的 NaN 检查
    if check_nan and torch.any(torch.isnan(probs)):
        raise ValueError("Input probs contains NaN.")
    # 将标量 top_p 参数转换为 (tensor, val) 统一格式
    return _top_p_sampling_from_probs_internal(
        probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
    )

评论区精华

torch.device 上下文管理器错误 正确性

gemini-code-assist 指出使用 `with probs.device as device:` 会导致 AttributeError,建议改为直接赋值。

结论:作者已修复(fixed.) · 已解决

建议参考 flashinfer 官方实现使用 probs.device 设计

yeahdongcn 要求 rebase 并引用 flashinfer 官方代码( sampling.py#L214 ),建议使用 probs.device 而非上下文管理器。

结论:作者已采纳并修复,yeahdongcn 最终批准。 · 已解决

风险与影响

  • 技术风险:新后端可能在某些 MUSA 设备上不稳定,缺少充分的单元测试(没有测试文件变更)。
  • 性能风险:新后端可能引入性能回归,但预期是优化。
  • 兼容性风险:torchada 版本更新可能影响其他依赖;但只在 MUSA 下生效,不影响其他后端。
  • 用户:MUSA 用户可使用 FlashInfer 采样,获得性能提升。
  • 系统:增加了条件导入,非 MUSA 环境无变化。
  • 团队:需要维护 MUSA 特有的采样代码,但核心逻辑来自 FlashInfer,降低维护成本。
MUSA 新后端 torch.device 错误已修复 依赖版本变更 缺少单元测试

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论