执行摘要
- 一句话:集中扩散模型FSDP分片条件并修复权重加载与包装器兼容性
- 推荐动作:值得精读,尤其是FSDP分片条件的集中设计和通用回退机制。设计决策(如基于类名和通用编号块的自动分片)具有借鉴意义。但需关注回归问题的修复进展。
功能与动机
扩散模型使用FSDP进行分片推理时,不同模型的分片条件分散在各自配置中,难以维护且容易遗漏。FSDP与张量并行(TP)的权重加载存在冲突,需要先加载预分片权重再分发。FSDP包装后的模型无法通过简单 getattr 访问内部属性。本PR集中分片条件、修复加载逻辑并添加通用回退机制,以降低维护成本并提高模型兼容性。
实现拆解
- 集中FSDP分片条件函数:新增
configs/models/fsdp.py,定义 is_module_list_entry、is_layer、is_block、is_double_block 等系列函数,统一通过模块名称判断分片边界。各模型配置(如 hunyuanvideo、gemma2 等)删除内联函数并改为导入集中函数。
- 重构FSDP加载器:修改
runtime/loader/fsdp_load.py,新增 _get_param_for_weight_loading 函数以支持带 weight_loader 的参数;新增 _is_common_numbered_block 和 _resolve_fsdp_shard_conditions 函数,提供基于模块类名或通用编号块的回退分片策略;在 maybe_load_fsdp_model 中先保存所有带 weight_loader 的参数,再执行FSDP包装。
- 增强包装兼容性:修改
runtime/pipelines_core/stages/denoising.py,新增 _get_transformer_attr 方法递归搜索FSDP包装模块(_fsdp_wrapped_module、module、_orig_mod)以获取属性;将 prepare_extra_func_kwargs 重构为 _get_extra_func_kwarg_names 并添加 _extra_func_kwarg_names_cache 缓存,避免每步签名检查的开销。
- 更新性能基线:更新
test/registered/fp8/fp8_perf_baselines.txt 以反映当前FSDP推理性能。
- 测试与配置调整:修改多个模型配置文件的导入,确保兼容集中式函数。
关键文件:
python/sglang/multimodal_gen/configs/models/fsdp.py(模块 扩散模型;类别 source;类型 data-contract;符号 is_module_list_entry, is_module_list_entry_in, is_layer, is_block): 新增的核心文件,集中定义了所有FSDP分片条件函数,是整个PR的基石。
python/sglang/multimodal_gen/runtime/loader/fsdp_load.py(模块 加载器;类别 source;类型 dependency-wiring;符号 _get_param_for_weight_loading, _make_class_name_shard_condition, _is_common_numbered_block, _resolve_fsdp_shard_conditions): 修改了权重加载逻辑,新增预分片参数保存、回退分片条件函数,修复FSDP+TP兼容性。
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py(模块 降噪阶段;类别 source;类型 core-logic;符号 _get_transformer_attr, prepare_extra_func_kwargs, _get_extra_func_kwarg_names): 修改了降噪阶段,添加包装模块属性访问和kwarg签名缓存,提升兼容性和性能。
python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py(模块 扩散模型;类别 source;类型 data-contract;符号 is_double_block, is_single_block, is_refiner_block, is_txt_in): 删除了内联分片条件函数,改为从集中文件导入,体现集中化模式。
python/sglang/multimodal_gen/configs/models/encoders/gemma2.py(模块 编码器;类别 source;类型 data-contract;符号 _is_transformer_layer, _is_embeddings, _is_final_norm): 类似 hunyuanvideo,删除内联条件改为导入集中函数,代表所有编码器的统一变更。
关键符号:is_module_list_entry, is_module_list_entry_in, _get_param_for_weight_loading, _resolve_fsdp_shard_conditions, _is_common_numbered_block, _get_transformer_attr, _get_extra_func_kwarg_names
关键源码片段
python/sglang/multimodal_gen/configs/models/fsdp.py
新增的核心文件,集中定义了所有FSDP分片条件函数,是整个PR的基石。
# SPDX-License-Identifier: Apache-2.0
def is_module_list_entry(name: str, container_name: str) -> bool:
# 匹配直接属于某容器的子模块(数字索引),排除其内部子模块
parts = name.split(".")
return len(parts) >= 2 and parts[-2] == container_name and parts[-1].isdigit()
def is_module_list_entry_in(name: str, container_names: tuple[str, ...]) -> bool:
# 匹配直接属于多个候选容器之一的子模块
parts = name.split(".")
return len(parts) >= 2 and parts[-2] in container_names and parts[-1].isdigit()
def is_layer(name: str, module: object) -> bool:
# 匹配 "layers" 容器下的直接子模块
return is_module_list_entry(name, "layers")
def is_block(name: str, module: object) -> bool:
# 匹配 "blocks" 容器下的直接子模块
return is_module_list_entry(name, "blocks")
def is_double_block(name: str, module: object) -> bool:
# 匹配 "double_blocks" 容器下的直接子模块(用于 HunyuanVideo 等模型)
return is_module_list_entry(name, "double_blocks")
def is_single_block(name: str, module: object) -> bool:
# 匹配 "single_blocks" 容器下的直接子模块
return is_module_list_entry(name, "single_blocks")
def is_refiner_block(name: str, module: object) -> bool:
# 匹配 "refiner_blocks" 容器下的直接子模块
return is_module_list_entry(name, "refiner_blocks")
python/sglang/multimodal_gen/runtime/loader/fsdp_load.py
修改了权重加载逻辑,新增预分片参数保存、回退分片条件函数,修复FSDP+TP兼容性。
def _get_param_for_weight_loading(
model: torch.nn.Module,
param_dict: dict[str, torch.nn.Parameter],
param_name: str,
) -> torch.nn.Parameter | None:
# 优先返回带 weight_loader 的参数,用于自定义加载逻辑
actual_param = param_dict.get(param_name)
if actual_param is not None and getattr(actual_param, "weight_loader", None):
return actual_param
# 查找预 FSDP 保存的带 weight_loader 的参数(在包装前保存的)
pre_fsdp_weight_loader_params = getattr(model, "_pre_fsdp_weight_loader_params", {})
pre_fsdp_param = pre_fsdp_weight_loader_params.get(param_name)
if pre_fsdp_param is not None:
return pre_fsdp_param
return actual_param
def _make_class_name_shard_condition(class_names: set[str]):
# 根据模块类名创建分片条件
def shard_condition(n: str, m: nn.Module) -> bool:
return type(m).__name__ in class_names
return shard_condition
def _is_common_numbered_block(n: str, m: nn.Module) -> bool:
# 通用编号块匹配:识别常见容器名下的数字索引子模块
return is_module_list_entry_in(
n,
(
"blocks",
"layers",
"double_blocks",
"single_blocks",
"refiner_blocks",
"noise_refiner",
"context_refiner",
"transformer_blocks",
"single_transformer_blocks",
),
)
def _resolve_fsdp_shard_conditions(
model: torch.nn.Module,
fsdp_shard_conditions: list[Callable[[str, nn.Module], bool]] | None,
) -> tuple[list[Callable[[str, nn.Module], bool]], str]:
# 优先级:显式条件 > 基于模型 _repeated_blocks/_no_split_modules 的类名条件 > 通用编号块条件
if fsdp_shard_conditions:
return fsdp_shard_conditions, "explicit"
block_class_names = set(getattr(model, "_repeated_blocks", []) or [])
block_class_names.update(getattr(model, "_no_split_modules", []) or [])
if block_class_names:
return [_make_class_name_shard_condition(block_class_names)], "block-class"
return [_is_common_numbered_block], "common-numbered-block"
python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
修改了降噪阶段,添加包装模块属性访问和kwarg签名缓存,提升兼容性和性能。
def _get_transformer_attr(self, name: str) -> Any:
# 递归搜索 transformer 及其 FSDP 包装属性,找到第一个非 None 的值
seen: set[int] = set()
stack = [self.transformer]
while stack:
module = stack.pop()
if module is None or id(module) in seen:
continue
seen.add(id(module))
value = getattr(module, name, None)
if value is not None:
return value
# 遍历常见的 FSDP 包装属性名称
for wrapper_attr in ("_fsdp_wrapped_module", "module", "_orig_mod"):
wrapped = getattr(module, wrapper_attr, None)
if wrapped is not None:
stack.append(wrapped)
return None
def _get_extra_func_kwarg_names(self, func) -> tuple[bool, frozenset[str]]:
# 缓存函数的可变参标志和参数名集合,避免每次调用都反射签名
import functools
# 处理 cache-dit 的 partial 包装
if isinstance(func, functools.partial):
func = func.func
target_func = inspect.unwrap(func)
cache_target = (
target_func.__func__ if inspect.ismethod(target_func) else target_func
)
cache_key = id(cache_target)
cached = self._extra_func_kwarg_names_cache.get(cache_key)
if cached is not None:
return cached
params = inspect.signature(target_func).parameters
result = (
any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()),
frozenset(params),
)
self._extra_func_kwarg_names_cache[cache_key] = result
return result
评论区精华
唯一的评论来自 gemini-code-assist[bot],指出 fsdp_load.py 中一处 requires_grad 赋值是冗余的(因为 _make_param_like 已设置 requires_grad=False),建议简化。作者未在评论中回复,该行在最终版本中已保留,可能认为该冗余在上下文中无害。
- 移除冗余requires_grad赋值 (style): 未采纳该建议,该行在最终版本中保留,可能认为其无害或为保险起见。
风险与影响
- 风险:
- 回归风险:评论指出模型
Efficient-Large-Model/Sana_600M_512px_diffusers 在该PR后无法运行,可能源于分片条件或加载逻辑变化。作者已同意修复,但合并前需确保该模型正常。
- 集中函数可能改变原有分片边界,影响性能或正确性。
- 缓存机制可能引入陈旧数据,需确保缓存键正确更新。
- 影响:影响所有使用FSDP推理的扩散模型,包括 HunyuanVideo、Gemma、LLaMA、T5 等编码器/解码器。变更涉及29个文件,但核心逻辑集中在加载器和降噪阶段。对于未使用FSDP的模型无影响。
- 风险标记:回归风险, 模型兼容性, 缓存一致性
关联脉络
参与讨论