Skip to content

[BugFix] Respect configured precision in Qwen layered path#21980

Open
jy-song-hub wants to merge 12 commits into
sgl-project:mainfrom
bytedance-iaas:fix/qwen-layered-precision-config
Open

[BugFix] Respect configured precision in Qwen layered path#21980
jy-song-hub wants to merge 12 commits into
sgl-project:mainfrom
bytedance-iaas:fix/qwen-layered-precision-config

Conversation

@jy-song-hub
Copy link
Copy Markdown
Contributor

Motivation

SGLang already exposes precision control via pipeline config:

  • vae_precision
  • text_encoder_precisions

However, the layered Qwen image path did not respect these settings. Specifically, in qwen_image.py, the configured precisions were not passed into QwenImageLayeredBeforeDenoisingStage; in qwen_image_layered.py, the implementation instead hardcoded torch.bfloat16 for the VAE, text encoder, and input image tensor.

As a result, this path ignored user-configured precision and implicitly assumed bf16, creating a mismatch with the configuration surface and potentially causing issues on devices or backends without reliable bf16 support.

Modifications

  • Import PRECISION_TO_TYPE in python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py.
  • Pass configured vae_dtype and text_encoder_dtype from QwenImageLayeredPipeline.create_pipeline_stages() into QwenImageLayeredBeforeDenoisingStage.
  • Update QwenImageLayeredBeforeDenoisingStage.__init__() in
    python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py to accept vae_dtype and text_encoder_dtype.
  • Replace the hardcoded torch.bfloat16 cast for the VAE with vae.to(dtype=vae_dtype).
  • Replace the hardcoded torch.bfloat16 cast for the text encoder with .to(dtype=self.text_encoder_dtype).
  • Replace the hardcoded torch.bfloat16 cast for the preprocessed image tensor with image.to(dtype=self.vae_dtype).

Accuracy Tests

Tested via unit test. To avoid expanding the PR surface area, the unittest is not included in this PR. For the unittest details, see the code snippet in the comment.

The test was run separately and validates the following old-fails / new-passes behavior:

  • old behavior fails because the layered pipeline does not pass configured precision into the stage
  • old behavior fails because QwenImageLayeredBeforeDenoisingStage hardcodes torch.bfloat16 for the VAE and text encoder
  • old behavior fails because the forward path hardcodes torch.bfloat16 for the preprocessed image tensor
  • new behavior passes because the stage receives and uses configured vae_precision and text_encoder_precisions

Speed Tests and Profiling

This change is a correctness fix for dtype selection in the layered Qwen path. It is not intended as a performance optimization, so no dedicated speed benchmark or profiling result is included.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Pass configured VAE and text encoder dtypes into the layered Qwen stage so it no longer hardcodes bf16 for the VAE, text encoder, and preprocessed image tensor.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added the diffusion SGLang Diffusion label Apr 2, 2026
@jy-song-hub
Copy link
Copy Markdown
Contributor Author

The following unit test was used to validate the correctness of this change. As the code modification is small while the test is relatively lengthy, it is included here instead of in the codebase to avoid unnecessary bloat.

import importlib
import sys
import types
import unittest
from types import SimpleNamespace
from unittest.mock import Mock, patch

import torch

LAYERED_MODULE = (
    "sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages."
    "qwen_image_layered"
)
PIPELINE_MODULE = "sglang.multimodal_gen.runtime.pipelines.qwen_image"


class _DummyVaeImageProcessor:
    def __init__(self, *args, **kwargs):
        pass

    def resize(self, image, _height, _width):
        return image

    def preprocess(self, _image, _height, _width):
        return torch.ones((1, 4, 4, 4), dtype=torch.float32)


class _FakeModule:
    def __init__(self, *, initial_dtype=torch.float32, supports_bf16=True):
        self.dtype = initial_dtype
        self.supports_bf16 = supports_bf16
        self.to_history = []
        self.temperal_downsample = [1]
        self.z_dim = 4

    def to(self, device=None, dtype=None, **kwargs):
        if isinstance(device, torch.dtype) and dtype is None:
            dtype = device
            device = None

        self.to_history.append((device, dtype))

        if dtype is not None:
            if dtype == torch.bfloat16 and not self.supports_bf16:
                raise RuntimeError("bfloat16 is unsupported on this backend")
            self.dtype = dtype

        return self


class _FakeImage:
    def __init__(self, size=(64, 64)):
        self.size = size

    def convert(self, _mode):
        return self


def _fake_diffusers_modules():
    fake_diffusers = types.ModuleType("diffusers")
    fake_image_processor = types.ModuleType("diffusers.image_processor")
    fake_image_processor.VaeImageProcessor = _DummyVaeImageProcessor
    fake_torch_utils = types.ModuleType("diffusers.utils.torch_utils")
    fake_torch_utils.randn_tensor = (
        lambda shape, generator=None, device=None, dtype=None: torch.zeros(
            shape, device=device, dtype=dtype
        )
    )
    fake_utils = types.ModuleType("diffusers.utils")
    fake_utils.torch_utils = fake_torch_utils

    fake_diffusers.image_processor = fake_image_processor
    fake_diffusers.utils = fake_utils

    return {
        "diffusers": fake_diffusers,
        "diffusers.image_processor": fake_image_processor,
        "diffusers.utils": fake_utils,
        "diffusers.utils.torch_utils": fake_torch_utils,
    }


def _make_transformers_module(*, supports_bf16):
    fake_transformers = types.ModuleType("transformers")

    class _FakeTextEncoder(_FakeModule):
        @classmethod
        def from_pretrained(cls, model_path, subfolder=None):
            del model_path, subfolder
            return cls(supports_bf16=supports_bf16)

    fake_transformers.Qwen2_5_VLForConditionalGeneration = _FakeTextEncoder
    return fake_transformers


def _import_module(module_name):
    with patch.dict(sys.modules, _fake_diffusers_modules()):
        sys.modules.pop(LAYERED_MODULE, None)
        sys.modules.pop(PIPELINE_MODULE, None)
        sys.modules.pop(module_name, None)
        return importlib.import_module(module_name)


class TestQwenImageLayeredPrecision(unittest.TestCase):
    def test_pipeline_passes_configured_precisions_to_layered_stage(self):
        module = _import_module(PIPELINE_MODULE)

        pipeline = module.QwenImageLayeredPipeline.__new__(
            module.QwenImageLayeredPipeline
        )
        captured = {}

        pipeline.add_stage = lambda stage: captured.setdefault("stage", stage)
        pipeline.get_module = lambda name: name
        pipeline.model_path = "fake-model"
        pipeline.add_standard_timestep_preparation_stage = Mock()
        pipeline.add_standard_denoising_stage = Mock()
        pipeline.add_standard_decoding_stage = Mock()

        server_args = SimpleNamespace(
            pipeline_config=SimpleNamespace(
                vae_precision="fp16", text_encoder_precisions=("fp32",)
            )
        )

        with patch.object(
            module,
            "QwenImageLayeredBeforeDenoisingStage",
            side_effect=lambda **kwargs: kwargs,
        ):
            module.QwenImageLayeredPipeline.create_pipeline_stages(
                pipeline, server_args
            )

        self.assertEqual(captured["stage"]["vae_dtype"], torch.float16)
        self.assertEqual(captured["stage"]["text_encoder_dtype"], torch.float32)

    def test_stage_init_uses_passed_vae_and_text_encoder_dtypes(self):
        module = _import_module(LAYERED_MODULE)
        fake_transformers = _make_transformers_module(supports_bf16=False)
        fake_global_server_args = SimpleNamespace(comfyui_mode=False)

        with patch.dict(sys.modules, {"transformers": fake_transformers}), patch(
            "sglang.multimodal_gen.runtime.pipelines_core.stages.base.get_global_server_args",
            return_value=fake_global_server_args,
        ), patch.object(module, "get_local_torch_device", return_value="cpu"):
            stage = module.QwenImageLayeredBeforeDenoisingStage(
                vae=_FakeModule(supports_bf16=False),
                tokenizer=object(),
                processor=object(),
                transformer=SimpleNamespace(config=SimpleNamespace(in_channels=4)),
                scheduler=object(),
                model_path="fake-model",
                vae_dtype=torch.float16,
                text_encoder_dtype=torch.float32,
            )

        self.assertEqual(stage.vae.dtype, torch.float16)
        self.assertEqual(stage.text_encoder.dtype, torch.float32)

    def test_forward_uses_vae_dtype_for_preprocessed_image(self):
        module = _import_module(LAYERED_MODULE)

        stage = module.QwenImageLayeredBeforeDenoisingStage.__new__(
            module.QwenImageLayeredBeforeDenoisingStage
        )
        stage.vae_scale_factor = 1
        stage.vae_dtype = torch.float32
        stage.image_processor = _DummyVaeImageProcessor()
        stage.transformer = SimpleNamespace(config=SimpleNamespace(in_channels=4))
        stage.scheduler = object()
        stage.get_image_caption = Mock(return_value="caption")
        prompt_embeds = torch.ones((1, 2, 4), dtype=torch.float32)
        prompt_embeds_mask = torch.ones((1, 2), dtype=torch.long)
        stage.encode_prompt = Mock(
            side_effect=[
                (prompt_embeds, prompt_embeds_mask),
                (prompt_embeds, prompt_embeds_mask),
            ]
        )

        captured = {}

        def fake_prepare_latents(image, *args, **kwargs):
            del args, kwargs
            captured["image_dtype"] = image.dtype
            return (
                torch.ones((1, 4, 4), dtype=torch.float32),
                torch.ones((1, 4, 4), dtype=torch.float32),
            )

        stage.prepare_latents = fake_prepare_latents

        batch = SimpleNamespace(
            image_path=["fake.png"],
            num_frames=1,
            num_inference_steps=1,
            generator=torch.Generator(),
            negative_prompt="negative",
        )
        server_args = SimpleNamespace(pipeline_config=SimpleNamespace(resolution=256))

        with patch.object(
            module, "get_local_torch_device", return_value="cpu"
        ), patch.object(module, "load_image", return_value=_FakeImage()), patch.object(
            module, "retrieve_timesteps", return_value=(torch.tensor([1.0]), 1)
        ):
            result = module.QwenImageLayeredBeforeDenoisingStage.forward(
                stage, batch, server_args
            )

        self.assertIs(result, batch)
        self.assertEqual(captured["image_dtype"], torch.float32)


if __name__ == "__main__":
    unittest.main()

@jy-song-hub
Copy link
Copy Markdown
Contributor Author

@mickqian Please take a look. Thanks~

Comment thread python/sglang/multimodal_gen/runtime/pipelines/qwen_image.py
@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 4, 2026
@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@yhyang201
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@jy-song-hub
Copy link
Copy Markdown
Contributor Author

jy-song-hub commented Apr 29, 2026

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants