Prhub

#23596 [AMD] Fix memory access fault when `--page-size > 1` with speculative decoding on AMD GPUs

原始 PR 作者 hubertlu-tw 合并时间 2026-04-24 14:56 文件变更 1 提交数 2 评论 1 代码增减 +82 / -5

执行摘要

修复 AMD GPU 上使用默认页大小 + 推测解码时的内存访问错误

在AMD GPU上使用--page-size 16和推测解码(EAGLE3)运行服务器时,出现HSA内存访问错误(memory access fault)。该问题仅在HIP/Triton后端下触发,根源是Triton HIP编译器对混合宽度存储的错误编译。

PR值得快速合并,是专为AMD GPU上的Triton编译器兼容性问题而设计的安全修复。建议后续为get_last_loc_triton_safe添加单元测试,以避免类似编译器回归。

讨论亮点

本PR无公开review评论。commit message中提到该PR是从#23146拆分出的,应HaiShaw要求独立合并以加速修复。PR由HaiShaw批准。

实现拆解

  1. 导入与全局检测python/sglang/srt/mem_cache/common.py):新增from sglang.srt.utils import is_hip导入,并定义模块级常量_is_hip = is_hip(),用于在模块作用域内提前判断当前运行时是否为HIP平台。
  2. 新增安全内核common.py):定义_get_last_loc_safe_kernel Triton JIT内核,其输出缓冲区result_i32固定为int32类型。内核内部根据PREFIX_DTYPE_IS_I64条件选择是否进行类型提升,最终存储结果时保持int32,避免混合宽度存储。
  3. 新增安全函数封装common.py):定义get_last_loc_triton_safe,在Triton外分配int32类型的输出张量,调用安全内核,最后用.to()提升到与输入一致的dtype,从而在Triton之外完成类型转换。
  4. 修改调度逻辑get_last_loc):将原先的attention_backend条件提取为uses_triton_dispatch变量,并在其之上增加_is_hip and uses_triton_dispatch条件分支,将HIP平台下本应走get_last_loc_triton的路径全部导向get_last_loc_triton_safe。非HIP平台保持不变。
文件 模块 状态 重要度
python/sglang/srt/mem_cache/common.py 缓存层 modified 7.84

关键符号

_get_last_loc_safe_kernel get_last_loc_triton_safe get_last_loc

关键源码片段

python/sglang/srt/mem_cache/common.py core-logic

唯一修改的文件,包含所有核心变更:新增安全内核 `_get_last_loc_safe_kernel`、封装函数 `get_last_loc_triton_safe`,以及修改调度函数 `get_last_loc` 以在 HIP 路径下使用安全变体。

# python/sglang/srt/mem_cache/common.py@triton.jit
def _get_last_loc_safe_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result_i32, # 固定为 int32 输出缓冲区,避免 Triton HIP 后端对混合宽度存储的错误编译
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
    PREFIX_DTYPE_IS_I64: tl.constexpr, # 编译期常量:输入 dtype 是否为 int64
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens
​
    if PREFIX_DTYPE_IS_I64:
        # 当输入已经是 int64 时,直接进行乘法避免额外转换
        prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
        req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
        token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    else:
        # 当输入是 int32 时,将索引操作数显式提升为 int64,提升精度
        prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
        req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
        token_index = req_pool_indices.to(tl.int64) * req_to_token_stride + (
            prefix_lens.to(tl.int64) - 1
        )
​
    token_mask = mask & (prefix_lens > 0)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
    # 结果存储进 int32 缓冲区,后续由调用者在 Triton 外交互完成类型提升
    tl.store(result_i32 + offset, tokens, mask=mask)
​
​
def get_last_loc_triton_safe(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    """int32安全的last_loc Triton实现:使用int32输出缓冲区,由PyTorch完成最终类型提升。"""
    num_tokens = prefix_lens_tensor.shape[0]
    BLOCK_SIZE = 256
    # 分配 int32 缓冲区,避免 Triton 内发生混合宽度存储
    result_i32 = torch.empty(
        num_tokens, dtype=torch.int32, device=prefix_lens_tensor.device
    )
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
    _get_last_loc_safe_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result_i32,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
        PREFIX_DTYPE_IS_I64=(prefix_lens_tensor.dtype == torch.int64),
    )
    # 在 Triton 外部将结果提升回目标 dtype,避免编译器 bug
    return result_i32.to(prefix_lens_tensor.dtype)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

  • 回归风险:修改集中在单一文件common.py中的get_last_loc调度函数,仅影响HIP平台下attention_backend为triton/aiter且使用推测解码的场景。非HIP(CUDA)平台逻辑完全不变,回归风险低。
  • 性能影响:新引入的get_last_loc_triton_safe增加了一次torch.empty分配和一次.to()类型转换,开销极小,在AMD GPU上可忽略不计。
  • 覆盖范围:仅修复了get_last_loc路径下的问题,其他使用Triton内核的地方若存在类似混合宽度存储问题,仍需单独修复。
  • 测试覆盖:无配套测试文件,依赖集成测试(PR提供了GSM8K精度测试)和手动复现。

用户影响:修复了AMD GPU用户在使用推测解码(特别是EAGLE)且页面大小大于1时的崩溃问题,使该配置可正常工作。
系统影响:无,仅改变HIP分支下的内核选择路径,不涉及数据结构或协议变更。
团队影响:代码意图清晰,注释充分,易于维护。

平台特定编译器 bug 核心路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论