Skip to content

[Model] Add TP-aware MistralEncoder for FLUX.2-dev TP#2465

Open
vraiti wants to merge 1 commit intovllm-project:mainfrom
vraiti:flux2-tp
Open

[Model] Add TP-aware MistralEncoder for FLUX.2-dev TP#2465
vraiti wants to merge 1 commit intovllm-project:mainfrom
vraiti:flux2-tp

Conversation

@vraiti
Copy link
Copy Markdown
Contributor

@vraiti vraiti commented Apr 2, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Addresses #2464

Currently, FLUX.2-dev uses transformers.Mistral3ForConditionalGeneration to perform text encoding. When trying to run FLUX.2-dev on smaller devices (e.g. A100-40G) fails even with TP because the 24 GiB encoder must be fully replicated on every device. Implement a custom MistralEncoder class that can be sharded with tensor parallelism to enable running FLUX.2-dev on clusters of smaller devices.

Additionally, FLUX.2-dev Flux2Pipeline.transformer (a Flux2Transformer2DModel) was being initalized with guidance=None, which seemed to be causing garbage output. Initalize instead with a tensor of guidance_scale.

Test Plan

Reproduced on 4x A100-SXM4-40GB.

1. Unit tests (29 tests, CPU-only, TP=2 mocked)

python3 -m pytest tests/diffusion/models/mistral_encoder/test_mistral_encoder_tp.py -v

2. Online serving — server startup

export HF_TOKEN=<token>
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

vllm-omni serve black-forest-labs/FLUX.2-dev \
  --omni --port 8000 \
  --stage-configs-path vllm_omni/model_executor/stage_configs/flux2.yaml

Stage config (flux2.yaml):

stage_args:
  - stage_id: 0
    stage_type: diffusion
    runtime:
      devices: "0,1,2,3"
    engine_args:
      model_arch: Flux2Pipeline
      max_num_seqs: 1
      enforce_eager: true
      trust_remote_code: true
      distributed_executor_backend: mp
      parallel_config:
        tensor_parallel_size: 4
      model_config:
        enable_caption_upsampling: true
    final_output: true
    final_output_type: image
    default_sampling_params:
      seed: 42
      num_inference_steps: 50
      guidance_scale: 4.0
      height: 1024
      width: 1024
runtime:
  enabled: true
  defaults:
    window_size: -1
    max_inflight: 1

3. Online serving — single request via curl

curl -s http://localhost:8000/v1/images/generations \
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "a lovely bunny holding a sign that says vllm-omni",
    "size": "1024x1024",
    "n": 1,
    "response_format": "b64_json"
  }' | python3 -c "import sys,json,base64; d=json.load(sys.stdin); open('output.png','wb').write(base64.b64decode(d['data'][0]['b64_json']))"

4. E2E latency benchmark — without upsampling (5 prompts)

import time, requests

url = 'http://localhost:8000/v1/images/generations'
prompts = [
    'a cat sitting on a windowsill',
    'a mountain landscape at sunset',
    'a futuristic city skyline',
    'a bowl of fresh fruit on a table',
    'an astronaut floating in space',
]
latencies = []
for i, prompt in enumerate(prompts):
    payload = {'prompt': prompt, 'size': '1024x1024', 'n': 1, 'response_format': 'b64_json'}
    t0 = time.perf_counter()
    r = requests.post(url, json=payload)
    elapsed = time.perf_counter() - t0
    latencies.append(elapsed)
    print(f'[{i+1}/5] {elapsed:6.2f}s  status={r.status_code}')
print(f'Mean: {sum(latencies)/len(latencies):.2f}s')

5. Prompt upsampling (T2I and I2I via Python Omni API)

The /v1/images/generations HTTP endpoint does not support caption_upsample_temperature or I2I. Use the Python Omni API:

"""Generate baseline + upsampled images for both T2I and I2I."""
import os, torch
from PIL import Image
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform

img = Image.open("sample-input.png").convert("RGB")
omni = Omni(model="black-forest-labs/FLUX.2-dev",
            parallel_config=DiffusionParallelConfig(tensor_parallel_size=4))
os.makedirs("outputs/upsample", exist_ok=True)

for mode, prompt, image in [
    ("t2i", "a cat sitting on a windowsill", None),
    ("i2i", "transform this into a dramatic sunset scene", [img]),
]:
    for label, temp in [("baseline", None), ("upsampled", 0.15)]:
        out = omni.generate(
            {"prompt": prompt, "multi_modal_data": {"image": image}},
            OmniDiffusionSamplingParams(
                generator=torch.Generator(
                    device=current_omni_platform.device_type).manual_seed(42),
                guidance_scale=4.0, num_inference_steps=50,
                extra_args=({"caption_upsample_temperature": temp}
                            if temp else {}),
            ),
        )
        path = f"outputs/upsample/flux2-{mode}-{label}.png"
        out[0].request_output.images[0].save(path)
        print(f"[{mode}/{label}] {os.path.abspath(path)}")

6. Image-to-image editing

python examples/offline_inference/image_to_image/image_edit.py \
    --model black-forest-labs/FLUX.2-dev \
    --image output.png \
    --prompt "replace the llama in the image with a lion" \
    --output output-edit.png \
    --seed 42 \
    --tensor-parallel-size 4 \
    --num-inference-steps 50 \
    --guidance-scale 4.0

7. Memory profiling

nvidia-smi --query-gpu=index,memory.used,memory.total --format=csv,noheader

Checked at idle (after model load) and at peak (during inference).

Test Result

Reproduced on 4x A100-SXM4-40GB, eager mode, TP=4.

1. Unit tests

26/26 passed (CPU-only, TP=2 mocked):

tests/diffusion/models/mistral_encoder/test_mistral_encoder_tp.py  26 passed

Test classes: TestConfigParsing (3), TestRoPEInitialization (4), TestWeightLoading (4), TestModelStructure (2), TestKVCache (6), TestRoPEOffset (2), TestGenerate (4), TestComputeLogits (1).

2. Online serving

Server started successfully. All 4 workers loaded in 13.3s. MistralEncoderModel loaded 242 params (10.66 GiB total param memory per rank).

3. Single request via curl

Returned valid PNG image (status 200).

4. E2E latency — without upsampling (1024x1024, 50 steps)

Request Latency
1 37.58s
2 37.76s
3 38.29s
4 38.93s
5 39.48s
Mean 38.41s
Median 38.29s

5. Prompt upsampling (T2I and I2I)

Generated 4 images: T2I baseline, T2I upsampled, I2I baseline, I2I upsampled. TP=4, seed 42, guidance_scale 4.0, 50 steps, caption_upsample_temperature=0.15.

T2I — baseline (left) vs upsampled (right). The upsampled prompt adds environmental detail (warm sunlight, garden with flowers) that is absent from the baseline:

flux2-t2i-baseline flux2-t2i-upsampled

I2I — baseline (left) vs upsampled (right). The I2I system message is intentionally conservative (50-80 words), so differences are subtle:

flux2-i2i-baseline flux2-i2i-upsampled

Upsampling adds ~1.3s mean overhead per request (~3.4%).

6. Image-to-image editing

Input: 1024x1024 PNG, prompt: "replace the bunny in the image with a dog", seed 42, 50 steps, guidance_scale 4.0, TP=4.

Generation time: 62.49s. Peak GPU memory: 29.76 GB reserved, 29.37 GB allocated.

Input:
sample-input

Output:
flux2-dev-edit-seed42

7. Memory profiling

Metric GPU 0 GPU 1 GPU 2 GPU 3
After model load 27.46 GiB 27.46 GiB 27.46 GiB 27.46 GiB
Idle (nvidia-smi) 31,603 MiB 31,603 MiB 31,603 MiB 31,603 MiB
Peak during inference 29.77 GB reserved, 29.37 GB allocated same same same

Total model size: 26.97 GiB per rank. Peak overhead during inference: ~2.3 GB per GPU.

7. CFG parallel

Not tested. CFG parallel with cfg_parallel_size=2 requires TP x CFG GPUs. On our 4x A100-SXM4-40GB node, TP=4 is needed to fit the model (~27.46 GiB per rank), leaving no room for a second CFG group. Larger GPUs (e.g. A100-80GB) could run TP=2/CFG=2 on 4 GPUs.

Sample image

llama-eating-gpus

Logs

Sample log (successful):
flux2-success.log

Sample log (old, error):
flux2-error.log


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@vraiti vraiti requested a review from hsliuustc0106 as a code owner April 2, 2026 20:52
@vraiti vraiti changed the title Fix FLUX.2-dev tensor parallelism Add TP-aware MistralEncoder for FLUX.2-dev TP Apr 2, 2026
@vraiti vraiti changed the title Add TP-aware MistralEncoder for FLUX.2-dev TP [Model] Add TP-aware MistralEncoder for FLUX.2-dev TP Apr 2, 2026
@hsliuustc0106 hsliuustc0106 requested a review from ZJY0516 April 3, 2026 14:12
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@nuclearwu PTAL

@alex-jw-brooks
Copy link
Copy Markdown
Contributor

Thanks @vraiti! The guidance issue should actually be fixed now, I saw the same a couple days ago #2433

Copy link
Copy Markdown
Contributor

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening this! Some thoughts

Comment thread vllm_omni/diffusion/models/flux2/pipeline_flux2.py Outdated
Comment thread vllm_omni/diffusion/models/mistral_encoder/mistral_encoder.py

Follows the same pattern as T5EncoderModel: uses vLLM's parallel linear layers
for tensor parallelism but simple scaled_dot_product_attention instead of
PagedAttention, so it can be used as a standalone encoder without VllmConfig.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about this, is this because of the way it is initialized for convenience, or is there another reason?

Copy link
Copy Markdown
Contributor Author

@vraiti vraiti Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that since FLUX.2 is a diffusion model, and MistralEncoder runs as a subcomponent of FLUX.2, MistralEncoder doesn't actually have access to the vLLM engine to do PagedAttention.

It may be possible to run MistralEncoder as a separate stage with vLLM, but I'm not sure how much performance that would buy us

Comment thread vllm_omni/diffusion/models/mistral_encoder/mistral_encoder.py Outdated
Comment thread vllm_omni/diffusion/models/mistral_encoder/mistral_encoder.py
Comment thread vllm_omni/diffusion/models/mistral_encoder/mistral_encoder.py
@vraiti vraiti force-pushed the flux2-tp branch 5 times, most recently from 82b32bf to fd14948 Compare April 6, 2026 18:22
@nuclearwu
Copy link
Copy Markdown
Contributor

nuclearwu commented Apr 10, 2026

@vraiti
Overall review summary

Request changes.

I found two blocking issues:

  1. This PR replaces Mistral3ForConditionalGeneration with an encoder-only MistralEncoderModel, but
    Flux2Pipeline.upsample_prompt() still calls self.text_encoder.generate(...). That regresses the existing
    caption upsampling path, including the image-conditioned path.

  2. MistralRotaryEmbedding.inv_freq is registered with torch.empty(...) and then used immediately to build
    RoPE cos/sin tables, but I don't see any initialization from config (rope_theta, etc.). HF Mistral computes
    this from config instead of relying on checkpoint loading, so this looks like undefined / incorrect rotary
    embeddings.

Inline comment 1
Target: vllm_omni/diffusion/models/flux2/pipeline_flux2.py around line 380 or line 668

This change swaps self.text_encoder from Mistral3ForConditionalGeneration to MistralEncoderModel, but
upsample_prompt() still calls self.text_encoder.generate(...) below.

MistralEncoderModel is encoder-only and does not implement generate(), so the existing caption upsampling
path will now fail at runtime. This also seems to drop the image-conditioned upsampling capability, since the
old path relied on the multimodal conditional-generation model rather than a text-only encoder.

Inline comment 2
Target: vllm_omni/diffusion/models/mistral_encoder/mistral_encoder.py around line 39

inv_freq is registered as torch.empty(...) here, but I don't see it ever being initialized from config
before it is used in forward() to build RoPE cos/sin tables.

That means rotary embeddings depend on uninitialized memory unless a checkpoint happens to provide this
buffer. HF's Mistral implementation computes inv_freq from config (rope_theta, head dim, etc.) and
typically treats it as a derived buffer rather than something that must come from weights, so I think this
needs explicit initialization here.

@vraiti
Copy link
Copy Markdown
Contributor Author

vraiti commented Apr 10, 2026

@vraiti Overall review summary

Request changes.

I found two blocking issues:

  1. This PR replaces Mistral3ForConditionalGeneration with an encoder-only MistralEncoderModel, but
    Flux2Pipeline.upsample_prompt() still calls self.text_encoder.generate(...). That regresses the existing
    caption upsampling path, including the image-conditioned path.

Fixed. Flux2Pipeline.upsample_prompt() was actually dead code since caption_upsample_temperature is never read from sampling_params. Fixing that was a simple 1-line change.

The bigger problem is that the full Mistral model is much larger than just the encoder (obviously) and will put significant undue memory pressure on a system that doesn't otherwise want to provide upsampling.

So, I added enable_caption_upsampling as a pipeline config option for FLUX.2 that causes the pipeline to use transformers.Mistral3ForConditionalGeneration instead of the new MistralEncoderModel. Now, caption_upsampling_temperature is only silently ignored if enable_caption_upsampling is false.

  1. MistralRotaryEmbedding.inv_freq is registered with torch.empty(...) and then used immediately to build
    RoPE cos/sin tables, but I don't see any initialization from config (rope_theta, etc.). HF Mistral computes
    this from config instead of relying on checkpoint loading, so this looks like undefined / incorrect rotary
    embeddings.

Fixed.

@nuclearwu
Copy link
Copy Markdown
Contributor

@vraiti Could you please add more test plans to verify modification? If that were the case, it looks good to me.

@vraiti
Copy link
Copy Markdown
Contributor Author

vraiti commented Apr 13, 2026

@vraiti Could you please add more test plans to verify modification? If that were the case, it looks good to me.

@nuclearwu done

@nuclearwu
Copy link
Copy Markdown
Contributor

nuclearwu commented Apr 14, 2026

@vraiti Could you please add more test plans and paste test results refer to #1629 or #2010

@nuclearwu
Copy link
Copy Markdown
Contributor

I think the design is still inconsistent here.

This PR conceptually replaces Mistral3ForConditionalGeneration with MistralEncoderModel as
self.text_encoder, so MistralEncoderModel should satisfy the interface contract that the pipeline
expects from self.text_encoder. Right now that is not true: the pipeline still relies on generation
behavior in upsample_prompt(), but MistralEncoderModel is only an encoder-only subset, not a full
conditional-generation model.

Because of that, I don't think the right fix is to keep branching on two different text encoder
implementations inside the pipeline. Instead, we should introduce a single text backend / wrapper
abstraction for FLUX.2 that owns both prompt encoding and prompt upsampling behavior. Internally, that
abstraction can choose whether to use the TP-aware encoder path or the full
Mistral3ForConditionalGeneration path, but the pipeline should not need to know which concrete
implementation is underneath.

In other words: if MistralEncoderModel is not intended to be a full drop-in replacement for
Mistral3ForConditionalGeneration, then it should not be exposed directly as self.text_encoder in
places where the pipeline still expects generation semantics.

What do you think about it? @hsliuustc0106 @wtomin

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: [Model] Add TP-aware MistralEncoder for FLUX.2-dev TP

Overall this is a well-structured PR that follows the established pattern (similar to T5EncoderModel) for implementing a TP-aware encoder. The architecture is correct and the test coverage is solid. Approving with a few minor suggestions.

What works well

  • TP sharding is correct: QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear, and VocabParallelEmbedding are the right vLLM primitives. The head count division (num_heads // tp_size, max(1, num_kv_heads // tp_size)) and the num_kv_groups derivation are correct.
  • Weight loading: The stacked_params_mapping for q/k/v -> qkv_proj and gate/up -> gate_up_proj matches vLLM conventions. The skip list for lm_head, vision_tower, and multi_modal_projector correctly drops weights not needed by the encoder-only use case.
  • RoPE implementation: Standard rotate-half RoPE with configurable theta. The precomputation in MistralRotaryEmbedding.forward() is clean.
  • Causal masking: Correctly uses a causal (upper-triangular) mask since Mistral is decoder-only even when used as an encoder. The padding mask combination is also correct.
  • Pipeline integration: The enable_caption_upsampling flag cleanly separates the two code paths - full Mistral3ForConditionalGeneration (needed for caption upsampling/generation) vs. the lightweight TP-sharded encoder. The AutoWeightsLoader delegation through weights_sources is properly set up.
  • Tests: Good coverage of config parsing (plain/nested/defaults), RoPE correctness, weight loading (QKV, gate_up, skip patterns), and module structure.

Minor suggestions (non-blocking)

  1. MistralEncoderOutput placement (style): The output dataclass is defined after the model class that uses it. While this works due to from __future__ import annotations, consider moving it above MistralEncoderModel for readability.

  2. enable_caption_upsampling discoverability: od_config.model_config.get("enable_caption_upsampling", False) introduces a new config key. It would be good to document this in the model's configuration reference or at least add a comment explaining when a user would want to set this to True (i.e., when they have enough VRAM for the full Mistral3 model and want caption upsampling).

  3. Trailing comma in weights_sources (line 361 in the diff): The addition of a trailing comma to the existing ComponentSource entry is a good practice but is a cosmetic-only change that could be split out to keep the diff focused. Very minor.

  4. Type annotation on _get_mistral_3_small_prompt_embeds: The type hint MistralEncoderModel | Mistral3ForConditionalGeneration is correct for the union, but the parameter order puts MistralEncoderModel first. Consider putting Mistral3ForConditionalGeneration first since it's the more feature-complete type (and the one used when caption upsampling is enabled). This is purely a style preference.

Verified correctness

  • output_hidden_states=True produces the right tuple structure (embedding + per-layer + final norm), compatible with the hidden_states_layers=(10, 20, 30) indexing in _get_mistral_3_small_prompt_embeds.
  • use_cache=False passed from the pipeline is harmlessly absorbed by **kwargs.
  • AutoWeightsLoader will correctly delegate to MistralEncoderModel.load_weights after stripping the text_encoder. prefix.

LGTM. Nice work enabling FLUX.2-dev on smaller GPU clusters.

@vraiti vraiti force-pushed the flux2-tp branch 8 times, most recently from 567cf60 to e9d9be7 Compare April 21, 2026 21:18
@vraiti vraiti force-pushed the flux2-tp branch 2 times, most recently from 0462e69 to d667d4c Compare April 22, 2026 14:06
@vraiti
Copy link
Copy Markdown
Contributor Author

vraiti commented Apr 22, 2026

I think the design is still inconsistent here.

This PR conceptually replaces Mistral3ForConditionalGeneration with MistralEncoderModel as self.text_encoder, so MistralEncoderModel should satisfy the interface contract that the pipeline expects from self.text_encoder. Right now that is not true: the pipeline still relies on generation behavior in upsample_prompt(), but MistralEncoderModel is only an encoder-only subset, not a full conditional-generation model.

@nuclearwu fixed. MistralEncoderModel now supports generate() as well, fully replacing Mistral3ForConditionalGeneration in the FLUX.2-dev pipeline

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: vraiti <vraiti@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants