执行摘要
- 一句话:为AMD MoRI EP添加SDMA路径支持
- 推荐动作:该PR提供了明确的硬件加速路径,设计简洁,值得AMD相关开发者关注。建议合并后补充单元测试覆盖SDMA路径的dispatch/combine逻辑,并考虑增加版本检测以增强鲁棒性。
功能与动机
AMD平台需要利用SDMA(System DMA)的硬件能力降低MoE token调度延迟。通过将融合的dispatch/combine拆分为send和recv两个阶段,可以在通信和计算之间实现更细粒度的重叠,提升低延迟场景下的吞吐量。
实现拆解
- 在
init_mori_op函数中添加enable_sdma参数:允许调用方控制是否启用SDMA路径,该参数默认False,通过LRU缓存保持线程安全。
- 在
_MoriEPDispatcherImplBase.__init__中读取环境变量:使用get_bool_env_var("MORI_ENABLE_SDMA", "false")初始化self.enable_sdma,便于运行态切换。
- 修改低延迟模式判断逻辑:将
async_mode的条件从仅检查deepep_mode.enable_low_latency()改为其与enable_sdma的或运算,使得SDMA启用时自动选用EpMode.LOW_LATENCY配置。
- 在
_dispatch_core和_combine_core中根据enable_sdma选择不同API:启用SDMA时,使用dispatch_send + dispatch_recv和combine_send + combine_recv代替原有的融合dispatch和combine,保持参数接口一致。
关键文件:
python/sglang/srt/layers/moe/token_dispatcher/moriep.py(模块 调度器;类别 source;类型 core-logic;符号 init_mori_op, _MoriEPDispatcherImplBase.init, _MoriEPDispatcherImplBase.mori_op, _dispatch_core): 核心变更文件,包含SDMA初始化、配置和调度逻辑的全部修改。
关键符号:init_mori_op, _MoriEPDispatcherImplBase.init, _MoriEPDispatcherImplBase.mori_op, _dispatch_core, _combine_core
关键源码片段
python/sglang/srt/layers/moe/token_dispatcher/moriep.py
核心变更文件,包含SDMA初始化、配置和调度逻辑的全部修改。
# python/sglang/srt/layers/moe/token_dispatcher/moriep.py
# 在 init_mori_op 函数中新增 enable_sdma 参数
@lru_cache(maxsize=4)
def init_mori_op(
group,
router_topk,
num_experts,
num_local_experts,
hidden_size,
params_dtype,
num_max_dispatch_tokens_per_rank,
deepep_mode,
instance_id=0,
fp8_dispatch=False,
fp4_dispatch=False,
enable_sdma=False, # 新增参数 ,用于控制是否启用 SDMA 路径
):
...
# 修改 async_mode 判断逻辑:deepep_mode 或 enable_sdma 任一为真则进入低延迟模式
async_mode = deepep_mode.enable_low_latency() or enable_sdma
if async_mode:
mode = EpMode.LOW_LATENCY
...
# 在 _MoriEPDispatcherImplBase.__init__ 中读取环境变量
class _MoriEPDispatcherImplBase:
def __init__(self, group, router_topk, permute_fusion, num_experts,
num_local_experts, hidden_size, params_dtype, deepep_mode,
instance_id=0):
...
# 通过环境变量 MORI_ENABLE_SDMA 控制 SDMA 开关 ,默认关闭
self.enable_sdma = get_bool_env_var("MORI_ENABLE_SDMA", "false")
...
@property
def mori_op(self):
if self._mori_op is None:
...
# 将 enable_sdma 传入 init_mori_op
self._mori_op = init_mori_op(
self.group,
self.router_topk,
...
self.enable_sdma, # 新增参数
)
return self._mori_op
评论区精华
review中gemini-code-assist[bot]提出了代码重构建议:将dispatch/combine中的if/else分支统一为函数引用赋值的模式,减少代码重复。最终版本采纳了建议,使用了dispatch_fn = (self.mori_op.dispatch_send if self.enable_sdma else self.mori_op.dispatch)的形式。HaiShaw审阅后批准了PR,未提出其他疑虑。
- dispatch/combine 中代码重复重构 (design): 采纳建议,最终版本使用了
dispatch_fn = (self.mori_op.dispatch_send if self.enable_sdma else self.mori_op.dispatch) 等简洁写法。
风险与影响
关联脉络
参与讨论