Prhub

#22690 [diffusion] model: Properly validate device for Mistral 3 attention

原始 PR 作者 avjves 合并时间 2026-04-16 15:29 文件变更 1 提交数 3 评论 4 代码增减 +4 / -1

执行摘要

修复 AMD ROCm 平台上 Mistral 3 注意力后端选择逻辑,避免误用 cuDNN 导致支持中断。

PR body 明确指出:PR #22423 的变更导致 AMD ROCm 平台上 Flux2 模型支持被破坏,因为 AMD 硬件也报告设备类型为 "cuda",但不支持 cuDNN 注意力。需要修复以恢复 AMD 支持。

该 PR 值得精读,尤其是关注 current_platform.is_cuda() 与设备类型检查的结合使用,这是处理跨平台兼容性问题的典型设计决策。

讨论亮点

review 中 gemini-code-assist[bot] 指出:移除 execution_tensor.device.type == "cuda" 检查可能导致在 CPU 张量(如测试或 GPU 机器上的 CPU offload 场景)上错误强制使用 SDPBackend.CUDNN_ATTENTION。建议结合两个检查以确保后端仅应用于支持硬件的 CUDA 张量。作者采纳建议,在最终实现中保留了设备类型检查并增加平台检测。

实现拆解

  1. 导入平台检测模块:在 python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py 中新增 from sglang.multimodal_gen.runtime.platforms import current_platform,引入平台检测能力。
  2. 修改注意力后端选择逻辑:在 forward 方法中,将 sdpa_context 的条件判断从 execution_tensor is not None and execution_tensor.device.type == "cuda" 改为 execution_tensor is not None and execution_tensor.device.type == "cuda" and current_platform.is_cuda(),确保同时满足张量在 CUDA 设备上且当前平台为 NVIDIA CUDA 硬件。
  3. 代码风格调整:在提交历史中,有单独的提交 "Fix styling" 调整代码格式,确保符合项目规范。
  4. 测试与验证:PR 未包含直接测试文件变更,但通过 CI 测试(包括 NVIDIA 和 AMD)验证了修复的有效性。
文件 模块 状态 重要度
python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py 扩散模型 modified 5.91

关键符号

forward

关键源码片段

python/sglang/multimodal_gen/runtime/models/encoders/mistral_3.py core-logic

唯一变更文件,包含 Mistral 3 编码器的核心注意力后端选择逻辑修复。

# 在 forward 方法中,修改注意力后端选择逻辑
execution_tensor = input_ids if input_ids is not None else inputs_embeds
sdpa_context = (
    sdpa_kernel(SDPBackend.CUDNN_ATTENTION)
    if execution_tensor is not None
    and execution_tensor.device.type == "cuda" # 检查张量是否在 CUDA 设备上
    and current_platform.is_cuda() # 新增:检查当前平台是否为 NVIDIA CUDA 硬件
    else nullcontext()
)
with sdpa_context:
    # FLUX.2 使用纯文本 Mistral3 路径,但仍期望与官方 HF 实现相同的本地 SDPA 内核选择。
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_hidden_states=output_hidden_states,
        return_dict=True,
        cache_position=cache_position,
        image_sizes=image_sizes,
        **kwargs,
    )

评论区精华

注意力后端选择条件的正确性 正确性

gemini-code-assist[bot] 指出移除设备类型检查可能导致 CPU 张量错误使用 cuDNN,建议结合设备类型和平台检测。

结论:作者采纳建议,在最终实现中保留设备类型检查并增加平台检测。 · 已解决

风险与影响

低风险:变更仅影响 Mistral 3 编码器的注意力后端选择逻辑,范围有限。

  • 回归风险:修复了 AMD 平台的支持问题,但需确保在混合平台环境(如同时有 NVIDIA 和 AMD GPU)中逻辑正确。
  • 兼容性风险:新增平台检测依赖 current_platform.is_cuda(),需确保该函数在所有目标平台上正确实现。
  • 性能风险:无,仅条件判断增加一个布尔检查,开销可忽略。

对用户:恢复 AMD ROCm 平台上 Flux2 扩散模型的正常运行,提升跨平台兼容性。
对系统:确保注意力后端选择更精确,避免在不支持 cuDNN 的硬件上错误启用,提高系统鲁棒性。
对团队:展示了平台检测与设备类型检查结合的重要性,为类似跨平台问题提供参考模式。

跨平台兼容性

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论