From 1a846ad1e97f4bbce672fb38e364ae3b4dd4e74d Mon Sep 17 00:00:00 2001 From: root Date: Wed, 18 Mar 2026 09:46:16 -0700 Subject: [PATCH 1/7] fix fi ar fusion workspace not initialized Signed-off-by: root Signed-off-by: wzhao18 --- .../passes/fusion/allreduce_rms_fusion.py | 116 ++++++++++++------ 1 file changed, 80 insertions(+), 36 deletions(-) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index f141a7c171f7..4ea6b6f27541 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -94,6 +94,46 @@ MiB = 1024 * 1024 + def _initialize_fi_ar_workspaces( + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + group, + ) -> bool: + """Initialize FlashInfer AR workspaces. Returns True on success.""" + for workspace_init_fn in [ + initialize_fi_ar_workspace, + initialize_fi_ar_quant_workspace, + ]: + try: + workspace_init_fn( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + group=group, + ) + except Exception as e: + if "multicast" in str(e).lower(): + logger.warning_once( + "AllReduce fusion pass is disabled: flashinfer workspace " + "creation failed: %s. This is expected on GPUs without " + "NVSwitch (e.g., NVLink bridge-only or PCIe topologies). " + "Falling back to non-fused allreduce.", + str(e), + ) + else: + logger.warning_once( + "Failed to initialize FlashInfer All Reduce workspace: %s. " + "AllReduce fusion pass will be disabled.", + e, + ) + return False + return True + def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, residual: torch.Tensor, @@ -133,16 +173,40 @@ def call_trtllm_fused_allreduce_norm( # Select workspace based on pattern: quant patterns use the # trtllm quant workspace, non-quant patterns use the primary workspace. - if pattern_code in ( + is_quant_pattern = pattern_code in ( ar_fusion_patterns.kARResidualRMSNormFP8Quant, ar_fusion_patterns.kARResidualRMSNormFP4Quant, - ): - workspace = get_fi_ar_quant_workspace() - else: - workspace = get_fi_ar_workspace() - assert workspace is not None, ( - "Flashinfer workspace must be initialized when using flashinfer" ) + workspace = ( + get_fi_ar_quant_workspace() + if is_quant_pattern + else get_fi_ar_workspace() + ) + + if workspace is None: + # Workspace may not be initialized by AllReduceFusionPass if the function is + # directly loaded from torch AOT compile cache. Lazily initialize here. + _, hidden_dim = allreduce_in.shape + if not _initialize_fi_ar_workspaces( + world_size=world_size, + rank=get_tensor_model_parallel_rank(), + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=allreduce_in.dtype, + group=get_tp_group().device_group, + ): + raise RuntimeError( + "Failed to initialize FlashInfer All Reduce workspace" + ) + workspace = ( + get_fi_ar_quant_workspace() + if is_quant_pattern + else get_fi_ar_workspace() + ) + if workspace is None: + raise RuntimeError( + "FlashInfer workspace is None after initialization" + ) assert flashinfer_comm is not None if norm_out is None: norm_out = allreduce_in @@ -753,35 +817,15 @@ def __init__(self, config: VllmConfig) -> None: scope="global", ) - for workspace_init_fn in [ - initialize_fi_ar_workspace, - initialize_fi_ar_quant_workspace, - ]: - try: - workspace_init_fn( - world_size=self.tp_size, - rank=rank, - max_token_num=self.max_token_num, - hidden_dim=self.hidden_dim, - dtype=self.model_dtype, - group=self.group, - ) - except Exception as e: - if "multicast" in str(e).lower(): - logger.warning( - "AllReduce fusion pass is disabled: flashinfer workspace " - "creation failed: %s. This is expected on GPUs without " - "NVSwitch (e.g., NVLink bridge-only or PCIe topologies). " - "Falling back to non-fused allreduce.", - str(e), - ) - else: - logger.warning( - "Failed to initialize FlashInfer All Reduce workspace: %s. " - "AllReduce fusion pass will be disabled.", - e, - ) - return + if not _initialize_fi_ar_workspaces( + world_size=self.tp_size, + rank=rank, + max_token_num=self.max_token_num, + hidden_dim=self.hidden_dim, + dtype=self.model_dtype, + group=self.group, + ): + return self.allreduce_params = FlashInferFusedAllReduceParams( world_size=self.tp_size, From 4930908018ceb03d5601a46728f3f889986b3922 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 18 Mar 2026 11:09:37 -0700 Subject: [PATCH 2/7] Fix formatting Signed-off-by: <> Signed-off-by: wzhao18 --- vllm/compilation/passes/fusion/allreduce_rms_fusion.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 4ea6b6f27541..f1b12bf368f0 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -178,9 +178,7 @@ def call_trtllm_fused_allreduce_norm( ar_fusion_patterns.kARResidualRMSNormFP4Quant, ) workspace = ( - get_fi_ar_quant_workspace() - if is_quant_pattern - else get_fi_ar_workspace() + get_fi_ar_quant_workspace() if is_quant_pattern else get_fi_ar_workspace() ) if workspace is None: @@ -204,9 +202,7 @@ def call_trtllm_fused_allreduce_norm( else get_fi_ar_workspace() ) if workspace is None: - raise RuntimeError( - "FlashInfer workspace is None after initialization" - ) + raise RuntimeError("FlashInfer workspace is None after initialization") assert flashinfer_comm is not None if norm_out is None: norm_out = allreduce_in From fe84ce8cc098f411ba655b36fc6d27ff9d183b14 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Thu, 19 Mar 2026 18:51:34 -0700 Subject: [PATCH 3/7] Code clean up Signed-off-by: wzhao18 --- .../passes/fusion/allreduce_rms_fusion.py | 92 +++------- .../flashinfer_all_reduce.py | 163 +++++++++--------- 2 files changed, 104 insertions(+), 151 deletions(-) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index f1b12bf368f0..435f2ef008be 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -86,54 +86,12 @@ destroy_fi_ar_workspace, get_fi_ar_quant_workspace, get_fi_ar_workspace, - initialize_fi_ar_quant_workspace, - initialize_fi_ar_workspace, ) ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern MiB = 1024 * 1024 - def _initialize_fi_ar_workspaces( - world_size: int, - rank: int, - max_token_num: int, - hidden_dim: int, - dtype: torch.dtype, - group, - ) -> bool: - """Initialize FlashInfer AR workspaces. Returns True on success.""" - for workspace_init_fn in [ - initialize_fi_ar_workspace, - initialize_fi_ar_quant_workspace, - ]: - try: - workspace_init_fn( - world_size=world_size, - rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - group=group, - ) - except Exception as e: - if "multicast" in str(e).lower(): - logger.warning_once( - "AllReduce fusion pass is disabled: flashinfer workspace " - "creation failed: %s. This is expected on GPUs without " - "NVSwitch (e.g., NVLink bridge-only or PCIe topologies). " - "Falling back to non-fused allreduce.", - str(e), - ) - else: - logger.warning_once( - "Failed to initialize FlashInfer All Reduce workspace: %s. " - "AllReduce fusion pass will be disabled.", - e, - ) - return False - return True - def call_trtllm_fused_allreduce_norm( allreduce_in: torch.Tensor, residual: torch.Tensor, @@ -177,32 +135,20 @@ def call_trtllm_fused_allreduce_norm( ar_fusion_patterns.kARResidualRMSNormFP8Quant, ar_fusion_patterns.kARResidualRMSNormFP4Quant, ) - workspace = ( - get_fi_ar_quant_workspace() if is_quant_pattern else get_fi_ar_workspace() + get_workspace_fn = ( + get_fi_ar_quant_workspace if is_quant_pattern else get_fi_ar_workspace + ) + workspace = get_workspace_fn( + world_size=world_size, + rank=get_tensor_model_parallel_rank(), + max_token_num=max_token_num, + hidden_dim=hidden_size, + dtype=allreduce_in.dtype, + group=get_tp_group().device_group, + ) + assert workspace is not None, ( + "Flashinfer allreduce workspace must be initialized when using flashinfer" ) - - if workspace is None: - # Workspace may not be initialized by AllReduceFusionPass if the function is - # directly loaded from torch AOT compile cache. Lazily initialize here. - _, hidden_dim = allreduce_in.shape - if not _initialize_fi_ar_workspaces( - world_size=world_size, - rank=get_tensor_model_parallel_rank(), - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=allreduce_in.dtype, - group=get_tp_group().device_group, - ): - raise RuntimeError( - "Failed to initialize FlashInfer All Reduce workspace" - ) - workspace = ( - get_fi_ar_quant_workspace() - if is_quant_pattern - else get_fi_ar_workspace() - ) - if workspace is None: - raise RuntimeError("FlashInfer workspace is None after initialization") assert flashinfer_comm is not None if norm_out is None: norm_out = allreduce_in @@ -813,16 +759,21 @@ def __init__(self, config: VllmConfig) -> None: scope="global", ) - if not _initialize_fi_ar_workspaces( + workspace_kwargs = dict( world_size=self.tp_size, rank=rank, max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, dtype=self.model_dtype, group=self.group, - ): + ) + if get_fi_ar_workspace(**workspace_kwargs) is None: return + self.supports_ar_quant_fusion = ( + get_fi_ar_quant_workspace(**workspace_kwargs) is not None + ) + self.allreduce_params = FlashInferFusedAllReduceParams( world_size=self.tp_size, max_token_num=self.max_token_num, @@ -833,9 +784,8 @@ def __init__(self, config: VllmConfig) -> None: @enable_fake_mode def register_patterns(self) -> None: - supports_quantization = get_fi_ar_quant_workspace() is not None for epsilon in [1e-5, 1e-6]: - if supports_quantization: + if self.supports_ar_quant_fusion: AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, self.model_dtype, diff --git a/vllm/distributed/device_communicators/flashinfer_all_reduce.py b/vllm/distributed/device_communicators/flashinfer_all_reduce.py index 66e089182869..042a64a89a0d 100644 --- a/vllm/distributed/device_communicators/flashinfer_all_reduce.py +++ b/vllm/distributed/device_communicators/flashinfer_all_reduce.py @@ -29,50 +29,27 @@ except ImportError: pass -# Global workspace for standalone allreduce and non-quant ar+rms fusion +# Workspace for standalone allreduce and non-quant ar+rms fusion _fi_ar_workspace = None # Extra workspace for quant fusion patterns (only supported by trtllm backend) -# Only created if primary workspace is not already trtllm _fi_ar_quant_workspace = None -def get_fi_ar_workspace(): - return _fi_ar_workspace - - -def get_fi_ar_quant_workspace(): - return _fi_ar_quant_workspace - - -def initialize_fi_ar_workspace( +def _create_workspace( + backend: str, world_size: int, rank: int, max_token_num: int, hidden_dim: int, dtype: torch.dtype, group: ProcessGroup, -) -> None: - """ - Initialize the workspace if not already initialized. - - Currently, this function is called by either the AllReduceFusionPass - or the FlashInferAllReduce backend for standalone allreduce. - If the fusion pass is enabled via - --compilation-config.pass_config.fuse_allreduce_rms=true, - it will create the workspace first, and the standalone backend - will reuse the workspace. Otherwise, the standalone backend will - create the workspace. - """ - global _fi_ar_workspace - if _fi_ar_workspace is not None: - return - - backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND +): + """Create a flashinfer allreduce workspace, returning None on failure.""" comm_backend = TorchDistBackend(group=group) rng_state = random.getstate() try: random.seed(int.from_bytes(os.urandom(16), byteorder="big")) - _fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace( + workspace = flashinfer_comm.create_allreduce_fusion_workspace( backend=backend, world_size=world_size, rank=rank, @@ -81,9 +58,24 @@ def initialize_fi_ar_workspace( dtype=dtype, comm_backend=comm_backend, ) + except Exception as e: + if "multicast" in str(e).lower(): + logger.warning_once( + "Failed to initialize FlashInfer All Reduce workspace: %s. " + "AllReduce fusion pass will be disabled. This is expected " + "on GPUs without NVSwitch (e.g., NVLink bridge-only or " + "PCIe topologies).", + e, + ) + else: + logger.warning_once( + "Failed to initialize FlashInfer All Reduce workspace: %s. " + "AllReduce fusion pass will be disabled.", + e, + ) + return None finally: random.setstate(rng_state) - assert _fi_ar_workspace is not None logger.debug( "Initialized FlashInfer All Reduce workspace: backend=%s, " "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s", @@ -94,52 +86,63 @@ def initialize_fi_ar_workspace( hidden_dim, dtype, ) + return workspace -def initialize_fi_ar_quant_workspace( +def get_fi_ar_workspace( world_size: int, rank: int, max_token_num: int, hidden_dim: int, dtype: torch.dtype, group: ProcessGroup, -) -> None: +): """ - Initialize the workspace used by quantization fusion patterns. + Return the allreduce workspace for non-quant patterns, initializing if needed. - Currently this always creates a workspace for trtllm backend as only it - supports quantization fusion (FP8/FP4). If the primary workspace - is already trtllm, the quant workspace aliases to it. + Used by AllReduceFusionPass (non-quant patterns) and FlashInferAllReduce + for standalone allreduce. Backend is controlled by + VLLM_FLASHINFER_ALLREDUCE_BACKEND env var. """ - global _fi_ar_quant_workspace - if _fi_ar_quant_workspace is not None: - return + global _fi_ar_workspace + if _fi_ar_workspace is not None: + return _fi_ar_workspace - # If primary workspace is already trtllm, reuse it - if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm": - _fi_ar_quant_workspace = _fi_ar_workspace - return + backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND - comm_backend = TorchDistBackend(group=group) - _fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", - world_size=world_size, - rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - comm_backend=comm_backend, + # Reuse the quant workspace if it was already created with the same backend + if _fi_ar_quant_workspace is not None and _fi_ar_quant_workspace.backend == backend: + _fi_ar_workspace = _fi_ar_quant_workspace + return _fi_ar_workspace + + _fi_ar_workspace = _create_workspace( + backend, world_size, rank, max_token_num, hidden_dim, dtype, group ) - assert _fi_ar_quant_workspace is not None - logger.debug( - "Initialized FlashInfer All Reduce workspace: backend=trtllm, " - "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s", - world_size, - rank, - max_token_num, - hidden_dim, - dtype, + return _fi_ar_workspace + + +def get_fi_ar_quant_workspace( + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + group: ProcessGroup, +): + """ + Return the allreduce workspace for quant patterns, initializing if needed. + + Always uses trtllm backend as it is the only one supporting quantization + fusion (FP8/FP4). + """ + global _fi_ar_quant_workspace + if _fi_ar_quant_workspace is not None: + return _fi_ar_quant_workspace + + _fi_ar_quant_workspace = _create_workspace( + "trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group ) + return _fi_ar_quant_workspace _fi_ar_workspace_lock = threading.Lock() @@ -209,29 +212,21 @@ def __init__( def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool: """Ensure the all reduce workspace is initialized.""" - if get_fi_ar_workspace() is not None: - return True if self.max_num_tokens == 0: element_size = torch.tensor([], dtype=dtype, device="cpu").element_size() self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size) - try: - initialize_fi_ar_workspace( - world_size=self.world_size, - rank=self.rank, - max_token_num=self.max_num_tokens, - hidden_dim=hidden_dim, - dtype=dtype, - group=self.group, - ) - return True - except Exception as e: - logger.warning( - "Failed to initialize FlashInfer All Reduce workspace: %s. " - "FlashInfer All Reduce will be disabled.", - e, - ) + workspace = get_fi_ar_workspace( + world_size=self.world_size, + rank=self.rank, + max_token_num=self.max_num_tokens, + hidden_dim=hidden_dim, + dtype=dtype, + group=self.group, + ) + if workspace is None: self.disabled = True return False + return True def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool: if self.disabled: @@ -257,7 +252,15 @@ def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool: return self._ensure_workspace(hidden_dim, input_tensor.dtype) def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor: - workspace = get_fi_ar_workspace() + _, hidden_dim = input_tensor.shape + workspace = get_fi_ar_workspace( + world_size=self.world_size, + rank=self.rank, + max_token_num=self.max_num_tokens, + hidden_dim=hidden_dim, + dtype=input_tensor.dtype, + group=self.group, + ) return flashinfer_comm.allreduce_fusion( input=input_tensor, workspace=workspace, From 43f309b3f0226a3581ec11bfcc93cfedac12917c Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Thu, 19 Mar 2026 18:58:42 -0700 Subject: [PATCH 4/7] Code clean up Signed-off-by: wzhao18 --- vllm/compilation/passes/fusion/allreduce_rms_fusion.py | 8 ++++++-- .../device_communicators/flashinfer_all_reduce.py | 8 +++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 435f2ef008be..f62535a555eb 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -768,9 +768,13 @@ def __init__(self, config: VllmConfig) -> None: group=self.group, ) if get_fi_ar_workspace(**workspace_kwargs) is None: + logger.warning_once( + "Failed to initialize Flashinfer allreduce workspace. " + "Flashinfer allreduce fusion will be disabled." + ) return - self.supports_ar_quant_fusion = ( + self.supports_quant_fusion = ( get_fi_ar_quant_workspace(**workspace_kwargs) is not None ) @@ -785,7 +789,7 @@ def __init__(self, config: VllmConfig) -> None: @enable_fake_mode def register_patterns(self) -> None: for epsilon in [1e-5, 1e-6]: - if self.supports_ar_quant_fusion: + if self.supports_quant_fusion: AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, self.model_dtype, diff --git a/vllm/distributed/device_communicators/flashinfer_all_reduce.py b/vllm/distributed/device_communicators/flashinfer_all_reduce.py index 042a64a89a0d..78ef9bbc7d3f 100644 --- a/vllm/distributed/device_communicators/flashinfer_all_reduce.py +++ b/vllm/distributed/device_communicators/flashinfer_all_reduce.py @@ -62,15 +62,13 @@ def _create_workspace( if "multicast" in str(e).lower(): logger.warning_once( "Failed to initialize FlashInfer All Reduce workspace: %s. " - "AllReduce fusion pass will be disabled. This is expected " - "on GPUs without NVSwitch (e.g., NVLink bridge-only or " - "PCIe topologies).", + "This is expected on GPUs without NVSwitch (e.g., NVLink " + "bridge-only or PCIe topologies).", e, ) else: logger.warning_once( - "Failed to initialize FlashInfer All Reduce workspace: %s. " - "AllReduce fusion pass will be disabled.", + "Failed to initialize FlashInfer All Reduce workspace: %s.", e, ) return None From 4a0240c8afd2b91f41ba0dd5e6e1d2c97bb1244b Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Thu, 19 Mar 2026 19:26:11 -0700 Subject: [PATCH 5/7] Add warning is AR quant fusion is disabled Signed-off-by: wzhao18 --- vllm/compilation/passes/fusion/allreduce_rms_fusion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index f62535a555eb..623ff5913763 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -770,13 +770,18 @@ def __init__(self, config: VllmConfig) -> None: if get_fi_ar_workspace(**workspace_kwargs) is None: logger.warning_once( "Failed to initialize Flashinfer allreduce workspace. " - "Flashinfer allreduce fusion will be disabled." + "Flashinfer allreduce-norm fusion will be disabled." ) return self.supports_quant_fusion = ( get_fi_ar_quant_workspace(**workspace_kwargs) is not None ) + if not self.supports_quant_fusion: + logger.warning_once( + "Failed to initialize Flashinfer allreduce workspace. " + "Flashinfer allreduce-norm-quant fusion will be disabled." + ) self.allreduce_params = FlashInferFusedAllReduceParams( world_size=self.tp_size, From 6ede8bdf89f63a241395246490d47695a73a5ef4 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Thu, 19 Mar 2026 19:32:07 -0700 Subject: [PATCH 6/7] quant workspace may reuse non-quant Signed-off-by: wzhao18 --- .../device_communicators/flashinfer_all_reduce.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/distributed/device_communicators/flashinfer_all_reduce.py b/vllm/distributed/device_communicators/flashinfer_all_reduce.py index 78ef9bbc7d3f..8a5f91492fd9 100644 --- a/vllm/distributed/device_communicators/flashinfer_all_reduce.py +++ b/vllm/distributed/device_communicators/flashinfer_all_reduce.py @@ -137,6 +137,11 @@ def get_fi_ar_quant_workspace( if _fi_ar_quant_workspace is not None: return _fi_ar_quant_workspace + # Reuse the non-quant workspace if it was already created with trtllm + if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm": + _fi_ar_quant_workspace = _fi_ar_workspace + return _fi_ar_quant_workspace + _fi_ar_quant_workspace = _create_workspace( "trtllm", world_size, rank, max_token_num, hidden_dim, dtype, group ) From 863debdf6a467bce131eb9f3ebf84a4891eb4ab2 Mon Sep 17 00:00:00 2001 From: wzhao18 Date: Thu, 19 Mar 2026 19:41:18 -0700 Subject: [PATCH 7/7] Update destroy_fi_ar_workspace Signed-off-by: wzhao18 --- .../flashinfer_all_reduce.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/device_communicators/flashinfer_all_reduce.py b/vllm/distributed/device_communicators/flashinfer_all_reduce.py index 8a5f91492fd9..b2edfc15d731 100644 --- a/vllm/distributed/device_communicators/flashinfer_all_reduce.py +++ b/vllm/distributed/device_communicators/flashinfer_all_reduce.py @@ -152,18 +152,16 @@ def get_fi_ar_quant_workspace( def destroy_fi_ar_workspace(): - global _fi_ar_workspace - global _fi_ar_quant_workspace + global _fi_ar_workspace, _fi_ar_quant_workspace with _fi_ar_workspace_lock: - if ( - _fi_ar_quant_workspace is not None - and _fi_ar_quant_workspace is not _fi_ar_workspace - ): - _fi_ar_quant_workspace.destroy() - _fi_ar_quant_workspace = None + is_alias = _fi_ar_workspace is _fi_ar_quant_workspace + if _fi_ar_workspace is not None: _fi_ar_workspace.destroy() - _fi_ar_workspace = None + if _fi_ar_quant_workspace is not None and not is_alias: + _fi_ar_quant_workspace.destroy() + + _fi_ar_workspace = _fi_ar_quant_workspace = None atexit.register(destroy_fi_ar_workspace)