Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
Expand Down Expand Up @@ -236,7 +237,12 @@ def final_mask_mod(

def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
return create_block_mask_compiled(
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
Comment on lines +240 to +241
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The FIXME comment clearly explains the issue with create_block_mask_compiled when the tensor parallel world size is greater than 1. To ensure this is addressed in the future, consider creating a GitHub issue to track this underlying CUDA error if one doesn't exist already. This would help in eventually enabling the compiled version universally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The full trace back of the illegal memory error:

Log
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/attention/backends/flex_attention.py", line 262, in __post_init__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.block_mask = self.build_block_mask()
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                       ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/vllm/v1/attention/backends/flex_attention.py", line 246, in build_block_mask
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return create_block_mask_fn(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/nn/attention/flex_attention.py", line 824, in create_block_mask
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     def create_block_mask(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1201, in forward
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(full_args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 328, in runtime_wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     all_outs = call_func_at_runtime_with_args(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = normalize_as_list(f(args))
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                             ^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 689, in inner_fn
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     outs = compiled_fn(args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 495, in wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(runtime_args)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 460, in __call__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return self.current_callable(inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1372, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return compiled_fn(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 387, in deferred_cudagraphify
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 448, in cudagraphify
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return manager.add_function(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2308, in add_function
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn, fn(inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]                ^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 1997, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self._run(new_inputs, function_id)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2104, in _run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self.run_eager(new_inputs, function_id)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2269, in run_eager
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return node.run(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 668, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     out = self.wrapped_function.model(new_inputs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/root/.cache/vllm/torch_compile_cache/26b5568570/rank_0_0/inductor_cache/4o/c4osf7wcdszj5dy7kaxakhrrucni4ac5aiyysa63j3fmz37p6jxn.py", line 561, in call
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     triton_per_fused__to_copy_sum_7.run(buf18, buf22, 5718, triton_per_fused__to_copy_sum_7_r0_numel, stream=stream0)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 909, in run
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.autotune_to_one_config(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 763, in autotune_to_one_config
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     timings = self.benchmark_all_configs(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 738, in benchmark_all_configs
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     launcher: self.bench(launcher, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 616, in bench
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return benchmarker.benchmark_gpu(kernel_call, rep=40)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     return fn(self, *args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 243, in benchmark_gpu
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     _callable()
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 601, in kernel_call
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     launcher(
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "<string>", line 5, in launcher
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]   File "/kaggle/working/vllm/.venv/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 444, in __call__
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527]     self.launch(*args, **kwargs)
(VllmWorker rank=0 pid=13527) ERROR 06-15 06:22:40 [multiproc_executor.py:527] RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

@drisspg Any idea about this error?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @zou3519 for torch.compile related issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey so I actually just noticed this too, this was not the cause until pretty recently, going to create an issue + tracking for this

create_block_mask_fn = (create_block_mask_compiled
if get_tensor_model_parallel_world_size() == 1
else create_block_mask)
return create_block_mask_fn(
self.mask_mod,
None,
None,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def __init__(self,

vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why only TP + FlexAttention needs this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because FlexAttention needs num_gpu_blocks for calculation while other attention backends don't need it.

Not sure if this is intended, but in V1, only engine core's cache_config has updated num_gpu_blocks, and worker in different process (TP situation) won't have num_gpu_blocks updated without collective_rpc calling.

Therefore, in distributed inference, worker's num_gpu_blocks is still None, which caused the error in PR description.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could you check if we need to add some condition to only call this function if tp > 1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For single-process, we use UniProcExecutor instead of MultiprocExecutor:

elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
elif distributed_executor_backend == "uni":
executor_class = UniProcExecutor

Given that it also has collective_rpc impplemented properly, it's safe to call collective_rpc as well, especially we only update cache_config here, though it has been done in unified process with previous lines before:

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]

Have checked TP=1 can still work currently.

args=(num_gpu_blocks, num_cpu_blocks))

self.structured_output_manager = StructuredOutputManager(vllm_config)

Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}

def initialize_cache(self, num_gpu_blocks: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds more like "setting_cache_size" instead of initialize_cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, cache_config's num_gpu_blocks and num_cpu_blocks are updated in initialize_cache for worker in v0, which is a base class method:

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError

vllm/vllm/worker/worker.py

Lines 312 to 325 in 3d330c4

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
"""
raise_if_cache_size_invalid(
num_gpu_blocks, self.cache_config.block_size,
self.cache_config.is_attention_free,
self.model_config.max_model_len,
self.parallel_config.pipeline_parallel_size)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

Although this method not used by v1 before this PR, I think using this method shared by v0 can keep the worker implementation consistent.

num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

def init_device(self):
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(
if self.model_config.seed is None:
self.model_config.seed = 0

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
Expand Down