From 29782dff3c35c528d87e474bb1ca006b37470786 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 15 Oct 2025 16:06:12 -0700 Subject: [PATCH 1/7] disable graph partition in custom op Signed-off-by: Boyuan Feng --- vllm/model_executor/layers/fused_moe/layer.py | 26 ++++++++++--------- vllm/model_executor/utils.py | 24 +++++++++++++++++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9b117f3b5d41..559ebcfa462c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -46,7 +46,7 @@ QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import disable_graph_partition, set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up @@ -1900,17 +1900,19 @@ def select_experts( if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - ) + + with disable_graph_partition(): + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) elif e_score_correction_bias is not None: diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 38cd230082f8..b76d69d01b05 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utils for model executor.""" +import contextlib import copy from typing import Any @@ -83,3 +84,26 @@ def get_moe_expert_mapping( if child_map is not None: return child_map() return [] + + +@contextlib.contextmanager +def disable_graph_partition(): + """Context manager to disable inductor graph partition. + This is used to avoid nested cudagraph capture. + + Example: + 1. We apply torch.compile directly on some ops (e.g., grouped_topk) wrapped + in custom ops. Inductor graph partition applies cudagraph within the custom op. + 2. At the same time, we compile the model which uses these custom ops. Inductor + graph partition also wraps each graph partition with CUDAGraph. Some partitions + may include custom ops, which has already been applied cudagraph. This leads to + nested cudagraph which is not supported. + + This context manager should be wrapped around torch.compile calls within custom ops + to avoid the nested cudagraph capture.""" + old_val = torch._inductor.config.graph_partition + try: + torch._inductor.config.graph_partition = False + yield + finally: + torch._inductor.config.graph_partition = old_val From 8e08521954a23d604212741bc314a527db3b3d1e Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 15 Oct 2025 17:17:45 -0700 Subject: [PATCH 2/7] rewrite as decorator Signed-off-by: Boyuan Feng --- .../layers/fused_moe/fused_moe.py | 2 ++ vllm/model_executor/layers/fused_moe/layer.py | 26 ++++++++--------- vllm/model_executor/utils.py | 29 +++++++++++++------ 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9f66e47dcb96..8a27b8cd98d4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -46,6 +46,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.utils import disable_inductor_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer @@ -1126,6 +1127,7 @@ def fused_topk_bias( # This is used by the Deepseek-V2 and Deepseek-V3 model +@disable_inductor_graph_partition @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 559ebcfa462c..9b117f3b5d41 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -46,7 +46,7 @@ QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.utils import disable_graph_partition, set_weight_attrs +from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up @@ -1900,19 +1900,17 @@ def select_experts( if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - - with disable_graph_partition(): - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) elif e_score_correction_bias is not None: diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index b76d69d01b05..604060fabf01 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utils for model executor.""" -import contextlib import copy from typing import Any @@ -86,9 +85,8 @@ def get_moe_expert_mapping( return [] -@contextlib.contextmanager -def disable_graph_partition(): - """Context manager to disable inductor graph partition. +def disable_inductor_graph_partition(func): + """Decorator to disable inductor graph partition. This is used to avoid nested cudagraph capture. Example: @@ -100,10 +98,23 @@ def disable_graph_partition(): nested cudagraph which is not supported. This context manager should be wrapped around torch.compile calls within custom ops - to avoid the nested cudagraph capture.""" - old_val = torch._inductor.config.graph_partition - try: + to avoid the nested cudagraph capture. + + Expected Usage: + @disable_inductor_graph_partition + @torch.compile() + def op_eager_code(...): + ... + + Note that `@disable_inductor_graph_partition` should be applied before + `@torch.compile()` + """ + + def wrapper(*args, **kwargs): + old_val = torch._inductor.config.graph_partition torch._inductor.config.graph_partition = False - yield - finally: + out = func(*args, **kwargs) torch._inductor.config.graph_partition = old_val + return out + + return wrapper From 04aadb3b2f4304b74bf55b1f2180ef08e503f01c Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 15 Oct 2025 17:22:05 -0700 Subject: [PATCH 3/7] nit Signed-off-by: Boyuan Feng --- vllm/model_executor/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 604060fabf01..d5ff63a00f53 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -7,6 +7,8 @@ import torch +from vllm.utils import is_torch_equal_or_newer + def set_random_seed(seed: int) -> None: from vllm.platforms import current_platform @@ -117,4 +119,4 @@ def wrapper(*args, **kwargs): torch._inductor.config.graph_partition = old_val return out - return wrapper + return wrapper if is_torch_equal_or_newer("2.9.0.dev") else func From d5d36c343a19253698ff97caa271049e5aabc728 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 15 Oct 2025 17:22:53 -0700 Subject: [PATCH 4/7] Update vllm/model_executor/utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: Boyuan Feng --- vllm/model_executor/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index d5ff63a00f53..f630d2b39ed9 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -108,8 +108,8 @@ def disable_inductor_graph_partition(func): def op_eager_code(...): ... - Note that `@disable_inductor_graph_partition` should be applied before - `@torch.compile()` + Note that `@disable_inductor_graph_partition` should be applied on top of + `torch.compile()` """ def wrapper(*args, **kwargs): From 0ab71757c25fec18a5f6fb65b95fd9461e42b2e3 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 15 Oct 2025 17:49:17 -0700 Subject: [PATCH 5/7] lint Signed-off-by: Boyuan Feng --- vllm/model_executor/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index f630d2b39ed9..73a5838db445 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -108,7 +108,7 @@ def disable_inductor_graph_partition(func): def op_eager_code(...): ... - Note that `@disable_inductor_graph_partition` should be applied on top of + Note that `@disable_inductor_graph_partition` should be applied on top of `torch.compile()` """ From 6f9339a13a9cff8bca1ae588c4086069f261d2ea Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 15 Oct 2025 17:58:07 -0700 Subject: [PATCH 6/7] use torch.compile options Signed-off-by: Boyuan Feng --- .../layers/fused_moe/fused_moe.py | 8 ++-- vllm/model_executor/utils.py | 37 ------------------- 2 files changed, 5 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8a27b8cd98d4..7949204d6226 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme -from vllm.model_executor.utils import disable_inductor_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer @@ -1127,8 +1126,11 @@ def fused_topk_bias( # This is used by the Deepseek-V2 and Deepseek-V3 model -@disable_inductor_graph_partition -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +@torch.compile( + dynamic=True, + backend=current_platform.simple_compile_backend, + options={"graph_partition": False}, +) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 73a5838db445..38cd230082f8 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -7,8 +7,6 @@ import torch -from vllm.utils import is_torch_equal_or_newer - def set_random_seed(seed: int) -> None: from vllm.platforms import current_platform @@ -85,38 +83,3 @@ def get_moe_expert_mapping( if child_map is not None: return child_map() return [] - - -def disable_inductor_graph_partition(func): - """Decorator to disable inductor graph partition. - This is used to avoid nested cudagraph capture. - - Example: - 1. We apply torch.compile directly on some ops (e.g., grouped_topk) wrapped - in custom ops. Inductor graph partition applies cudagraph within the custom op. - 2. At the same time, we compile the model which uses these custom ops. Inductor - graph partition also wraps each graph partition with CUDAGraph. Some partitions - may include custom ops, which has already been applied cudagraph. This leads to - nested cudagraph which is not supported. - - This context manager should be wrapped around torch.compile calls within custom ops - to avoid the nested cudagraph capture. - - Expected Usage: - @disable_inductor_graph_partition - @torch.compile() - def op_eager_code(...): - ... - - Note that `@disable_inductor_graph_partition` should be applied on top of - `torch.compile()` - """ - - def wrapper(*args, **kwargs): - old_val = torch._inductor.config.graph_partition - torch._inductor.config.graph_partition = False - out = func(*args, **kwargs) - torch._inductor.config.graph_partition = old_val - return out - - return wrapper if is_torch_equal_or_newer("2.9.0.dev") else func From c1dfad6042f9a2203af100a18e3792dc3f5ecf49 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 15 Oct 2025 18:05:06 -0700 Subject: [PATCH 7/7] nit Signed-off-by: Boyuan Feng --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 ++- vllm/model_executor/utils.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7949204d6226..f07781452907 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -46,6 +46,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme +from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer @@ -1129,7 +1130,7 @@ def fused_topk_bias( @torch.compile( dynamic=True, backend=current_platform.simple_compile_backend, - options={"graph_partition": False}, + options=maybe_disable_graph_partition(current_platform.simple_compile_backend), ) def grouped_topk( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 38cd230082f8..5ffee6cb8d8b 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -7,6 +7,8 @@ import torch +from vllm.utils import is_torch_equal_or_newer + def set_random_seed(seed: int) -> None: from vllm.platforms import current_platform @@ -83,3 +85,10 @@ def get_moe_expert_mapping( if child_map is not None: return child_map() return [] + + +def maybe_disable_graph_partition(current_backend: str) -> dict[str, bool]: + if current_backend == "inductor" and is_torch_equal_or_newer("2.9.0.dev"): + return {"graph_partition": False} + else: + return {}