Skip to content

Use unreg path for custom all-reduce during CUDA graph capture#2075

Merged
valarLip merged 4 commits intoROCm:mainfrom
zyzshishui:ar
Mar 3, 2026
Merged

Use unreg path for custom all-reduce during CUDA graph capture#2075
valarLip merged 4 commits intoROCm:mainfrom
zyzshishui:ar

Conversation

@zyzshishui
Copy link
Copy Markdown
Contributor

Motivation

Same as sgl-project/sglang#19162.

Super tiny fix, needed to be compatible with torch_memory_saver. Error path:

  1. torch_memory_saver hook hipMalloc with hipMemAddressReserve+ hipMemMap(VMM APIs)
  2. During register_graph_buffers, hipIpcGetMemHandleexpect a ptr by hipMallocbut in fact, got buffer from hipMemAddressReserve+ hipMemMap
  3. hipIpcGetMemHandle does not check and accept this invalid handle (invalid because Runtime API and VMM API use different Allocation Table), which I raised a fix here
  4. Other ranks call hipIpcOpenMemHandle(invalid_handle) and fail, causing a hang during cuda graph capturing

Technical Details

Test Plan

Test Result

Submission Checklist

# return

self.disabled = False
self.tms_cudagraph = os.getenv("SGLANG_MEMORY_SAVER_CUDA_GRAPH", "0")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefix with SGLANG_MEMORY_SAVER_CUDA_GRAPH?

Copy link
Copy Markdown
Contributor Author

@zyzshishui zyzshishui Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. How about adding a tms_cudagraph parameter to __init__ and parse the param from sglang? But this would need sglang using updated aiter. Any suggestions?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, let's add "enable_register" param to init

Copy link
Copy Markdown
Contributor Author

@zyzshishui zyzshishui Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added "enable_register_for_capturing" since this only control behavior for capturing not real calling. But again, we cannot make subsequent change before sglang bump up aiter's version. I will keep an eye on that

@valarLip valarLip merged commit d4f5e52 into ROCm:main Mar 3, 2026
18 checks passed
@zyzshishui zyzshishui deleted the ar branch March 8, 2026 07:54
valarLip pushed a commit that referenced this pull request Mar 18, 2026
AMD-yanfeiwang pushed a commit to AMD-yanfeiwang/aiter that referenced this pull request Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants