diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loader.py index bdd3c4822243..9cf0c1b929d9 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loader.py @@ -46,6 +46,20 @@ logger = init_logger(__name__) +class skip_init_modules: + def __enter__(self): + # Save originals + self._orig_reset = {} + for cls in (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d): + self._orig_reset[cls] = cls.reset_parameters + cls.reset_parameters = lambda self: None # skip init + + def __exit__(self, exc_type, exc_value, traceback): + # Restore originals + for cls, orig in self._orig_reset.items(): + cls.reset_parameters = orig + + class ComponentLoader(ABC): """Base class for loading a specific type of model component.""" @@ -287,7 +301,7 @@ def load_model( ) with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]): - with target_device: + with target_device, skip_init_modules(): architectures = getattr(model_config, "architectures", []) model_cls, _ = ModelRegistry.resolve_model_cls(architectures) model = model_cls(model_config) @@ -454,7 +468,7 @@ def load(self, model_path: str, server_args: ServerArgs, *args): with set_default_torch_dtype( PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] - ): + ), skip_init_modules(): vae_cls, _ = ModelRegistry.resolve_model_cls(class_name) vae = vae_cls(vae_config).to(target_device)