[Model] Add TP-aware MistralEncoder for FLUX.2-dev TP#2465
[Model] Add TP-aware MistralEncoder for FLUX.2-dev TP#2465vraiti wants to merge 1 commit intovllm-project:mainfrom
Conversation
|
@nuclearwu PTAL |
alex-jw-brooks
left a comment
There was a problem hiding this comment.
Thanks for opening this! Some thoughts
|
|
||
| Follows the same pattern as T5EncoderModel: uses vLLM's parallel linear layers | ||
| for tensor parallelism but simple scaled_dot_product_attention instead of | ||
| PagedAttention, so it can be used as a standalone encoder without VllmConfig. |
There was a problem hiding this comment.
Curious about this, is this because of the way it is initialized for convenience, or is there another reason?
There was a problem hiding this comment.
My understanding is that since FLUX.2 is a diffusion model, and MistralEncoder runs as a subcomponent of FLUX.2, MistralEncoder doesn't actually have access to the vLLM engine to do PagedAttention.
It may be possible to run MistralEncoder as a separate stage with vLLM, but I'm not sure how much performance that would buy us
82b32bf to
fd14948
Compare
|
@vraiti Request changes. I found two blocking issues:
Inline comment 1 This change swaps
Inline comment 2
That means rotary embeddings depend on uninitialized memory unless a checkpoint happens to provide this |
Fixed. The bigger problem is that the full Mistral model is much larger than just the encoder (obviously) and will put significant undue memory pressure on a system that doesn't otherwise want to provide upsampling. So, I added
Fixed. |
|
@vraiti Could you please add more test plans to verify modification? If that were the case, it looks good to me. |
@nuclearwu done |
|
I think the design is still inconsistent here. This PR conceptually replaces Because of that, I don't think the right fix is to keep branching on two different text encoder In other words: if What do you think about it? @hsliuustc0106 @wtomin |
lishunyang12
left a comment
There was a problem hiding this comment.
Review: [Model] Add TP-aware MistralEncoder for FLUX.2-dev TP
Overall this is a well-structured PR that follows the established pattern (similar to T5EncoderModel) for implementing a TP-aware encoder. The architecture is correct and the test coverage is solid. Approving with a few minor suggestions.
What works well
- TP sharding is correct: QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear, and VocabParallelEmbedding are the right vLLM primitives. The head count division (
num_heads // tp_size,max(1, num_kv_heads // tp_size)) and thenum_kv_groupsderivation are correct. - Weight loading: The stacked_params_mapping for q/k/v -> qkv_proj and gate/up -> gate_up_proj matches vLLM conventions. The skip list for
lm_head,vision_tower, andmulti_modal_projectorcorrectly drops weights not needed by the encoder-only use case. - RoPE implementation: Standard rotate-half RoPE with configurable theta. The precomputation in
MistralRotaryEmbedding.forward()is clean. - Causal masking: Correctly uses a causal (upper-triangular) mask since Mistral is decoder-only even when used as an encoder. The padding mask combination is also correct.
- Pipeline integration: The
enable_caption_upsamplingflag cleanly separates the two code paths - fullMistral3ForConditionalGeneration(needed for caption upsampling/generation) vs. the lightweight TP-sharded encoder. TheAutoWeightsLoaderdelegation throughweights_sourcesis properly set up. - Tests: Good coverage of config parsing (plain/nested/defaults), RoPE correctness, weight loading (QKV, gate_up, skip patterns), and module structure.
Minor suggestions (non-blocking)
-
MistralEncoderOutputplacement (style): The output dataclass is defined after the model class that uses it. While this works due tofrom __future__ import annotations, consider moving it aboveMistralEncoderModelfor readability. -
enable_caption_upsamplingdiscoverability:od_config.model_config.get("enable_caption_upsampling", False)introduces a new config key. It would be good to document this in the model's configuration reference or at least add a comment explaining when a user would want to set this to True (i.e., when they have enough VRAM for the full Mistral3 model and want caption upsampling). -
Trailing comma in
weights_sources(line 361 in the diff): The addition of a trailing comma to the existingComponentSourceentry is a good practice but is a cosmetic-only change that could be split out to keep the diff focused. Very minor. -
Type annotation on
_get_mistral_3_small_prompt_embeds: The type hintMistralEncoderModel | Mistral3ForConditionalGenerationis correct for the union, but the parameter order putsMistralEncoderModelfirst. Consider puttingMistral3ForConditionalGenerationfirst since it's the more feature-complete type (and the one used when caption upsampling is enabled). This is purely a style preference.
Verified correctness
output_hidden_states=Trueproduces the right tuple structure (embedding + per-layer + final norm), compatible with thehidden_states_layers=(10, 20, 30)indexing in_get_mistral_3_small_prompt_embeds.use_cache=Falsepassed from the pipeline is harmlessly absorbed by**kwargs.AutoWeightsLoaderwill correctly delegate toMistralEncoderModel.load_weightsafter stripping thetext_encoder.prefix.
LGTM. Nice work enabling FLUX.2-dev on smaller GPU clusters.
567cf60 to
e9d9be7
Compare
0462e69 to
d667d4c
Compare
@nuclearwu fixed. MistralEncoderModel now supports generate() as well, fully replacing Mistral3ForConditionalGeneration in the FLUX.2-dev pipeline |
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: vraiti <vraiti@redhat.com>
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
Addresses #2464
Currently, FLUX.2-dev uses transformers.Mistral3ForConditionalGeneration to perform text encoding. When trying to run FLUX.2-dev on smaller devices (e.g. A100-40G) fails even with TP because the 24 GiB encoder must be fully replicated on every device. Implement a custom MistralEncoder class that can be sharded with tensor parallelism to enable running FLUX.2-dev on clusters of smaller devices.
Additionally, FLUX.2-dev Flux2Pipeline.transformer (a Flux2Transformer2DModel) was being initalized with guidance=None, which seemed to be causing garbage output. Initalize instead with a tensor of
guidance_scale.Test Plan
Reproduced on 4x A100-SXM4-40GB.
1. Unit tests (29 tests, CPU-only, TP=2 mocked)
2. Online serving — server startup
Stage config (
flux2.yaml):3. Online serving — single request via curl
4. E2E latency benchmark — without upsampling (5 prompts)
5. Prompt upsampling (T2I and I2I via Python Omni API)
The
/v1/images/generationsHTTP endpoint does not supportcaption_upsample_temperatureor I2I. Use the Python Omni API:6. Image-to-image editing
python examples/offline_inference/image_to_image/image_edit.py \ --model black-forest-labs/FLUX.2-dev \ --image output.png \ --prompt "replace the llama in the image with a lion" \ --output output-edit.png \ --seed 42 \ --tensor-parallel-size 4 \ --num-inference-steps 50 \ --guidance-scale 4.07. Memory profiling
Checked at idle (after model load) and at peak (during inference).
Test Result
Reproduced on 4x A100-SXM4-40GB, eager mode, TP=4.
1. Unit tests
26/26 passed (CPU-only, TP=2 mocked):
Test classes: TestConfigParsing (3), TestRoPEInitialization (4), TestWeightLoading (4), TestModelStructure (2), TestKVCache (6), TestRoPEOffset (2), TestGenerate (4), TestComputeLogits (1).
2. Online serving
Server started successfully. All 4 workers loaded in 13.3s. MistralEncoderModel loaded 242 params (10.66 GiB total param memory per rank).
3. Single request via curl
Returned valid PNG image (status 200).
4. E2E latency — without upsampling (1024x1024, 50 steps)
5. Prompt upsampling (T2I and I2I)
Generated 4 images: T2I baseline, T2I upsampled, I2I baseline, I2I upsampled. TP=4, seed 42, guidance_scale 4.0, 50 steps,
caption_upsample_temperature=0.15.T2I — baseline (left) vs upsampled (right). The upsampled prompt adds environmental detail (warm sunlight, garden with flowers) that is absent from the baseline:
I2I — baseline (left) vs upsampled (right). The I2I system message is intentionally conservative (50-80 words), so differences are subtle:
Upsampling adds ~1.3s mean overhead per request (~3.4%).
6. Image-to-image editing
Input: 1024x1024 PNG, prompt: "replace the bunny in the image with a dog", seed 42, 50 steps, guidance_scale 4.0, TP=4.
Generation time: 62.49s. Peak GPU memory: 29.76 GB reserved, 29.37 GB allocated.
Input:

Output:

7. Memory profiling
Total model size: 26.97 GiB per rank. Peak overhead during inference: ~2.3 GB per GPU.
7. CFG parallel
Not tested. CFG parallel with cfg_parallel_size=2 requires TP x CFG GPUs. On our 4x A100-SXM4-40GB node, TP=4 is needed to fit the model (~27.46 GiB per rank), leaving no room for a second CFG group. Larger GPUs (e.g. A100-80GB) could run TP=2/CFG=2 on 4 GPUs.
Sample image
Logs
Sample log (successful):
flux2-success.log
Sample log (old, error):
flux2-error.log
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)