执行摘要
- 一句话:支持GLM-Image多设备并行生成
- 推荐动作:值得精读。尤其关注新增的
MAIN_RANK_ONLY_AND_SEND_TO_OTHERS并行模式设计,它解决了自回归生成阶段在多卡环境中必须保持token一致性的问题。这种“单卡执行后广播”的范式对于混合不同并行策略的流水线很有参考价值。同时,AR阶段与扩散阶段的拆分也体现了模块化思想。
功能与动机
目前GLM-Image只能在一个设备上生成图像,sp>1出现形状错误,tp>1存在准确性问题。本PR旨在通过合理拆分阶段并引入新的并行策略,使GLM-Image支持多卡加速并保持精度。
实现拆解
- 拆分AR阶段:在
python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py中,将原本混在一起的GlmImageBeforeDenoisingStage拆分为两个独立阶段:GlmImageAR(仅负责自回归token生成)和简化的GlmImageBeforeDenoisingStage(仅负责扩散前处理)。GlmImageAR通过parallelism_type声明为MAIN_RANK_ONLY_AND_SEND_TO_OTHERS。
- 新增并行模式枚举:在
python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py的StageParallelismType中添加MAIN_RANK_ONLY_AND_SEND_TO_OTHERS。
- 实现新并行分支:在
python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py的_execute_stages中增加对应分支:仅rank 0执行阶段,然后通过broadcast_pyobj将结果广播给所有其他rank,最后barrier同步。
- 调整管道注册顺序:在
python/sglang/multimodal_gen/runtime/pipelines/glm_image.py的create_pipeline_stages中,先添加GlmImageAR,再添加GlmImageBeforeDenoisingStage和DenoisingStage,移除GlmImageBeforeDenoisingStage对processor和vision_language_encoder的依赖。
- 修复旋转嵌入SP分片:在
python/sglang/multimodal_gen/configs/pipeline_configs/glm_image.py的get_freqs_cis中,将rotary_emb返回的cos/sin通过shard_rotary_emb_for_sp分片,并返回元组而非单个张量。
- 升级依赖:在
python/pyproject_npu.toml中将cache-dit从1.2.1升级到1.3.5,以兼容SP相关修复。
关键文件:
python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py(模块 扩散模型;类别 source;类型 core-logic;符号 GlmImageAR, GlmImageBeforeDenoisingStage, parallelism_type): 核心变更文件,将AR阶段拆分为独立PipelineStage,定义新并行模式。
python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py(模块 执行器;类别 source;类型 core-logic;符号 _execute_stages): 实现新并行模式的分支逻辑,确保AR阶段结果正确分发。
python/sglang/multimodal_gen/runtime/pipelines/glm_image.py(模块 管道注册;类别 source;类型 core-logic;符号 create_pipeline_stages): 调整管道阶段注册顺序,匹配拆分后的阶段。
python/sglang/multimodal_gen/configs/pipeline_configs/glm_image.py(模块 配置计算;类别 source;类型 core-logic;符号 get_freqs_cis, shard_rotary_emb_for_sp): 修复旋转嵌入的SP分片,确保多卡推理位置编码正确。
python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py(模块 框架基础;类别 source;类型 core-logic;符号 StageParallelismType): 新增并行模式枚举值,定义新模式的语义。
python/pyproject_npu.toml(模块 依赖管理;类别 config;类型 configuration): 升级cache-dit依赖版本,兼容SP相关修复。
关键符号:GlmImageAR.init, GlmImageAR.parallelism_type, GlmImageAR._compute_generation_params, GlmImageBeforeDenoisingStage.init, GlmImagePipelineConfig.get_freqs_cis, ParallelExecutor._execute_stages, shard_rotary_emb_for_sp
关键源码片段
python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py
核心变更文件,将AR阶段拆分为独立PipelineStage,定义新并行模式。
class GlmImageAR(PipelineStage):
"""自回归生成阶段:仅在主 rank 执行,结果广播给其他 rank。"""
def __init__(self, processor, vision_language_encoder):
super().__init__()
self.processor = processor
self.vision_language_encoder = vision_language_encoder
@property
def parallelism_type(self) -> StageParallelismType:
# 声明并行策略:主 rank 执行后广播
return StageParallelismType.MAIN_RANK_ONLY_AND_SEND_TO_OTHERS
@staticmethod
def _compute_generation_params(image_grid_thw, is_text_to_image):
# 根据图像网格 shape 计算生成 tokens 数、偏移和目标分辨率
grid_sizes = []
grid_hw = []
for i in range(image_grid_thw.shape[0]):
t, h, w = image_grid_thw[i].tolist()
grid_sizes.append(int(h * w))
grid_hw.append((int(h), int(w)))
if not is_text_to_image:
max_new_tokens = grid_sizes[-1] + 1
large_image_start_offset = 0
target_grid_h, target_grid_w = grid_hw[-1]
else:
total_tokens = sum(grid_sizes)
max_new_tokens = total_tokens + 1
large_image_start_offset = sum(grid_sizes[1:])
target_grid_h, target_grid_w = grid_hw[0]
return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w
python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py
实现新并行模式的分支逻辑,确保AR阶段结果正确分发。
elif paradigm == StageParallelismType.MAIN_RANK_ONLY_AND_SEND_TO_OTHERS:
# 仅主 rank(rank 0)执行阶段,其他 rank 等待
if rank == 0:
self.before_stage(stage, stage_index, batch, server_args)
batch = stage(batch, server_args)
self.after_stage(stage_index)
torch.distributed.barrier()
# 将 batch 从主 rank 广播到所有 rank
obj_list = [batch] if rank == 0 else []
broadcasted_list = broadcast_pyobj(
obj_list, rank=rank, dist_group=group.cpu_group, src=0
)
if rank != 0:
batch = broadcasted_list[0]
torch.distributed.barrier()
评论区精华
本PR无公开review讨论,审核者ping1jing2直接批准。从提交历史看,经过两次合并主分支后最终提交稳定。
风险与影响
- 风险:
- 新并行模式验证不足:
MAIN_RANK_ONLY_AND_SEND_TO_OTHERS仅在GLM-Image上测试,其他扩散模型未使用,但模式本身是通用的,若被复用需确保语义正确。
- AR阶段单点瓶颈:AR生成仅在rank 0执行,如果AR生成时间远大于扩散阶段,可能成为瓶颈;但当前数据下AR阶段耗时较少。
- 旋转嵌入分片影响:
shard_rotary_emb_for_sp仅在GLM-Image配置中调用,不影响其他模型,但若未来其他模型复用需检查兼容性。
- 依赖升级风险:
cache-dit从1.2.1跳至1.3.5,可能引入新行为,需关注CI测试结果。
- 影响:
- 用户:GLM-Image用户现在可以指定
--num-gpus N --sp-degree N实现多卡加速,实测SP=2时速度提升约2倍(从31.66s降至16.86s)。
- 系统:扩展了扩散模型的并行策略选项,为其他模型提供可参考的“单卡执行后广播”范式。
- 团队:需维护新增的并行枚举值和对应逻辑,但设计上与现有模式保持一致,维护成本可控。
- 风险标记:新并行模式仅在GLM-Image验证, AR阶段单点瓶颈, 旋转嵌入分片影响范围有限, cache-dit版本升级
关联脉络
参与讨论