From c19ff57fbe515b018071f3fd1d50be6fd9c44511 Mon Sep 17 00:00:00 2001 From: zyzshishui Date: Mon, 23 Feb 2026 06:33:31 +0000 Subject: [PATCH 1/3] 1 --- aiter/dist/device_communicators/custom_all_reduce.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index d01010ae53..4b19c9e652 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -15,6 +15,7 @@ * limitations under the License. """ +import os from contextlib import contextmanager from typing import Any, List, Optional, Union @@ -147,6 +148,7 @@ def __init__( # return self.disabled = False + self.tms_cudagraph = os.getenv("SGLANG_MEMORY_SAVER_CUDA_GRAPH", "0") # buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate @@ -329,8 +331,8 @@ def custom_all_reduce( input, use_new=use_new, open_fp8_quant=open_fp8_quant, - registered_input=False, - registered_output=False + registered_input=not self.tms_cudagraph, + registered_output=not self.tms_cudagraph ) def reduce_scatter( From 8146fca7408f18b92d3a781da31fc2b117a89914 Mon Sep 17 00:00:00 2001 From: Yuzhen Zhou <82826991+zyzshishui@users.noreply.github.com> Date: Mon, 23 Feb 2026 18:53:30 -0800 Subject: [PATCH 2/3] 1 --- aiter/dist/device_communicators/custom_all_reduce.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index 4b19c9e652..32e57e2c0c 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -315,8 +315,8 @@ def custom_all_reduce( input, use_new=use_new, open_fp8_quant=open_fp8_quant, - registered_input=True, - registered_output=True + registered_input=not self.tms_cudagraph, + registered_output=not self.tms_cudagraph ) else: # if warm up, mimic the allocation pattern @@ -331,8 +331,8 @@ def custom_all_reduce( input, use_new=use_new, open_fp8_quant=open_fp8_quant, - registered_input=not self.tms_cudagraph, - registered_output=not self.tms_cudagraph + registered_input=False, + registered_output=False ) def reduce_scatter( From a1d148fd9affe3bca4e8cc004a4b736a60b29c22 Mon Sep 17 00:00:00 2001 From: zyzshishui Date: Tue, 24 Feb 2026 18:50:48 -0800 Subject: [PATCH 3/3] 1 --- aiter/dist/device_communicators/custom_all_reduce.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index 32e57e2c0c..bf15f70d43 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -15,7 +15,6 @@ * limitations under the License. """ -import os from contextlib import contextmanager from typing import Any, List, Optional, Union @@ -56,6 +55,7 @@ def __init__( group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024 * 8 * 2, # In allreduce 2stage writemode, use 2x tmp buffer + enable_register_for_capturing: bool = True, ) -> None: """ Args: @@ -148,7 +148,7 @@ def __init__( # return self.disabled = False - self.tms_cudagraph = os.getenv("SGLANG_MEMORY_SAVER_CUDA_GRAPH", "0") + self.enable_register_for_capturing = enable_register_for_capturing # buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate @@ -315,8 +315,8 @@ def custom_all_reduce( input, use_new=use_new, open_fp8_quant=open_fp8_quant, - registered_input=not self.tms_cudagraph, - registered_output=not self.tms_cudagraph + registered_input=self.enable_register_for_capturing, + registered_output=self.enable_register_for_capturing ) else: # if warm up, mimic the allocation pattern