Prhub

#24509 Support user-supplied recv-side transform in dumper grafter

原始 PR 作者 fzyzcjy 合并时间 2026-05-06 16:56 文件变更 2 提交数 6 评论 2 代码增减 +283 / -8

执行摘要

支持用户自定义 grafter 接收端张量变换函数

在 grafter 跨系统张量接枝中,baseline 和 target 的张量形状可能不匹配,或用户希望对接收到的张量进行后处理(如缩放、对比校验)。此 PR 允许用户通过配置提供自定义变换函数,替代默认的 identity-by-rank 降级,提升灵活性。

值得精读。该 PR 展示了如何在分布式调试工具中安全地支持用户自定义钩子,异常处理设计巧妙。同时应关注 review 中提出的多进程 rank 问题,确保未来扩展方向正确。

讨论亮点

两条 review 评论指出潜在的多进程问题:

  • global_rank 计算错误:所有 baseline 进程被赋 rank 0,导致自定义进程组初始化时 rank 碰撞。bot 建议基于 local_rank 偏移。
  • sender_rank 计算错误:T2B 方向发送者 rank 应为 baseline_world_size 而非 1,且仅单个发送者广播存在局限。
    这些问题主要影响多 world_size 场景,本 PR 未涉及相关代码,未在本次修复。

实现拆解

  1. 配置扩展:在 DumperConfig 中新增可选字段 grafter_transform_path,用户可指定完全限定 Python 函数路径。
  2. 函数动态加载:新增模块级函数 _load_function,通过点分路径动态加载可调用对象,提供清晰的错误分类(模块不存在、属性不存在、缺少点号、非可调用)。
  3. 变换调度:在 _Grafter 中新增 _apply_transform 方法,根据配置选择用户自定义变换或默认变换 _default_transform。默认变换按 rank 身份映射(发送 rank 0 对应接收 rank 0)。
  4. 安全包装:在 maybe_intercept 的 recv 分支中,将 transform 和 copy_ 包裹在 try-except 中,捕获所有异常并记录堆栈,保证错误不扩散。
  5. 默认降级_default_transform 静态方法直接返回 received_list[0],适用于当前 1+1 场景,未来可扩展。
  6. 测试覆盖:包含 _load_function 错误路径单元测试(ModuleNotFoundErrorAttributeErrorValueError、非可调用调用时),端到端用户 transform 测试(验证值翻倍),以及 shape 不匹配默认降级不崩溃测试。
文件 模块 状态 重要度
python/sglang/srt/debug_utils/dumper.py 调试工具 modified 8.06
test/registered/debug_utils/test_dumper.py 调试工具 modified 7.17

关键符号

_load_function _apply_transform _default_transform

关键源码片段

python/sglang/srt/debug_utils/dumper.py dependency-wiring

核心逻辑:新增配置字段、动态加载函数、变换调度和异常安全包装。

# 核心变换调度与安全执行
import traceback
from collections.abc import Callabledef _load_function(path: str) -> Callable:
    """从一个点分模块路径加载可调用对象。"""
    if '.' not in path:
        raise ValueError(f'missing dotted prefix in path: {path!r}')
    module_name, _, attr_name = path.rpartition('.')
    module = importlib.import_module(module_name)
    return getattr(module, attr_name)class _Grafter:
    def _apply_transform(self, received_list, *, target):
        # 若无用户配置则使用默认身份映射
        path = self._config.grafter_transform_path
        if path is None:
            return self._default_transform(received_list, target=target)
        # 加载用户函数并调用
        return _load_function(path)(received_list, target)
​
    @staticmethod
    def _default_transform(received_list, *, target):
        # 默认 identity-by-rank(1+1 场景直接取第一元素)
        return received_list[0]
​
    def maybe_intercept(self, *, value, tags):
        # ... 发送端逻辑 ...
        if not is_send:
            received = obj_list[0]
            if isinstance(received, torch.Tensor):
                received = received.to(value.device)
            # 安全包装:用户变换异常不中断训练
            try:
                value_to_override = self._apply_transform([received], target=value)
                value.copy_(value_to_override)
            except Exception as e:
                _log(f'[Grafter] recv transform failed: {e}\n{traceback.format_exc()}')

评论区精华

global_rank 计算错误导致多进程 rank 碰撞 正确性

gemini-code-assist[bot] 指出当前 global_rank 计算方法在所有 baseline 进程中赋值为 0,导致自定义进程组初始化时 rank 碰撞,建议基于 local_rank 偏移。

结论:未修复。本 PR 未涉及该部分代码;整体 grafter 多进程支持需后续解决。 · 待处理

sender_rank 计算错误导致 T2B 方向广播 rank 偏移缺失 正确性

同一 review 指出 sender_rank 计算在 T2B 方向应使用 baseline_world_size 作为偏移,而非 1,且仅单个发送者广播存在局限。

结论:未修复。该问题仅在 multi-world-size 配置下显现。 · 待处理

风险与影响

  • 异常安全:用户自定义函数异常被 try-catch 捕获,不会导致训练崩溃,但会静默跳过该次 graft,可能导致调试遗漏(通过日志可发现)。
  • 动态加载_load_function 使用 importlib 加载模块,若路径错误或模块副作用可能影响环境,但风险可控。
  • 配置兼容:新增字段默认为 None,向后兼容。
  • 多进程局限_default_transform 仅直接取接收列表第一元素,当前仅支持 1+1 场景;多 world_size 需后续扩展。
  • 用户:使用 grafter 且需自定义张量处理时,只需设置 grafter_transform_path,未配置则行为不变。
  • 系统:性能影响忽略不计(一次函数调用 + copy)。异常路径不会影响系统稳定性。
  • 测试:新增测试覆盖主要功能和错误路径,降低回归风险。
  • 团队:增加配置复杂度,但测试充分,维护成本可控。
异常安全包装 动态函数加载 配置兼容性 多进程兼容

关联 Issue

未识别关联 Issue

当前没有检测到明确关联的 Issue 链接,后续同步到相关引用后会出现在这里。

完整报告

参与讨论