diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index c0d3f908149f..20f0f6886e1b 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -8,7 +8,7 @@ import torch from torch.library import Library -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, tag_cudagraph_unsafe # Shared library for all compilation test operations # Using "silly" namespace to match existing test expectations @@ -55,6 +55,7 @@ def silly_attention_fake( return +# Register the unified attention operation # Register the unified attention operation direct_register_custom_op( op_name="attention", @@ -62,5 +63,5 @@ def silly_attention_fake( mutates_args=["out"], fake_impl=silly_attention_fake, target_lib=silly_lib, - tags=(torch._C.Tag.cudagraph_unsafe,), + tags=tag_cudagraph_unsafe, ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 929c3b6a4906..023eb20f18e8 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -34,14 +34,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform -from vllm.utils import GiB_bytes, direct_register_custom_op +from vllm.utils import GiB_bytes, direct_register_custom_op, tag_cudagraph_unsafe logger = init_logger(__name__) USE_XFORMERS_OPS = None -try: - tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,) -except AttributeError: - tag_cudagraph_unsafe = () # type: ignore[assignment] def check_xformers_availability(): diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index c8da83047a40..fcbd8b1225de 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3395,6 +3395,14 @@ def length_from_prompt_token_ids_or_embeds( return prompt_token_len +if is_torch_equal_or_newer("2.9.0.dev"): + from vllm.platforms import current_platform + tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, + ) if current_platform.is_cuda_alike() else () +else: + tag_cudagraph_unsafe = () # type: ignore[assignment] + + @contextlib.contextmanager def set_env_var(key, value): old = os.environ.get(key)