Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aiter/dist/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* limitations under the License.
"""

import os
from contextlib import contextmanager
from typing import Any, List, Optional, Union

Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
# return

self.disabled = False
self.tms_cudagraph = os.getenv("SGLANG_MEMORY_SAVER_CUDA_GRAPH", "0")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix with SGLANG_MEMORY_SAVER_CUDA_GRAPH?

@zyzshishui zyzshishui Feb 23, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. How about adding a tms_cudagraph parameter to __init__ and parse the param from sglang? But this would need sglang using updated aiter. Any suggestions?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, let's add "enable_register" param to init

@zyzshishui zyzshishui Feb 25, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added "enable_register_for_capturing" since this only control behavior for capturing not real calling. But again, we cannot make subsequent change before sglang bump up aiter's version. I will keep an eye on that

# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
Expand Down Expand Up @@ -329,8 +331,8 @@ def custom_all_reduce(
input,
use_new=use_new,
open_fp8_quant=open_fp8_quant,
registered_input=False,
registered_output=False
registered_input=not self.tms_cudagraph,
registered_output=not self.tms_cudagraph
)

def reduce_scatter(
Expand Down
Loading