Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import PRECISION_TO_TYPE

logger = init_logger(__name__)

Expand Down Expand Up @@ -300,16 +301,26 @@ def _load_delight_model(self, server_args: ServerArgs):
local_path = None

if local_path and os.path.exists(local_path):
# Resolve precision from config with a simple fallback for CPU/MPS
dit_dtype = PRECISION_TO_TYPE.get(
getattr(self.config, "dit_precision", "fp16"), torch.float16
)
if self.device.type in ("cpu", "mps") and dit_dtype in (
torch.float16,
torch.bfloat16,
):
# Avoid half/bfloat on CPU/MPS to be safe
dit_dtype = torch.float32
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
local_path,
torch_dtype=torch.float16,
torch_dtype=dit_dtype,
safety_checker=None,
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config
)
pipeline.set_progress_bar_config(disable=True)
self._delight_pipeline = pipeline.to(self.device, torch.float16)
self._delight_pipeline = pipeline.to(self.device, dit_dtype)
logger.info("Delight model loaded successfully")
else:
logger.warning(
Expand Down Expand Up @@ -347,6 +358,7 @@ def _run_delight(self, image):

image = self._delight_pipeline(
prompt=self.config.delight_prompt,
negative_prompt=getattr(self.config, "delight_negative_prompt", ""),
image=image,
generator=torch.manual_seed(42),
height=512,
Expand Down Expand Up @@ -559,10 +571,28 @@ def _do_load_paint(self, server_args: ServerArgs) -> None:
else:
raise FileNotFoundError(f"No VAE weights in {vae_dir}")
self.vae.load_state_dict(state_dict)
self.vae = self.vae.to(device=self.device, dtype=torch.float16).eval()
# Resolve VAE/DiT dtypes from config with simple CPU/MPS fallback
vae_dtype = PRECISION_TO_TYPE.get(
getattr(self.config, "vae_precision", "fp32"), torch.float32
)
if self.device.type in ("cpu", "mps") and vae_dtype in (
torch.float16,
torch.bfloat16,
):
vae_dtype = torch.float32
dit_dtype = PRECISION_TO_TYPE.get(
getattr(self.config, "dit_precision", "fp16"), torch.float16
)
if self.device.type in ("cpu", "mps") and dit_dtype in (
torch.float16,
torch.bfloat16,
):
dit_dtype = torch.float32

self.vae = self.vae.to(device=self.device, dtype=vae_dtype).eval()
self.transformer = UNet2p5DConditionModel.from_pretrained(
os.path.join(local_path, "unet"),
torch_dtype=torch.float16,
torch_dtype=dit_dtype,
).to(self.device)
self.is_turbo = bool(getattr(self.config, "paint_turbo_mode", False))
sched_path = os.path.join(local_path, "scheduler", "scheduler_config.json")
Expand Down
Loading