diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py index 970a0e7290a6..6e1f2e96fb18 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py @@ -33,6 +33,7 @@ ) from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE logger = init_logger(__name__) @@ -300,16 +301,26 @@ def _load_delight_model(self, server_args: ServerArgs): local_path = None if local_path and os.path.exists(local_path): + # Resolve precision from config with a simple fallback for CPU/MPS + dit_dtype = PRECISION_TO_TYPE.get( + getattr(self.config, "dit_precision", "fp16"), torch.float16 + ) + if self.device.type in ("cpu", "mps") and dit_dtype in ( + torch.float16, + torch.bfloat16, + ): + # Avoid half/bfloat on CPU/MPS to be safe + dit_dtype = torch.float32 pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( local_path, - torch_dtype=torch.float16, + torch_dtype=dit_dtype, safety_checker=None, ) pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( pipeline.scheduler.config ) pipeline.set_progress_bar_config(disable=True) - self._delight_pipeline = pipeline.to(self.device, torch.float16) + self._delight_pipeline = pipeline.to(self.device, dit_dtype) logger.info("Delight model loaded successfully") else: logger.warning( @@ -347,6 +358,7 @@ def _run_delight(self, image): image = self._delight_pipeline( prompt=self.config.delight_prompt, + negative_prompt=getattr(self.config, "delight_negative_prompt", ""), image=image, generator=torch.manual_seed(42), height=512, @@ -559,10 +571,28 @@ def _do_load_paint(self, server_args: ServerArgs) -> None: else: raise FileNotFoundError(f"No VAE weights in {vae_dir}") self.vae.load_state_dict(state_dict) - self.vae = self.vae.to(device=self.device, dtype=torch.float16).eval() + # Resolve VAE/DiT dtypes from config with simple CPU/MPS fallback + vae_dtype = PRECISION_TO_TYPE.get( + getattr(self.config, "vae_precision", "fp32"), torch.float32 + ) + if self.device.type in ("cpu", "mps") and vae_dtype in ( + torch.float16, + torch.bfloat16, + ): + vae_dtype = torch.float32 + dit_dtype = PRECISION_TO_TYPE.get( + getattr(self.config, "dit_precision", "fp16"), torch.float16 + ) + if self.device.type in ("cpu", "mps") and dit_dtype in ( + torch.float16, + torch.bfloat16, + ): + dit_dtype = torch.float32 + + self.vae = self.vae.to(device=self.device, dtype=vae_dtype).eval() self.transformer = UNet2p5DConditionModel.from_pretrained( os.path.join(local_path, "unet"), - torch_dtype=torch.float16, + torch_dtype=dit_dtype, ).to(self.device) self.is_turbo = bool(getattr(self.config, "paint_turbo_mode", False)) sched_path = os.path.join(local_path, "scheduler", "scheduler_config.json")