[Quantization] Enable FP8 online quantization for Z-image text encoder#1338
Conversation
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d37a560ffb
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
lishunyang12
left a comment
There was a problem hiding this comment.
Nice work enabling FP8 for the Z-image text encoder! The create_transformers_model + recursive_replace_linear pattern is clean and reusable. The weight loader changes to support model.safetensors.index.json are a practical fix too.
Left a couple of small comments below.
| filter(lambda f: file_exists(model_name_or_path, f, revision=revision), possible_index_files) | ||
| ) | ||
| assert len(available_index_file) <= 1, ( | ||
| f"Multiple index files found in {model_name_or_path} with subfolder {subfolder}: {available_index_file}" |
There was a problem hiding this comment.
I was wondering about the assert len(available_index_file) <= 1 here. If a model repo somehow ships both diffusion_pytorch_model.safetensors.index.json and model.safetensors.index.json in the same subfolder, this will crash with an AssertionError and no user-friendly message.
Could we either:
- Use
if len(...) > 1: raise ValueError(...)so the error survivespython -O, or - Pick one with a defined priority (e.g., prefer
diffusion_pytorch_modelovermodel) and log a warning?
Option 2 might be more resilient since we can't always control what model authors upload.
There was a problem hiding this comment.
model repo somehow ships both
diffusion_pytorch_model.safetensors.index.jsonandmodel.safetensors.index.jsonin the same subfolder
I think this is a quite rare case for diffusion pipeline with multiple components., especially diffusion_pytorch_model.safetensors.index.json and model.safetensors.index.json are two different style index file for diffusers and transformers respectively.
I prefer to choose option 1 for now before we actually encountered the index file mixing case.
|
|
||
| def _recursive_replace(module: nn.Module, prefix: str): | ||
| for child_name, child_module in module.named_children(): | ||
| new_module = child_module |
There was a problem hiding this comment.
Nice utility! One small thing I noticed: recursive_replace_linear always sets style = "replicate" for every nn.Linear in the model. This works correctly for FP8 quantization today, but if this utility is later reused for tensor-parallel text encoders, we'd need per-layer style selection.
Would it be worth accepting an optional style_map: dict[str, Style] parameter (defaulting to None = all replicate) to make this future-proof? Not a blocker at all -- just thinking about reusability since the function name suggests general-purpose use.
There was a problem hiding this comment.
but if this utility is later reused for tensor-parallel text encoders, we'd need per-layer style selection.
Would it be worth accepting an optional style_map: dict[str, Style] parameter (defaulting to None = all replicate) to make this future-proof?
We can reuse tp_plan from Transformers model like vLLM's Transformers backend, but I would like to leave it to a following PR because it can make things quite complicated:
https://github.com/vllm-project/vllm/blob/bebfe55b1c17c2e0fedb1b402df1dddfc1a04684/vllm/model_executor/models/transformers/base.py#L285-L296
| return loader.load_weights(weights) | ||
| loaded_weights = loader.load_weights(weights) | ||
| loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} | ||
| return loaded_weights |
There was a problem hiding this comment.
I was curious about this: we're adding all VAE parameter names to loaded_weights so the weight loader doesn't complain about "unloaded" weights, but the VAE was already loaded via AutoencoderKL.from_pretrained() above. This makes sense as a workaround.
But I noticed that self.vae is loaded with from_pretrained (which uses HF's default dtype and device handling), while the text encoder now goes through create_transformers_model (which uses od_config.dtype and meta init). Could there be a dtype mismatch between the two if od_config.dtype differs from the default? Probably fine in practice since VAE is typically float32 anyway, but wanted to flag it.
There was a problem hiding this comment.
I see, let's replace .from_pretrained with diffusion loader to load vae weights as well then.
There was a problem hiding this comment.
Could there be a dtype mismatch between the two if od_config.dtype differs from the default? Probably fine in practice since VAE is typically float32 anyway, but wanted to flag it.
Latents is usually casted to vae's dtype before decoding,, so I think dtype mismatch won't be a critical issue here:
vllm-omni/vllm_omni/diffusion/models/z_image/pipeline_z_image.py
Lines 634 to 640 in efbe411
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
| vae_config = AutoencoderKL.load_config(model, subfolder="vae", local_files_only=local_files_only) | ||
| self.vae = AutoencoderKL.from_config(vae_config).to(self._execution_device) |
There was a problem hiding this comment.
Actually, the FP8 kernel is not unsuitable for vae, so let's not convert it with vllm quantization layer for now:
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] Error executing method 'generate'. This might cause issues in distributed execution.
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] Traceback (most recent call last):
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_worker.py", line 674, in execute_method
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return func(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_worker.py", line 163, in generate
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self.execute_model(request, self.od_config)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_worker.py", line 185, in execute_model
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self.model_runner.execute_model(req)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return func(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/worker/diffusion_model_runner.py", line 196, in execute_model
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] output = self.pipeline.forward(req)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/vllm_omni/diffusion/models/z_image/pipeline_z_image.py", line 667, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] image = self.vae.decode(latents, return_dict=False)[0]
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return method(self, *args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 237, in decode
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] decoded = self._decode(z).sample
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 208, in _decode
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] dec = self.decoder(z)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/vae.py", line 298, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] sample = self.mid_block(sample, latent_embeds)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 745, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] hidden_states = attn(hidden_states, temb=temb)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 605, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self.processor(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 2740, in __call__
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] query = attn.to_q(hidden_states)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self._call_impl(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return forward_call(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/linear.py", line 413, in forward
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] output = self.quant_method.apply(self, x, bias)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/fp8.py", line 501, in apply
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return apply_fp8_marlin_linear(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py", line 69, in apply_fp8_marlin_linear
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] output = ops.marlin_gemm(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/vllm/_custom_ops.py", line 1246, in marlin_gemm
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return torch.ops._C.marlin_gemm(
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] File "/home/mozf/develop-projects/vllm-omni/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] return self._op(*args, **kwargs)
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] ^^^^^^^^^^^^^^^^^^^^^^^^^
[Stage-0] ERROR 02-21 21:20:29 [diffusion_worker.py:678] RuntimeError: A.stride(1) is not 1
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
lishunyang12
left a comment
There was a problem hiding this comment.
Followed up on the latest changes -- the VAE loader migration and ValueError fix look good, left a few more nits.
|
@Isotr0py Please resolve conflicts. Thanks! |
…ader Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
|
@lishunyang12 Can we merge this PR in v0.20.0? |
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
…ader Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
vllm-project#1338) Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: NumberWan <wantszkin2003@gmail.com>
vllm-project#1338) Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
vllm-project#1338) Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
vllm-project#1338) Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
vllm-project#1338) Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: sphinxkkkbc <binchengkang8@gmail.com>
vllm-project#1338) Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
Test Plan
Test Result
Main branch
PR
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)