执行摘要
- 一句话:修复 EPLB 路由副本选择偏差
- 推荐动作:值得精读。该 PR 展示了一个典型的负载均衡边界问题及其巧妙解决方案—— Knuth 乘法哈希。对于关注 MoE 推理优化的工程师,代码注释清晰,测试设计有针对性,是学习 Triton kernel 开发和负载均衡策略的好范例。
功能与动机
当前 EPLB 融合 MoE 路由器的副本选择策略在 top_k 是副本数整数倍时严重不均,同一逻辑专家的多个副本间负载偏差超过 90%。PR body 指出“Current replica selection in the EPLB fused-MoE router distributes tokens unevenly across replicas of the same logical expert”,需要切换哈希实现来修复。
实现拆解
- 在 Triton kernel 中引入 Knuth 乘法哈希(
vllm/model_executor/layers/fused_moe/router/base_router.py):
- 在
_eplb_map_and_record_i32_kernel 中添加 num_active_experts 参数表示每个 token 的专家数。
- 将原来的
replica_idx = offs % replica_count 替换为基于 token 索引的哈希计算:token_idx = offs // num_active_experts,hashed = (token_idx * KNUTH_MULTIPLIER) & 0xFFFFFFFF,replica_idx = hashed % replica_count。
- 这样即使连续 token 的 top-k 偏移量是副本数的整数倍,哈希值也能均匀分散到不同副本。
- 更新 Triton 启动调用(同一文件
_eplb_map_and_record_triton 函数):
- 从
topk_ids_in.shape[-1] 提取 num_active_experts 并传递给 kernel。
- 添加回归测试(
tests/kernels/moe/test_routing.py):
- 新增
test_eplb_map_hot_expert_replica_balance 函数,构造一个“热专家”(逻辑专家 0)具有 R 个副本的场景,随机生成其他专家 ID,强制热专家出现在每个 token 的 top-1 位置。运行映射后验证前 R 个物理 ID 的负载最大/均值比 <1.15。
- 同时更新了现有
test_eplb_map_with_redundancy 的预期输出和注释,以匹配新的哈希算法。
关键文件:
vllm/model_executor/layers/fused_moe/router/base_router.py(模块 路由器;类别 source;类型 core-logic;符号 _eplb_map_and_record_i32_kernel, _eplb_map_and_record_triton): 核心变更文件,修改了 Triton kernel 中副本选择逻辑,引入 Knuth 乘法哈希替代简单取模
tests/kernels/moe/test_routing.py(模块 测试;类别 test;类型 test-coverage;符号 test_eplb_map_hot_expert_replica_balance): 新增回归测试验证热专家副本均衡性,更新现有测试预期输出以匹配新哈希算法
关键符号:_eplb_map_and_record_i32_kernel, _eplb_map_and_record_triton, test_eplb_map_hot_expert_replica_balance
关键源码片段
vllm/model_executor/layers/fused_moe/router/base_router.py
核心变更文件,修改了 Triton kernel 中副本选择逻辑,引入 Knuth 乘法哈希替代简单取模
# Triton JIT kernel for EPLB mapping with Knuth multiplicative hash
@triton.jit
def _eplb_map_and_record_i32_kernel(
topk_ids_ptr,
logical_replica_count_ptr,
logical_to_physical_ptr,
out_ids_ptr,
out_ptr,
record_enabled_ptr,
num_logical_experts,
map_slots,
out_size,
numel,
num_active_experts, # number of experts per token (top-k)
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < numel
expert_id = tl.load(topk_ids_ptr + offs, mask=mask, other=0).to(tl.int64)
valid_expert = (expert_id >= 0) & (expert_id < num_logical_experts)
safe_expert_id = tl.where(valid_expert, expert_id, 0)
# Load replica count for this expert
replica_count = tl.load(
logical_replica_count_ptr + safe_expert_id,
mask=mask & valid_expert,
other=1,
)
replica_count = tl.maximum(replica_count, 1)
# Knuth multiplicative hash: floor(2^32 / phi)
KNUTH_MULTIPLIER = 2654435769
# token_idx is the token index in the batch (flattened position divided by num_active_experts)
token_idx = (offs // num_active_experts).to(tl.int64)
hashed = (token_idx * KNUTH_MULTIPLIER) & 0xFFFFFFFF
replica_idx = hashed % replica_count # uniform distribution across replicas
# Map to physical expert ID
map_index = safe_expert_id * map_slots + replica_idx
physical_id = tl.load(
logical_to_physical_ptr + map_index,
mask=mask & valid_expert,
other=-1,
)
tl.store(out_ids_ptr + offs, physical_id, mask=mask)
# Atomic load recording
record_enabled = tl.load(record_enabled_ptr) != 0
valid = mask & record_enabled & (physical_id >= 0) & (physical_id < out_size)
safe_physical_id = tl.where(physical_id >= 0, physical_id, 0)
tl.atomic_add(out_ptr + safe_physical_id, 1, mask=valid)
评论区精华
主要讨论围绕命名规范:
风险与影响
关联脉络
- PR #40412 fused_moe: treat NIXL EP as batched experts: 同为 fused_moe 模块的修复,涉及路由逻辑,可参考 EPLB 相关上下文
- PR #40794 [Bugfix][MoE] Unpad routed output before shared expert add [Fixes #35949]: 同为 MoE 路由相关 bugfix,展示 MoE 推理路径中常见问题
参与讨论