Prhub

#22312 Make GDN support non-continuous B/A Tensor input to fix the accuracy regression of Qwen3.5-27B

sgl-project/sglang · 作者 cs-cat · 合并时间 2026-04-10 18:58

分析状态 已生成
文件变更 3提交数 2 · 评论 5
代码增减 +272 / -8
bugfix run-ci sgl-kernel test

执行摘要

修复 GDN 内核以支持非连续 B/A 张量输入,解决 Qwen3.5-27B 准确性回归问题。

PR body 和关联 Issue #22311 指出,commit 5bdc07d974f6cf236fa765a685453ea5e587a838 的优化导致 Qwen3.5-27B 在 fallback 路径(v_per_group = 3)下,BA 投影被分割成非连续的 a 和 b 视图。GND Triton 内核假设连续内存布局并硬编码步幅,从而错误读取内存,引发准确性严重下降(从 49/50 降至 3/50)。修复旨在使内核支持非连续输入,恢复模型正常行为。

建议工程师精读此 PR,以学习内核中处理非连续内存布局的技术细节,以及如何通过显式步幅参数扩展内核通用性。关注测试文件中的模拟方法,可作为类似场景的参考。

讨论亮点

Review 过程中没有实质性讨论,reviewer 'yizhang2077' 直接批准了 PR。所有修改细节已在 PR body 和 commit 消息中明确描述,无需额外争议或决策。

实现拆解

实现方案分为三个部分:首先,在 fused_gdn_gating.py 中,为内核函数添加 stride_astride_b 参数,并在加载 a 和 b 张量时使用这些步幅替代硬编码偏移。其次,在 fused_sigmoid_gating_recurrent.py 中,类似添加 stride_a 参数,确保在循环更新中正确处理非连续布局。最后,新增测试文件 test_gdn_noncontiguous_stride.py,模拟 Qwen3.5 的 split 操作生成非连续张量,并验证两个内核函数在连续与非连续输入下输出一致。

文件 模块 状态 重要度
python/sglang/srt/layers/attention/fla/fused_gdn_gating.py attention/fla modified 8.0
python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py attention/fla modified 7.0
test/registered/attention/test_gdn_noncontiguous_stride.py test added 6.0

分析完成后,这里会展示 LLM 生成的相对完整源码片段和详细注释。

关键符号

fused_gdn_gating fused_sigmoid_gating_delta_rule_update

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

主要风险是原先的内存读取错误已通过修复消除,避免了准确性回归。新增测试覆盖了非连续输入场景,降低了未来类似回归风险。但修改涉及内核步幅计算,需确保不影响其他模型路径的性能或正确性;测试文件已验证 Qwen3.5 相关形状,但未覆盖所有可能布局。

直接影响 Qwen3.5-27B 用户,准确性恢复至正常水平,提升模型在长上下文推理中的可靠性。间接影响使用类似 GDN 路径的其他模型(如 Qwen3.5 其他配置),确保非连续输入得到正确处理。系统层面,修复了核心注意力模块的关键 bug,增强了整体推理稳定性。

内存布局假设错误 核心路径变更 缺少非连续输入测试

关联 Issue

#22311 Qwen3.5-27B accuracy regression caused by non-contiguous GDN split views

完整报告

执行摘要

本 PR 修复了 Qwen3.5-27B 模型因 GDN 内核不支持非连续输入而导致的准确性回归问题,通过更新内核步幅处理恢复准确性至正常水平,直接影响模型用户并提升系统可靠性。

功能与动机

动机源于 commit 5bdc07d 引入的优化,该优化导致 Qwen3.5-27B 在 fallback 路径下产生非连续 BA 张量视图。GND Triton 内核假设连续布局,硬编码步幅,引发内存读取错误和准确性严重下降(从 49/50 降至 3/50)。PR body 明确引用 Issue #22311,指出问题根源在于内核未处理 split 操作产生的非连续视图,修复目标是使内核支持非连续输入。

实现拆解

改动按模块拆解如下:

  • 核心内核更新
    • fused_gdn_gating.py:在内核函数中添加 stride_astride_b 参数,替换硬编码偏移为 a + i_b * stride_a + head_offb + i_b * stride_b + head_off 的加载逻辑。
    • fused_sigmoid_gating_recurrent.py:类似添加 stride_a 参数,并在循环更新中使用 p_a = a + bos * stride_a + i_hv * K + o_k(KDA 路径)或 p_a = a + bos * stride_a + i_hv(GDN 路径)确保正确指针移动。
  • 测试覆盖:新增 test_gdn_noncontiguous_stride.py,通过 _make_noncontiguous_ab 函数模拟 Qwen3.5 split 操作,生成非连续张量并对比内核输出与连续版本的差异,验证修复正确性。

评论区精华

Review 过程中无实质性讨论,reviewer 'yizhang2077' 直接批准,表明修改清晰且必要,所有技术细节已在 PR body 和 commit 中阐述。

风险与影响

  • 风险:原先内存读取错误已修复,但需确保步幅计算不影响其他模型路径;新增测试覆盖了 Qwen3.5 形状,但未全面验证所有潜在布局变体。
  • 影响:直接恢复 Qwen3.5-27B 准确性,提升用户信任;间接增强 GDN 模块对非连续输入的鲁棒性,可能惠及类似模型配置。

关联脉络

与历史 PR #22444 相关,后者也涉及 GDN 模块的性能优化(修改 gdn_backend.py),共同体现了对 GDN 内核的持续改进趋势。本 PR 作为准确性修复,补充了性能优化后的正确性保障。

参与讨论