Skip to content

[Quantization] Enable FP8 online quantization for Z-image text encoder#1338

Merged
hsliuustc0106 merged 27 commits into
vllm-project:mainfrom
Isotr0py:refine-diffusion-loader
Apr 30, 2026
Merged

[Quantization] Enable FP8 online quantization for Z-image text encoder#1338
hsliuustc0106 merged 27 commits into
vllm-project:mainfrom
Isotr0py:refine-diffusion-loader

Conversation

@Isotr0py
Copy link
Copy Markdown
Member

@Isotr0py Isotr0py commented Feb 11, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Test Plan

python examples/offline_inference/text_to_image/text_to_image.py --model /mnt/data0/LLM/Z-Image-Turbo/ --width 512 --height 512 --tensor-parallel-size 2

Test Result

Main branch

Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:01<00:03,  1.78s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:03<00:01,  1.96s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:04<00:00,  1.57s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:04<00:00,  1.65s/it]

[Stage-0] INFO 02-11 22:49:28 [diffusers_loader.py:256] Loading weights took 5.02 seconds
[Stage-0] INFO 02-11 22:49:28 [diffusers_loader.py:256] Loading weights took 5.15 seconds
[Stage-0] INFO 02-11 22:49:29 [diffusion_model_runner.py:102] Model loading took 10.7946 GiB and 10.276793 seconds

PR

Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:01<00:03,  1.80s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:03<00:01,  1.96s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:04<00:00,  1.54s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:04<00:00,  1.64s/it]

Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:01<00:02,  1.42s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:02<00:01,  1.32s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00,  1.26it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00,  1.06it/s]

[Stage-0] INFO 02-11 22:42:52 [diffusers_loader.py:257] Loading weights took 7.75 seconds
[Stage-0] INFO 02-11 22:42:52 [diffusers_loader.py:257] Loading weights took 7.88 seconds
[Stage-0] INFO 02-11 22:42:53 [diffusion_model_runner.py:102] Model loading took 7.7586 GiB and 11.094385 seconds

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py marked this pull request as ready for review February 12, 2026 15:41
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d37a560ffb

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/models/z_image/pipeline_z_image.py Outdated
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work enabling FP8 for the Z-image text encoder! The create_transformers_model + recursive_replace_linear pattern is clean and reusable. The weight loader changes to support model.safetensors.index.json are a practical fix too.

Left a couple of small comments below.

filter(lambda f: file_exists(model_name_or_path, f, revision=revision), possible_index_files)
)
assert len(available_index_file) <= 1, (
f"Multiple index files found in {model_name_or_path} with subfolder {subfolder}: {available_index_file}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering about the assert len(available_index_file) <= 1 here. If a model repo somehow ships both diffusion_pytorch_model.safetensors.index.json and model.safetensors.index.json in the same subfolder, this will crash with an AssertionError and no user-friendly message.

Could we either:

  1. Use if len(...) > 1: raise ValueError(...) so the error survives python -O, or
  2. Pick one with a defined priority (e.g., prefer diffusion_pytorch_model over model) and log a warning?

Option 2 might be more resilient since we can't always control what model authors upload.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model repo somehow ships both diffusion_pytorch_model.safetensors.index.json and model.safetensors.index.json in the same subfolder

I think this is a quite rare case for diffusion pipeline with multiple components., especially diffusion_pytorch_model.safetensors.index.json and model.safetensors.index.json are two different style index file for diffusers and transformers respectively.

I prefer to choose option 1 for now before we actually encountered the index file mixing case.


def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice utility! One small thing I noticed: recursive_replace_linear always sets style = "replicate" for every nn.Linear in the model. This works correctly for FP8 quantization today, but if this utility is later reused for tensor-parallel text encoders, we'd need per-layer style selection.

Would it be worth accepting an optional style_map: dict[str, Style] parameter (defaulting to None = all replicate) to make this future-proof? Not a blocker at all -- just thinking about reusability since the function name suggests general-purpose use.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if this utility is later reused for tensor-parallel text encoders, we'd need per-layer style selection.

Would it be worth accepting an optional style_map: dict[str, Style] parameter (defaulting to None = all replicate) to make this future-proof?

We can reuse tp_plan from Transformers model like vLLM's Transformers backend, but I would like to leave it to a following PR because it can make things quite complicated:
https://github.com/vllm-project/vllm/blob/bebfe55b1c17c2e0fedb1b402df1dddfc1a04684/vllm/model_executor/models/transformers/base.py#L285-L296

return loader.load_weights(weights)
loaded_weights = loader.load_weights(weights)
loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()}
return loaded_weights
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was curious about this: we're adding all VAE parameter names to loaded_weights so the weight loader doesn't complain about "unloaded" weights, but the VAE was already loaded via AutoencoderKL.from_pretrained() above. This makes sense as a workaround.

