执行摘要
- 一句话:修复AMD GPU上使用默认页大小+推测解码时的内存访问错误
- 推荐动作:PR值得快速合并,是专为AMD GPU上的Triton编译器兼容性问题而设计的安全修复。建议后续为
get_last_loc_triton_safe添加单元测试,以避免类似编译器回归。
功能与动机
在AMD GPU上使用--page-size 16和推测解码(EAGLE3)运行服务器时,出现HSA内存访问错误(memory access fault)。该问题仅在HIP/Triton后端下触发,根源是Triton HIP编译器对混合宽度存储的错误编译。
实现拆解
- 导入与全局检测(
python/sglang/srt/mem_cache/common.py):新增from sglang.srt.utils import is_hip导入,并定义模块级常量_is_hip = is_hip(),用于在模块作用域内提前判断当前运行时是否为HIP平台。
- 新增安全内核(
common.py):定义_get_last_loc_safe_kernel Triton JIT内核,其输出缓冲区result_i32固定为int32类型。内核内部根据PREFIX_DTYPE_IS_I64条件选择是否进行类型提升,最终存储结果时保持int32,避免混合宽度存储。
- 新增安全函数封装(
common.py):定义get_last_loc_triton_safe,在Triton外分配int32类型的输出张量,调用安全内核,最后用.to()提升到与输入一致的dtype,从而在Triton之外完成类型转换。
- 修改调度逻辑(
get_last_loc):将原先的attention_backend条件提取为uses_triton_dispatch变量,并在其之上增加_is_hip and uses_triton_dispatch条件分支,将HIP平台下本应走get_last_loc_triton的路径全部导向get_last_loc_triton_safe。非HIP平台保持不变。
关键文件:
python/sglang/srt/mem_cache/common.py(模块 缓存层;类别 source;类型 core-logic;符号 _get_last_loc_safe_kernel, get_last_loc_triton_safe): 唯一修改的文件,包含所有核心变更:新增安全内核_get_last_loc_safe_kernel、封装函数get_last_loc_triton_safe,以及修改调度函数get_last_loc以在HIP路径下使用安全变体。
关键符号:_get_last_loc_safe_kernel, get_last_loc_triton_safe, get_last_loc
关键源码片段
python/sglang/srt/mem_cache/common.py
唯一修改的文件,包含所有核心变更:新增安全内核_get_last_loc_safe_kernel、封装函数get_last_loc_triton_safe,以及修改调度函数get_last_loc以在HIP路径下使用安全变体。
# python/sglang/srt/mem_cache/common.py
@triton.jit
def _get_last_loc_safe_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result_i32, # 固定为 int32 输出缓冲区,避免 Triton HIP 后端对混合宽度存储的错误编译
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
PREFIX_DTYPE_IS_I64: tl.constexpr, # 编译期常量:输入 dtype 是否为 int64
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
if PREFIX_DTYPE_IS_I64:
# 当输入已经是 int64 时,直接进行乘法避免额外转换
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
else:
# 当输入是 int32 时,将索引操作数显式提升为 int64,提升精度
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_index = req_pool_indices.to(tl.int64) * req_to_token_stride + (
prefix_lens.to(tl.int64) - 1
)
token_mask = mask & (prefix_lens > 0)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
# 结果存储进 int32 缓冲区,后续由调用者在 Triton 外交互完成类型提升
tl.store(result_i32 + offset, tokens, mask=mask)
def get_last_loc_triton_safe(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
"""int32安全的last_loc Triton实现:使用int32输出缓冲区,由PyTorch完成最终类型提升。"""
num_tokens = prefix_lens_tensor.shape[0]
BLOCK_SIZE = 256
# 分配 int32 缓冲区,避免 Triton 内发生混合宽度存储
result_i32 = torch.empty(
num_tokens, dtype=torch.int32, device=prefix_lens_tensor.device
)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
_get_last_loc_safe_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result_i32,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
PREFIX_DTYPE_IS_I64=(prefix_lens_tensor.dtype == torch.int64),
)
# 在 Triton 外部将结果提升回目标 dtype,避免编译器 bug
return result_i32.to(prefix_lens_tensor.dtype)
评论区精华
本PR无公开review评论。commit message中提到该PR是从#23146拆分出的,应HaiShaw要求独立合并以加速修复。PR由HaiShaw批准。
风险与影响
关联脉络
- PR #23146 [WIP] Original PR containing this fix: 本PR从中拆分出来,应审查者要求独立合并以加速AMD修复进度。
参与讨论