Prhub

#26303 [MoE] Extend kimi_k2_moe_fused_gate to support 256 experts (MiMo V2 Flash)

原始 PR 作者 mpdfdfl 合并时间 2026-06-01 16:03 文件变更 2 提交数 8 评论 19 代码增减 +203 / -154

执行摘要

MoE fused gate 内核扩展至 256 专家并优化性能

MiMo V2 Flash 使用与 Kimi K2 相同的 noaux_tc 路由但专家数为 256,因此需要扩展 fused gate 内核以避免回退到慢速通用实现。PR body 明确说明“so MiMo V2 Flash can take the same optimized fused gate as Kimi K2”。

建议详细阅读 CUDA 内核实现,特别是模板化 GateConfig 和 small-token 路径的优化技巧(bank-conflict-free 写入、单 pass renorm),对 CUDA 性能优化有参考价值。测试用例的 parametrize 重构也值得学习。

讨论亮点

审查中主要讨论包括:

1) xu-yfei 指出与 #22488 类似的优化,并建议采用 lane-strided 访问消除 bank conflict,作者采纳并实现;
2) xu-yfei 警告同时修改 sgl-kernel 和 SGLang 可能导致暂时不可用,作者回退 topk.py 修改;
3) ispobock 提议合并后对 fused moe gate 内核统一重构 (#26771)。

实现拆解

分两步实现:

1) 在 CUDA 内核中引入模板参数 N 表示专家数,编译期推导常量,并重写小 token 和大 token 内核以消除 bank conflict、向量化写回、单次 renorm;
2) 在 Python 测试用例中参数化 (num_experts, topk, scaling_factor) 同时覆盖 Kimi K2 和 MiMo V2 Flash 配置。由于版本风险,Python 包装器 topk.py 的修改被回退,最终仅内核和测试变更。

文件 模块 状态 重要度
sgl-kernel/csrc/moe/kimi_k2_moe_fused_gate.cu MoE 内核 modified 6.08
sgl-kernel/tests/test_kimi_k2_moe_fused_gate.py MoE 测试 modified 5.72

关键符号

kimi_k2_moe_fused_gate_kernel_small_token kimi_k2_moe_fused_gate_kernel test_kimi_k2_moe_fused_gate test_kimi_k2_specific_case

关键源码片段

sgl-kernel/tests/test_kimi_k2_moe_fused_gate.py test-coverage

测试覆盖从单一 384 专家扩展为 256 和 384 两种配置的参数化测试,保证功能正确。

import pytest
import torch
from sgl_kernel import kimi_k2_moe_fused_gate
from sglang.srt.layers.moe.topk import kimi_k2_biased_topk_impl# 定义被测配置 : (num_experts, topk, routed_scaling_factor)
_CONFIGS = [
    (384, 6, 2.872), # Kimi K2
    (256, 8, 1.0), # MiMo V2 Flash
]@pytest.mark.parametrize(
    'seq_length',
    list(range(1, 10)) + [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
)
@pytest.mark.parametrize('config', _CONFIGS, ids=['kimi384', 'mimo256'])
@pytest.mark.parametrize('apply_routed_scaling_factor_on_output', [False, True])
def test_kimi_k2_moe_fused_gate(seq_length, config, dtype, apply_routed_scaling_factor_on_output):
    # 解包配置
    num_experts, topk, routed_scaling_factor = config
    renormalize = True
    torch.manual_seed(seq_length)
    # 生成随机输入
    tensor = torch.rand((seq_length, num_experts), dtype=torch.float32, device='cuda')
    scores = tensor.clone()
    bias = torch.rand(num_experts, dtype=torch.float32, device='cuda')
​
    # 调用优化内核
    output, indices = kimi_k2_moe_fused_gate(
        tensor, bias, topk=topk, renormalize=renormalize,
        routed_scaling_factor=routed_scaling_factor,
        apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
    )
    # 参考实现
    ref_output, ref_indices = kimi_k2_biased_topk_impl(
        scores, scores, bias, topk=topk, renormalize=renormalize,
        routed_scaling_factor=routed_scaling_factor,
        apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
    )
    # 排序后比较权重 (MoE 实际使用的权重是排序后的 )
    assert torch.allclose(
        ref_output.sort()[0].to(torch.float32),
        output.sort()[0].to(torch.float32),
        rtol=1e-02, atol=1e-03,
    ), f'Output mismatch at seq_length {seq_length}, config {config}'

评论区精华

Bank conflict 优化与 lane-strided 访问 性能

xu-yfei 指出原始实现中因 chunked-per-lane 导致 STS.128 多路冲突,建议改为 lane-strided 使每轮 32 条线程连续访问。

结论:作者采纳该方案,重写大 token 内核的共享内存写入,实现 bank-conflict-free。 · 已解决

同时修改 sgl-kernel 和 SGLang 的风险 other

xu-yfei 警告 SGLang 依赖的 sgl-kernel 版本可能尚未更新,同时修改两个仓库会导致运行时版本不匹配,造成暂时不可用。

结论:作者回退 topk.py 更改,仅保留 sgl-kernel 修改和对应测试,确保先独立合并内核。 · 已解决

未来 Unified fused moe gate 内核重构 设计

ispobock 提议合并后对多个 fused gate 内核进行统一重构,创建 issue #26771 跟踪。

结论:作者愿意协助后续验证和基准测试。 · unresolved

风险与影响

主要风险:

1) 内核模板化后编译时间可能增加,但约束在 256/384 两种特化;
2) 新增 256 专家路径仅通过 Python 测试验证,可能未覆盖边缘情况(如极端序列长度);
3) 未来 sgl-kernel 重构可能与此 PR 的模板设计冲突。但由于无 API 变化且测试充分,整体风险较低。

对用户:MiMo V2 Flash 模型用户将自动受益于优化内核(升级 sgl-kernel 后),其他模型不受影响。对系统:无配置或依赖变化。对团队:此 PR 奠定了模板化多配置内核的基础,便于后续支持更多专家数。

核心路径变更 依赖协调风险 未来重构兼容

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论