But I noticed that self.vae is loaded with from_pretrained (which uses HF's default dtype and device handling), while the text encoder now goes through create_transformers_model (which uses od_config.dtype and meta init). Could there be a dtype mismatch between the two if od_config.dtype differs from the default? Probably fine in practice since VAE is typically float32 anyway, but wanted to flag it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let's replace .from_pretrained with diffusion loader to load vae weights as well then.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could there be a dtype mismatch between the two if od_config.dtype differs from the default? Probably fine in practice since VAE is typically float32 anyway, but wanted to flag it.

Latents is usually casted to vae's dtype before decoding,, so I think dtype mismatch won't be a critical issue here:

if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Comment on lines +195 to +196
vae_config = AutoencoderKL.load_config(model, subfolder="vae", local_files_only=local_files_only)
self.vae = AutoencoderKL.from_config(vae_config).to(self._execution_device)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the FP8 kernel is not unsuitable for vae, so let's not convert it with vllm quantization layer for now:

[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] Error executing method 'generate'. This might cause issues in distributed execution.
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] Traceback (most recent call last):
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_worker.py", line 674, in execute_method
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return func(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_worker.py", line 163, in generate
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self.execute_model(request, self.od_config)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_worker.py", line 185, in execute_model
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self.model_runner.execute_model(req)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return func(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_model_runner.py", line 196, in execute_model
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     output = self.pipeline.forward(req)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/models/z_image/pipeline_z_image.py", line 667, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     image = self.vae.decode(latents, return_dict=False)[0]
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return method(self, *args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 237, in decode
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     decoded = self._decode(z).sample
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]               ^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 208, in _decode
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     dec = self.decoder(z)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]           ^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/vae.py", line 298, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     sample = self.mid_block(sample, latent_embeds)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 745, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     hidden_states = attn(hidden_states, temb=temb)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 605, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self.processor(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 2740, in __call__
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     query = attn.to_q(hidden_states)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]             ^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/linear.py", line 413, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     output = self.quant_method.apply(self, x, bias)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/fp8.py", line 501, in apply
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return apply_fp8_marlin_linear(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py", line 69, in apply_fp8_marlin_linear
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     output = ops.marlin_gemm(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]              ^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/_custom_ops.py", line 1246, in marlin_gemm
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return torch.ops._C.marlin_gemm(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]   File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]     return self._op(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678]            ^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] RuntimeError: A.stride(1) is not 1

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed up on the latest changes -- the VAE loader migration and ValueError fix look good, left a few more nits.

Comment thread vllm_omni/diffusion/model_loader/diffusers_loader.py Outdated
Comment thread vllm_omni/diffusion/model_loader/diffusers_loader.py
Comment thread vllm_omni/diffusion/models/utils.py
@Gaohan123
Copy link
Copy Markdown
Collaborator

@Isotr0py Please resolve conflicts. Thanks!

@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 17, 2026
…ader

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py
Copy link
Copy Markdown
Member Author

@lishunyang12 Can we merge this PR in v0.20.0?

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Apr 24, 2026
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@hsliuustc0106 hsliuustc0106 merged commit ac66282 into vllm-project:main Apr 30, 2026
6 of 8 checks passed
NumberWan pushed a commit to NumberWan/vllm-omni that referenced this pull request Apr 30, 2026
vllm-project#1338)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: NumberWan <wantszkin2003@gmail.com>
xiaohajiayou pushed a commit to xiaohajiayou/vllm-omni that referenced this pull request Apr 30, 2026
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
BeatSeat pushed a commit to BeatSeat/vllm-omni that referenced this pull request May 2, 2026
sphinxkkkbc pushed a commit to sphinxkkkbc/vllm-omni that referenced this pull request May 4, 2026
vllm-project#1338)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants