执行摘要
缓存 diffusion 中 FP32LayerNorm 参数转换
Wan 扩散路径每次前向都调用 self.weight.float().to(device),推理时参数不变,可缓存。从 FastVideo PR #1245 移植。
值得关注缓存失效设计,测试覆盖全面。如使用扩散模型,建议合并。
无实质性讨论,直接合并。
Wan 扩散路径每次前向都调用 self.weight.float().to(device),推理时参数不变,可缓存。从 FastVideo PR #1245 移植。
值得关注缓存失效设计,测试覆盖全面。如使用扩散模型,建议合并。
无实质性讨论,直接合并。
| 文件 | 模块 | 状态 | 重要度 |
|---|---|---|---|
python/sglang/multimodal_gen/runtime/layers/layernorm.py |
层归一化 | modified | 6.69 |
python/sglang/multimodal_gen/test/unit/test_fp32_layernorm.py |
单元测试 | added | 6.69 |
python/sglang/multimodal_gen/runtime/layers/layernorm.py
core-logic
核心变更,添加 _cached_fp32_param 缓存逻辑,修改 forward 使用缓存。
def _cached_fp32_param(
self, attr: str, param: torch.Tensor | None, device: torch.device
) -> torch.Tensor | None:
# 参数为 None 时直接返回
if param is None:
return None
# 保持 autograd 语义不变:若 grad 启用,直接转换(不缓存)
if torch.is_grad_enabled():
return param.float().to(device=device)
# 构建 key:包含 data_ptr、_version、来源设备、目标设备、dtype
key = (
param.data_ptr(),
param._version,
param.device,
device,
param.dtype,
)
# 从实例的 __dict__ 中按 attr 名称查询缓存
cache = self.__dict__.get(attr)
if cache is not None and cache[0] == key:
return cache[1]
# 否则做转换并缓存
fp32_param = param.detach().to(device=device, dtype=torch.float32)
self.__dict__[attr] = (key, fp32_param)
return fp32_param
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
device = inputs.device
weight = self._cached_fp32_param('_weight_fp32_cache', self.weight, device)
bias = self._cached_fp32_param('_bias_fp32_cache', self.bias, device)
return F.layer_norm(
inputs.float(),
self.normalized_shape,
weight,
bias,
self.eps,
).to(origin_dtype)
python/sglang/multimodal_gen/test/unit/test_fp32_layernorm.py
test-coverage
新增单元测试,覆盖缓存正确性、复用、失效、grad 模式。
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA required')
def test_fp32_layernorm_cache_reuses_converted_params():
norm = FP32LayerNorm(16, eps=1e-5).cuda().to(torch.bfloat16)
inputs = torch.randn(4, 16, device='cuda', dtype=torch.bfloat16)
with torch.no_grad():
norm(inputs)
weight_cache = norm.__dict__['_weight_fp32_cache']
bias_cache = norm.__dict__['_bias_fp32_cache']
norm(inputs)
# 第二次调用应复用同一 tensor 对象
assert norm.__dict__['_weight_fp32_cache'][1] is weight_cache[1]
assert norm.__dict__['_bias_fp32_cache'][1] is bias_cache[1]
# 缓存应不出现在 state_dict 中
assert '_weight_fp32_cache' not in norm.state_dict()
assert '_bias_fp32_cache' not in norm.state_dict()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA required')
def test_fp32_layernorm_grad_mode_preserves_autograd_path():
norm = FP32LayerNorm(16, eps=1e-5).cuda().to(torch.bfloat16)
inputs = torch.randn(4, 16, device='cuda', dtype=torch.bfloat16, requires_grad=True)
output = norm(inputs).float().sum()
output.backward()
# 应能正确反向传播
assert inputs.grad is not None
# 且缓存不被填充
assert '_weight_fp32_cache' not in norm.__dict__
assert '_bias_fp32_cache' not in norm.__dict__
当前评论区没有形成足够清晰的争议点或结论,后续有更多讨论时会体现在这里。
缓存依赖 data_ptr 和 version,若参数被就地修改但版本未变(不常见),可能导致错误结果。测试覆盖了 add 操作,grad 模式完全跳过缓存,风险较低。
仅影响扩散模型推理路径,训练无影响。性能提升微小(e2e 0.1%),但代码更清晰,为类似模式提供参考。
参与讨论