Skip to content

[Feat] FP8 quantization support for LongCat-Image and LongCat-Image-Edit#2633

Open
lcukyfuture wants to merge 7 commits into
vllm-project:mainfrom
lcukyfuture:feat/fp8-longcat-image
Open

[Feat] FP8 quantization support for LongCat-Image and LongCat-Image-Edit#2633
lcukyfuture wants to merge 7 commits into
vllm-project:mainfrom
lcukyfuture:feat/fp8-longcat-image

Conversation

@lcukyfuture
Copy link
Copy Markdown
Contributor

@lcukyfuture lcukyfuture commented Apr 9, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Add FP8 quantization support to LongCat-Image and LongCat-Image-Edit pipelines, following the unified quantization framework introduced in #1764.

Test Plan

1 * NVIDIA RTX6000 Ada (48G)
1024×1024, 50 steps, seed=42, LPIPS

Test Result

LongCat-Image (text-to-image)

Config Avg Time Speedup Memory (GiB) Mem Reduction Mean LPIPS
BF16 baseline 19.87s 1.00x 33.46 (ref)
fp8 13.24s 1.50x 29.56 12% 0.0767
fp8 + proj_out 15.08s 1.32x 30.46 9% 0.0192
BF16 baseline fp8 fp8 + skip proj_out
image image image

LongCat-Image-Edit (image editing)

Config Avg Time Speedup Mean LPIPS
BF16 baseline 41.49s 1.00x (ref)
fp8 30.35s 1.37x 0.0162
BF16 baseline fp8
Base image image
Edit image image

Findings

LongCat's bottleneck is its unquantizable text encoder(Qwen2.5_VL).


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Add FP8 quantization support to LongCat-Image and LongCat-Image-Edit
pipelines, following the unified quantization framework introduced in
vllm-project#1764.

Changes:
- Replace plain `nn.Linear` layers in `LongCatImageTransformer2DModel`
  with quantization-aware vLLM linear layers (`ReplicatedLinear`,
  `QKVParallelLinear`, `RowParallelLinear`, `ColumnParallelLinear`)
  and propagate `quant_config` through `FeedForward`,
  `LongCatImageAttention`, `LongCatImageTransformerBlock`, and
  `LongCatImageSingleTransformerBlock`
- Pass `quant_config=od_config.quantization_config` to the transformer
  in both `LongCatImagePipeline` and `LongCatImageEditPipeline`
- Fix `load_weights` in both pipelines to include VAE and text encoder
  parameters in the returned loaded-weights set
- Fix `TypeError`: `LongCatImageSingleTransformerBlock.__init__` was
  receiving an unsupported `prefix` keyword argument, causing a crash
  on startup with any quantization config

Signed-off-by: lcukyfuture <zlf994478451@outlook.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

Signed-off-by: lcukyfuture <zlf994478451@outlook.com>
@lcukyfuture lcukyfuture force-pushed the feat/fp8-longcat-image branch 2 times, most recently from 7512391 to d9340fa Compare April 9, 2026 08:50
…-image

Signed-off-by: lcukyfuture <zlf994478451@outlook.com>
@lcukyfuture lcukyfuture force-pushed the feat/fp8-longcat-image branch from d9340fa to 90c7ebe Compare April 9, 2026 08:51
lcukyfuture and others added 2 commits April 9, 2026 17:32
Signed-off-by: Lingfeng Zhang <48312954+lcukyfuture@users.noreply.github.com>
Signed-off-by: lcukyfuture <zlf994478451@outlook.com>

self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.proj_out = ReplicatedLinear(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your own numbers show quantizing this final proj_out regresses LPIPS from 0.0192 to 0.0767 (~4x worse) for only a ~12% speed gain. Flux keeps the final proj_out as nn.Linear for the same reason — let's match that here and skip quantization on this layer by default, rather than making users discover the ignored_layers flag.

help="Task type: t2i (text-to-image), t2v (text-to-video), or image_edit (image editing).",
)
parser.add_argument(
"--quantization",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropping nargs="+" is a breaking change — the earlier docstring example (--quantization fp8 int8 bitsandbytes) no longer works. Can you keep multi-method support?

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: FP8 quantization support for LongCat-Image and LongCat-Image-Edit

Good work overall. The quantization wiring follows the established pattern from Flux and other models, the benchmark results are solid (1.50x speedup at acceptable LPIPS), and the test coverage additions are appropriate. I have a few issues to flag, one of which is likely to cause problems with the ignored_layers feature.


Issue 1 (Medium): Incomplete prefix propagation for quantization layer matching

In LongCatImageSingleTransformerBlock, the ReplicatedLinear layers use bare prefixes:

self.proj_mlp = ReplicatedLinear(..., prefix="proj_mlp")
self.proj_out = ReplicatedLinear(..., prefix="proj_out")

But in the Flux reference implementation, full hierarchical prefixes are passed:

self.proj_mlp = ReplicatedLinear(..., prefix=f"{prefix}.proj_mlp")
self.proj_out = ReplicatedLinear(..., prefix=f"{prefix}.proj_out")

And the parent model passes indexed prefixes like prefix=f"single_transformer_blocks.{i}" when creating blocks.

The LongCat PR does not pass prefix to LongCatImageTransformerBlock or LongCatImageSingleTransformerBlock, and internally uses bare names. This means:

  • Basic FP8 (quantize all linears) works fine, as confirmed by the benchmark.
  • The ignored_layers feature advertised in the benchmark script (e.g., {"method":"fp8","ignored_layers":["proj_out"]}) may behave incorrectly, since the quantization framework matches layers by their full prefix path, not bare names.

Similarly, in LongCatImageAttention, the QKVParallelLinear and RowParallelLinear use bare prefix="to_qkv", prefix="to_out", etc.

Recommendation: Propagate the prefix parameter from the transformer model down through each block and sub-module, following the Flux pattern, so that ignored_layers matching works correctly. The benchmark PR description shows fp8 + skip proj_out as a tested config, so this should be functional.


Issue 2 (Nit): Benchmark switches from torch.cuda.max_memory_allocated() to pynvml process-wide memory

The _get_gpu_memory_gib() function uses pynvml.nvmlDeviceGetMemoryInfo(handle).used, which reports device-wide GPU memory usage (all processes), not just the current process. The previous code used torch.cuda.max_memory_allocated() which is process-scoped and reports peak allocation.

These measure fundamentally different things. On a shared GPU, the pynvml number will include memory from other processes. Also, the old metric was peak memory; the new one is instantaneous memory at the time of the call (post-generation), which may miss peak usage during inference.

The test file (test_quantization_quality.py) still correctly uses torch.cuda.max_memory_allocated() for its _generate_image_edit function, which is good. But the benchmark script's numbers in the PR description (memory column) may be slightly misleading.

This is fine if the benchmark is intended to run on a dedicated GPU, but worth a comment in the code.


Issue 3 (Minor): load_weights change marks VAE/text_encoder params as "loaded" without actually loading them

loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()}
loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()}

This follows the pattern from pipeline_z_image.py and pipeline_flux2_klein.py, so it's consistent with the codebase convention. Just confirming: these weights are loaded separately via from_pretrained() in __init__, so marking them as loaded prevents the weight loader from warning about unloaded parameters. This is correct.


Minor notes:

  • The .contiguous() calls added before quantized linear layers (lines 215, 232, 333-334 in the transformer) are a reasonable defensive measure for FP8 kernels that require contiguous inputs. This matches what other models do.
  • The benchmark refactoring from multi-quantization loop to single-quantization is a simplification that reduces complexity. The removal of the "Multiple quantization methods" example from the docstring is consistent.
  • Test configs use num_inference_steps=20 which is good for CI speed while still being meaningful for quality checks.

Summary: The core FP8 quantization integration is correct and well-tested. The main actionable item is fixing the prefix propagation (Issue 1) to ensure ignored_layers works as intended. The rest is solid.

@lcukyfuture
Copy link
Copy Markdown
Contributor Author

@lishunyang12 , Thanks for your review. I have some important things to take care of in these two weeks, and I will fix these issues later(before the end of next week). Sorry for the delay.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

please explain which part do you quantize

@lcukyfuture
Copy link
Copy Markdown
Contributor Author

Quantize Layers:

LongCat Component Quantized Layers
Transformer input projections context_embedder, x_embedder
Dual-stream attention to_qkv, add_kv_proj, to_out, to_add_out
Dual-stream FFN ff.w_in, ff.w_out, ff_context.w_in, ff_context.w_out
Single-stream attention attn.to_qkv
Single-stream MLP proj_mlp, proj_out

Signed-off-by: lcukyfuture <zlf994478451@outlook.com>
Signed-off-by: lcukyfuture <zlf994478451@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quantization Code related to quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants