Skip to content

DRAFT: Mistral large 3 Extended Blackwell Support#29884

Closed
jdebache wants to merge 12 commits intovllm-project:mainfrom
jdebache:mistral_large_3_blackwell
Closed

DRAFT: Mistral large 3 Extended Blackwell Support#29884
jdebache wants to merge 12 commits intovllm-project:mainfrom
jdebache:mistral_large_3_blackwell

Conversation

@jdebache
Copy link
Copy Markdown
Contributor

@jdebache jdebache commented Dec 2, 2025

Purpose

Improve performance and support of Mistral Large 3 on Blackwell.

Details

  • Added per-tensor scaled Triton configs for MoE (for Eagle draft model)
  • (WIP) Added per-block scaled Triton configs for MoE (for target model)
  • Added support for Flashinfer TRTLLM per-tensor scaled FP8 MoE kernels (for Eagle draft model)
  • (WIP) Added support for Flashinfer TRTLLM per-block scaled FP8 MoE kernels (for target model)
  • Fixed Llama4 routing for FP4 MoE
  • Added support for Mistral config format in benchmarks/kernels/benchmark_moe.py
  • Added support for Mistral tokenizer in vllm/benchmarks/throughput.py

Best Performance Usage

FP8 Checkpoint on DGX B200 (8 devices)

The FP8 model will fit on a single node.
At low concurrencies, deploy with TP8:

VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL=1 \
VLLM_ATTENTION_BACKEND=FLASHINFER_MLA \
VLLM_USE_FLASHINFER_MOE_FP8=1 \
VLLM_FLASHINFER_MOE_BACKEND=latency \
vllm serve /models/Mistral-Large-3-675B-Instruct-2512 \
-tp 8 --kv-cache-dtype fp8 --no-enable-prefix-caching \
--config-format mistral --load-format mistral --tokenizer-mode mistral \
--max_model_len 65536 --max_num_seqs 512 --limit-mm-per-prompt '{"image":10}' \
--tool-call-parser mistral --enable-auto-tool-choice

At higher concurrencies (128 concurrent requests and above), deploy with DP8 and expert parallelism:

VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL=1 \
VLLM_ATTENTION_BACKEND=FLASHINFER_MLA \
VLLM_USE_FLASHINFER_MOE_FP8=1 \
VLLM_FLASHINFER_MOE_BACKEND=latency \
vllm serve /models/Mistral-Large-3-675B-Instruct-2512 \
--data-parallel-size 8 --enable-expert-parallel --kv-cache-dtype fp8 --no-enable-prefix-caching \
--config-format mistral --load-format mistral --tokenizer-mode mistral \
--max_model_len 65536 --max_num_seqs 512 --limit-mm-per-prompt '{"image":10}' \
--tool-call-parser mistral --enable-auto-tool-choice

NVFP4

For NVFP4 checkpoints add the following to leverage the optimized kernels from Flashinfer:

VLLM_NVFP4_GEMM_BACKEND=cutlass VLLM_USE_FLASHINFER_MOE_FP4=1

With a version of Flashinfer >0.5.3:

VLLM_NVFP4_GEMM_BACKEND=flashinfer-cudnn VLLM_USE_FLASHINFER_MOE_FP4=1

A bug in the auto-tuner fixed recently (flashinfer-ai/flashinfer#2140) allows using flashinfer-cudnn.

GB200 P/D Disaggregated Dynamo Deployment

There are two options to set up a Dynamo P/D disaggregated deployment of this model. The first one is available immediately and relies on the processing pipeline of Dynamo. The second is pending a PR on Dynamo to enable delegating pre-processing to the vLLM backend.

For compatibility with ToT vLLM, you might need to include some changes that are not currently in upstream Dynamo:

With Dynamo request processing

Start by copying config.json from Ministral to your model directory.

With delegated request processing

Pending on some changes (TODO LINK) in Dynamo, you will be able to skip the file-copying step above.

Next steps

We have identified further optimizations which will be part of other PRs:

  • FP8 context attention
  • Routing GEMM optimization

Contributors

@dbari, @DanBlanaru, @evezhier, @hypdeb

@mergify mergify bot added ci/build performance Performance-related issues nvidia v1 labels Dec 2, 2025
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems there are a few things mixed into this one. Do you think we could prioritize the critical perf features like kernel support and tuned configs?

@@ -11,3 +11,4 @@ torchaudio==2.9.0
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.3
nvtx==0.2.13
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isn't a big deal but do we need this dep?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've added some nvtx ranges to the gpu_model_runner.py to make profiling easier, which uses nvtx. This can be split into another PR if it is desirable.

Comment on lines +2987 to +3008

# Count context tokens per request
context_requests = 0
decode_requests = 0

for req in scheduler_output.scheduled_new_reqs:
context_len = len(req.prompt_token_ids) if req.prompt_token_ids else 0
num_computed = req.num_computed_tokens
if num_computed < context_len:
context_requests += 1
else:
decode_requests += 1
# For cached requests
for i, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids):
context_len = self.requests[req_id].num_prompt_tokens
num_computed = scheduler_output.scheduled_cached_reqs.num_computed_tokens[i]

if num_computed < context_len:
context_requests += 1
else:
decode_requests += 1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this in another PR? We don't want to eat this cost when profiling isn't enabled

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

@jdebache
Copy link
Copy Markdown
Contributor Author

jdebache commented Dec 2, 2025

It seems there are a few things mixed into this one. Do you think we could prioritize the critical perf features like kernel support and tuned configs?

I'll start by splitting the NVTX stuff out, see how it looks after this.

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks straightforward to me, thanks for splitting it up!

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Dec 5, 2025
@mgoin mgoin marked this pull request as ready for review December 5, 2025 01:15
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 5, 2025
@jdebache jdebache force-pushed the mistral_large_3_blackwell branch from 4e91a18 to 9f5af28 Compare December 6, 2025 11:29
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hypdeb.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 10, 2025
@jdebache jdebache force-pushed the mistral_large_3_blackwell branch from b240199 to 8d0de6b Compare December 10, 2025 07:51
@mergify mergify bot removed the needs-rebase label Dec 10, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 10, 2025

Hi @hypdeb, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

jdebache and others added 8 commits December 15, 2025 01:36
Signed-off-by: jdebache <jdebache@nvidia.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Signed-off-by: Dan Blanaru <48605845+DanBlanaru@users.noreply.github.com>
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Signed-off-by: Julien Debache <jdebache@cpu-0002.cm.cluster>
Signed-off-by: jdebache <jdebache@nvidia.com>
@jdebache jdebache force-pushed the mistral_large_3_blackwell branch from 8d0de6b to 0c16645 Compare December 15, 2025 09:36
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
…ock_scale_moe

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@mergify mergify bot added the deepseek Related to DeepSeek models label Dec 18, 2025
@dbari
Copy link
Copy Markdown
Contributor

dbari commented Dec 18, 2025

I added support for Flashinfer TRTLLM per-block scaled FP8 MoE kernels for the target model, you can cross it off the list. However, some fixes will need to be merged to flashinfer and get a version number before they can be integrated here.

  • Wait for flashinfer PR#2238 to be merged
  • Update flashinfer version to the next after the PR is merged

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 18, 2025

Hi @hypdeb, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hypdeb.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@jdebache
Copy link
Copy Markdown
Contributor Author

Hey @mgoin, we are closing this one in favour of #33174, for starter. Other changes will come later.

@jdebache jdebache closed this Jan 27, 2026
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Jan 27, 2026
@jdebache jdebache deleted the mistral_large_3_blackwell branch March 25, 2026 10:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models needs-rebase nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants