Prhub

#26634 [CPU] fix incorrect index of b_ptr in fused_sigmoid_gating_delta_rule…

原始 PR 作者 blzheng 合并时间 2026-05-29 16:08 文件变更 2 提交数 2 评论 3 代码增减 +27 / -16

执行摘要

修复 CPU 核中 b_ptr 索引错误

本 PR 源于 issue #19484 在 review 中发现的 bug:fused_sigmoid_gating_delta_rule_update_kernel_impl 函数中 b_ptr 使用了错误的索引 ni(仅表示头索引),导致在多批次场景下读取到错误的 b 值。PR 作者 @blzheng 在 body 中特别感谢 @fadara01 指出该问题。

建议尽快合并。该修复为明确的 bugfix,且已有充分测试验证。对于关注 CPU 推理性能的团队值得关注。

讨论亮点

无 review 评论。仅有一条来自 gemini-code-assist bot 的警告(达到每日配额限制)。未发现争议点。

实现拆解

  1. C++ 内核修复:在 sgl-kernel/csrc/cpu/mamba/fla.cppfused_sigmoid_gating_delta_rule_update_kernel_impl 函数中,将 beta_val = 1 / (1 + std::exp(-b_ptr[ni])) 改为 b_ptr[bi * v_num_heads + ni],修正了索引,使其与其他参数(如 a_ptr 的索引方式)保持一致,确保在 batch_size > 1 时能正确获取每个样本的 b 值。
  2. Python 测试调整test/registered/cpu/test_mamba.py):
    • sigmoid_gating_delta_rule_update 函数中,将 g.unsqueeze(0)beta.unsqueeze(0) 改为 unsqueeze(1),使得 gbeta 的维度与 torch_recurrent_gated_delta_rule 的预期一致(在批次维度上添加新维度),避免维度不匹配导致的错误。
    • test_fused_sigmoid_gating_delta_rule_update 测试用例从单参数改为使用 @parametrize 装饰器,支持 batch_size=[1, 4] 等多种参数组合,增强测试覆盖。同时修正了 querykeyvalue 的 reshape 维度以匹配新参数,并将 query_start_loc 从固定 [0, 1] 改为 torch.arange(batch_size + 1) 以支持动态批量大小。
  3. 依赖导入调整:在测试文件中从 utils 导入 parametrize 以支持参数化测试。
文件 模块 状态 重要度
sgl-kernel/csrc/cpu/mamba/fla.cpp CPU 内核 modified 5.49
test/registered/cpu/test_mamba.py 测试 modified 6.22

关键符号

fused_sigmoid_gating_delta_rule_update_kernel_impl sigmoid_gating_delta_rule_update test_fused_sigmoid_gating_delta_rule_update

关键源码片段

test/registered/cpu/test_mamba.py test-coverage

测试文件,同时修复了 Python 函数中的 unsqueeze 维度错误,并增强了参数化测试,覆盖多批次场景。

# Python 参考函数,修正 g 和 beta 的 unsqueeze 维度:
# 之前使用 unsqueeze(0) 在批次前插入维度,但内核期望在批次后插入(unsqueeze(1))。
# 修正后与 torch_recurrent_gated_delta_rule 的输入布局一致。
def sigmoid_gating_delta_rule_update(...):
    beta = b.sigmoid()
    g = -A_log.float().exp() * softplus(a.float() + dt_bias)
    return torch_recurrent_gated_delta_rule(
        query, key, value,
        g.unsqueeze(1), # 原来为 unsqueeze(0),修正为 unsqueeze(1)
        beta.unsqueeze(1), # 同上
        initial_state, output_final_state,
        use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
    )# 测试用例被参数化,现在同时测试 batch_size=1 和 4:
@parametrize(
    batch_size=[1, 4], # 新增参数化,确保多批次正确性
    num_value_heads=[32],
    head_k_dim=[128],
    head_v_dim=[128],
    num_heads=[16],
    seq_len=[1],
    attn_tp_size=[1],
)
def test_fused_sigmoid_gating_delta_rule_update(self, batch_size, ...):
    # ... 内部 reshape 使用 batch_size 替代固定值 1
    query = query.view(1, batch_size, num_heads, head_k_dim)
    key = key.view(1, batch_size, num_heads, head_k_dim)
    value = value.view(1, batch_size, num_value_heads, head_v_dim)
    # query_start_loc 也动态生成
    query_start_loc = torch.arange(batch_size + 1, dtype=torch.int32)

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险较低。核心修复仅修改一行 C++ 索引,已通过参数化测试覆盖 batch_size=1 和 4。回归风险小。但请注意,该内核仅在 CPU 路径上生效,GPU 或其他硬件平台不受影响。

影响范围局限:主要影响 CPU 上使用 fused_sigmoid_gating_delta_rule_update 内核的 Mamba/SSM 模型推理。修复后保证了多批次(batch_size > 1)下计算的正确性。

单行修改 测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论