Prhub

#43982 [Bugfix] Fix Gemma4 MTP block_table batch_size mismatch under concurrent load

原始 PR 作者 Dymasik 合并时间 2026-06-04 08:11 文件变更 1 提交数 3 评论 1 代码增减 +6 / -1

执行摘要

修复 Gemma4 MTP 并发下 block table batch_size 不匹配

Gemma4 + MTP + FlashAttention 在并发负载下(batch 部分占用时)会触发 RuntimeError: batch_size must be equal to batch_size_k,导致服务崩溃。该问题由作者在测试中发现并修复。

建议尽快合并此修复,因为它直接解决了 Gemma4 MTP 在 FlashAttention 后端下的生产阻塞 bug。虽然改动极小,但 root cause 分析清晰,值得其他 speculative decoder 开发者在实现类似 per-group block table 时注意 batch 维度对齐。

讨论亮点

reviewer benchislett 批准了该 PR,并询问为何该 bug 这么久才暴露出来。没有其他深层讨论。

实现拆解

  1. 定位根因:在 vllm/v1/spec_decode/gemma4.pybuild_per_group_and_layer_attn_metadata 方法中,per-group block table 通过 set_per_group_block_table() 存储,其维度为 (num_reqs_padded, max_blocks)(CUDA graph 填充后的 padded 维度)。而 common_attn_metadata 在进入此方法前已通过 .unpadded() 裁剪到实际 num_reqs,导致二者 batch 维度不一致。
  2. 修复单行代码:在 build_per_group_and_layer_attn_metadata 中,从 common_attn_metadata 获取 batch_size(实际请求数 num_reqs),在赋值 cm.block_table_tensor 时进行 self._per_group_block_tables[gid][:batch_size] 切片,使其与 cu_seqlens_q 维度对齐。
  3. 测试验证:作者提供了详细的复现脚本和测试结果:在 8 卡环境、并发 8 请求、32 请求的负载下,修复前在 wave 1 就有 5/8 请求失败,随后服务崩溃;修复后所有请求成功。
文件 模块 状态 重要度
vllm/v1/spec_decode/gemma4.py 推测解码 modified 5.63

关键符号

build_per_group_and_layer_attn_metadata

关键源码片段

vllm/v1/spec_decode/gemma4.py core-logic

本 PR 唯一修改文件,核心修复在 `build_per_group_and_layer_attn_metadata` 方法中增加 `[:batch_size]` 切片操作,解决 padded block table 维度不匹配问题。

# vllm/v1/spec_decode/gemma4.pydef build_per_group_and_layer_attn_metadata(
    self,
    common_attn_metadata: CommonAttentionMetadata,
    draft_index: int = 0,
) -> tuple[list[object], dict[str, object]]:
    """Build attention metadata using the correct block table per group.    Gemma4 has multiple KV cache groups (sliding vs full attention)
    with different block tables.  The base class receives a single
    common_attn_metadata whose block_table belongs to one group.
    We swap in the correct block table for each draft attention group.
    """
    per_group_attn_metadata: list[object] = []
    per_layer_attn_metadata: dict[str, object] = {}
    batch_size = common_attn_metadata.batch_size() # 获取实际 batch 大小
    for attn_group in self.draft_attn_groups:
        gid = attn_group.kv_cache_group_id
        if gid in self._per_group_block_tables:
            cm = copy(common_attn_metadata)
            # 关键修复:切片到实际 batch 大小,对齐 cu_seqlens_q 维度
            # 原始 block table 可能带有 CUDA graph 填充的 padded 维度
            cm.block_table_tensor = self._per_group_block_tables[gid][:batch_size]
        else:
            cm = common_attn_metadata
        attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
            common_attn_metadata=cm, draft_index=draft_index
        )
        per_group_attn_metadata.append(attn_metadata)
        for layer_name in attn_group.layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
    return per_group_attn_metadata, per_layer_attn_metadata

评论区精华

没有提炼出高价值讨论线程

当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。

风险与影响

风险极低:修改仅涉及一行代码(在 block_table_tensor 赋值时加 [:batch_size] 切片),且逻辑与 common_attn_metadata 的 padding 处理保持一致。不会影响其他模型或非 MTP 模式。注意需要确保 common_attn_metadata.batch_size() 返回的值是正确的实际 batch 大小,当前逻辑中该值来自 unpadded 的 metadata,是可靠的。

  • 用户:修复了 Gemma4 MTP 在并发负载下的崩溃问题,使得该模型可以稳定用于生产。
  • 系统:仅影响 Gemma4 MTP 的 speculative decoding 路径,其他模型或解码方式无影响。
  • 团队:合并后无需额外配置或迁移。
并发路径变更 缺少测试覆盖

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论