From 6eaa8103e2947c021a338fd198b2ee9876a78b0e Mon Sep 17 00:00:00 2001 From: "jiayi.song" Date: Thu, 2 Apr 2026 16:08:03 -0700 Subject: [PATCH] [Fix] Respect configured precision in Qwen layered path Pass configured VAE and text encoder dtypes into the layered Qwen stage so it no longer hardcodes bf16 for the VAE, text encoder, and preprocessed image tensor. --- .../runtime/pipelines/qwen_image.py | 5 +++++ .../qwen_image_layered.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py b/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py index 81695190a8f5..c61c31864f98 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py @@ -13,6 +13,7 @@ ) from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE # TODO(will): move PRECISION_TO_TYPE to better place @@ -122,6 +123,10 @@ def create_pipeline_stages(self, server_args: ServerArgs): transformer=self.get_module("transformer"), scheduler=self.get_module("scheduler"), model_path=self.model_path, + vae_dtype=PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision], + text_encoder_dtype=PRECISION_TO_TYPE[ + server_args.pipeline_config.text_encoder_precisions[0] + ], ) ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py index 2cb35deadb3f..06e33b969be9 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py @@ -111,10 +111,20 @@ def retrieve_timesteps( class QwenImageLayeredBeforeDenoisingStage(PipelineStage): def __init__( - self, vae, tokenizer, processor, transformer, scheduler, model_path + self, + vae, + tokenizer, + processor, + transformer, + scheduler, + model_path, + vae_dtype: torch.dtype, + text_encoder_dtype: torch.dtype, ) -> None: super().__init__() - self.vae = vae.to(torch.bfloat16) + self.vae = vae.to(dtype=vae_dtype) + self.vae_dtype = vae_dtype + self.text_encoder_dtype = text_encoder_dtype from transformers import Qwen2_5_VLForConditionalGeneration self.text_encoder = ( @@ -122,7 +132,7 @@ def __init__( model_path, subfolder="text_encoder" ) .to(get_local_torch_device()) - .to(torch.bfloat16) + .to(dtype=self.text_encoder_dtype) ) self.tokenizer = tokenizer self.processor = processor @@ -441,7 +451,7 @@ def forward( image, calculated_height, calculated_width ) image = image.unsqueeze(2) - image = image.to(dtype=torch.bfloat16) + image = image.to(dtype=self.vae_dtype) prompt = self.get_image_caption( prompt_image, use_en_prompt=use_en_prompt, device=device