Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Jun 15, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

  • Flex Attention doesn't work with tensor parallel currently because num_gpu_blocks is not updated in cache_config properly:
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/worker/gpu_model_runner.py", line 1211, in execute_model
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]     spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/worker/gpu_model_runner.py", line 706, in _prepare_inputs
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]     attn_metadata_i = (builder.build(
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]                        ^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/attention/backends/flex_attention.py", line 301, in build
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]     total_cache_tokens = (self.runner.cache_config.num_gpu_blocks *
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=2540) ERROR 06-15 05:58:54 [multiproc_executor.py:527] TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'
  • This PR fixes this issue to make Flex Attention work with TP.

Test Plan

VLLM_ATTENTION_BACKEND=FLEX_ATTENTION python examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2

Test Result

(VllmWorker rank=1 pid=5600) INFO 06-15 06:06:11 [gpu_model_runner.py:2083] Graph capturing finished in 26 secs, took 0.20 GiB
(VllmWorker rank=0 pid=5599) INFO 06-15 06:06:11 [gpu_model_runner.py:2083] Graph capturing finished in 26 secs, took 0.20 GiB
INFO 06-15 06:06:11 [core.py:173] init engine (profile, create kv cache, warmup model) took 48.53 seconds
Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 776.08it/s]
Processed prompts: 100%|████████████████████████████████████████████████████████████████████| 4/4 [00:11<00:00,  2.92s/it, est. speed input: 2.23 toks/s, output: 5.49 toks/s]
--------------------------------------------------
Prompt: 'Hello, my name is'
Generated text: ' Joel, I am a BYU lecturer and I would love to talk about your class'
--------------------------------------------------
Prompt: 'The president of the United States is'
Generated text: " stuck painting the country like it's 70 years old again, with behind-the"
--------------------------------------------------
Prompt: 'The capital of France is'
Generated text: " directly behind Germany's winning bid for elections as French business and French defence leaders fired"
--------------------------------------------------
Prompt: 'The future of AI is'
Generated text: ' predicted according to the 2020 report by a economists\nAI stand to be in the'
--------------------------------------------------

(Optional) Documentation Update

no need

Isotr0py added 2 commits June 15, 2025 01:56
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Isotr0py, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug that prevented the Flex Attention backend from functioning correctly when Tensor Parallelism is enabled. The core issue was the failure to initialize cache configuration parameters (num_gpu_blocks, num_cpu_blocks) on the worker processes in a distributed setup. The fix involves explicitly sending this configuration to the workers. Additionally, a temporary measure is included to handle a separate CUDA error encountered with the compiled block mask function under TP, falling back to a non-compiled version.

Highlights

  • Bug Fix: Addresses a TypeError in the Flex Attention backend when used with Tensor Parallelism (TP > 1). The error occurred because cache_config.num_gpu_blocks was None on the workers.
  • Cache Configuration Propagation: Ensures that the calculated num_gpu_blocks and num_cpu_blocks are correctly propagated to the worker processes via RPC calls during engine initialization.
  • Conditional Block Mask Creation: Introduces a temporary workaround in the Flex Attention backend to use the non-compiled create_block_mask function when TP > 1, due to a noted CUDA error with the compiled version (create_block_mask_compiled) in this configuration. A FIXME comment is added for future investigation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added v1 tpu Related to Google TPUs labels Jun 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses two key issues preventing FlexAttention from functioning correctly with tensor parallelism (TP).

  1. It resolves a TypeError caused by num_gpu_blocks being None in the cache_config on worker processes. This is fixed by introducing an RPC call (initialize_cache) from the EngineCore to all workers, ensuring their cache_config is properly updated with the number of GPU and CPU blocks.
  2. It includes a workaround for a CUDA error: an illegal memory access that occurs when using the compiled version of create_block_mask with TP > 1. The fix involves conditionally using the non-compiled create_block_mask function in such scenarios.

The changes are well-targeted, and the provided test results demonstrate successful operation with TP enabled. The codebase is updated in EngineCore, GPU/TPU workers, and the FlexAttention backend itself.

Comment on lines +240 to +241
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The FIXME comment clearly explains the issue with create_block_mask_compiled when the tensor parallel world size is greater than 1. To ensure this is addressed in the future, consider creating a GitHub issue to track this underlying CUDA error if one doesn't exist already. This would help in eventually enabling the compiled version universally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The full trace back of the illegal memory error:

Log
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/attention/backends/flex_attention.py", line 262, in __post_init__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.block_mask = self.build_block_mask()
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                       ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/attention/backends/flex_attention.py", line 246, in build_block_mask
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return create_block_mask_fn(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 824, in create_block_mask
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     def create_block_mask(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1201, in forward
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(full_args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 328, in runtime_wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     all_outs = call_func_at_runtime_with_args(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = normalize_as_list(f(args))
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                             ^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     outs = compiled_fn(args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(runtime_args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 460, in __call__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return self.current_callable(inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1372, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 387, in deferred_cudagraphify
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 448, in cudagraphify
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return manager.add_function(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2308, in add_function
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn, fn(inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                ^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 1997, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self._run(new_inputs, function_id)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2104, in _run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self.run_eager(new_inputs, function_id)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2269, in run_eager
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return node.run(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 668, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self.wrapped_function.model(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/root/.cache/vllm/torch_compile_cache/26b5568570/rank_0_0/inductor_cache/4o/c4osf7wcdszj5dy7kaxakhrrucni4ac5aiyysa63j3fmz37p6jxn.py", line 561, in call
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     triton_per_fused__to_copy_sum_7.run(buf18, buf22, 5718, triton_per_fused__to_copy_sum_7_r0_numel, stream=stream0)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 909, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.autotune_to_one_config(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 763, in autotune_to_one_config
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     timings = self.benchmark_all_configs(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 738, in benchmark_all_configs
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     launcher: self.bench(launcher, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 616, in bench
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return benchmarker.benchmark_gpu(kernel_call, rep=40)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(self, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 243, in benchmark_gpu
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     _callable()
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 601, in kernel_call
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     launcher(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "<string>", line 5, in launcher
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 444, in __call__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.launch(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527] RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

@drisspg Any idea about this error?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @zou3519 for torch.compile related issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey so I actually just noticed this too, this was not the cause until pretty recently, going to create an issue + tracking for this

@Isotr0py Isotr0py requested a review from houseroad June 15, 2025 06:29
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can add some unittest, that will be great.

buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}

def initialize_cache(self, num_gpu_blocks: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds more like "setting_cache_size" instead of initialize_cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, cache_config's num_gpu_blocks and num_cpu_blocks are updated in initialize_cache for worker in v0, which is a base class method:

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError

vllm/vllm/worker/worker.py

Lines 312 to 325 in 3d330c4

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(
num_gpu_blocks, self.cache_config.block_size,
self.cache_config.is_attention_free,
self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

Although this method not used by v1 before this PR, I think using this method shared by v0 can keep the worker implementation consistent.

Comment on lines +240 to +241
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @zou3519 for torch.compile related issue.


vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why only TP + FlexAttention needs this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because FlexAttention needs num_gpu_blocks for calculation while other attention backends don't need it.

Not sure if this is intended, but in V1, only engine core's cache_config has updated num_gpu_blocks, and worker in different process (TP situation) won't have num_gpu_blocks updated without collective_rpc calling.

Therefore, in distributed inference, worker's num_gpu_blocks is still None, which caused the error in PR description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could you check if we need to add some condition to only call this function if tp > 1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For single-process, we use UniProcExecutor instead of MultiprocExecutor:

elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor

Given that it also has collective_rpc impplemented properly, it's safe to call collective_rpc as well, especially we only update cache_config here, though it has been done in unified process with previous lines before:

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]

Have checked TP=1 can still work currently.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Left one comment. :-)

@Isotr0py Isotr0py enabled auto-merge (squash) June 16, 2025 03:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 16, 2025
@Isotr0py Isotr0py merged commit 1173804 into vllm-project:main Jun 16, 2025
78 checks passed
@Isotr0py Isotr0py deleted the flex-tp branch June 16, 2025 12:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants