Skip to content

feat(sglang): disagg DP rank routing + backwards-compatible network imports#6736

Merged
ishandhanani merged 11 commits intomainfrom
idhanani/disagg-dp-rank-routing
Mar 25, 2026
Merged

feat(sglang): disagg DP rank routing + backwards-compatible network imports#6736
ishandhanani merged 11 commits intomainfrom
idhanani/disagg-dp-rank-routing

Conversation

@ishandhanani
Copy link
Copy Markdown
Contributor

@ishandhanani ishandhanani commented Mar 1, 2026

Summary

  • Pass data_parallel_rank from routing info to SGLang's prefill handler for disagg DP routing
  • Add _compat.py shim for SGLang network imports -- works with both 0.5.9 and main
  • Use NetworkAddress (real or polyfill) for IPv6-safe address handling in publisher, register, and handler_base

Changes

Disagg DP rank routing (prefill_handler.py)

Extracts dp_rank from the routing dict and passes data_parallel_rank to engine.async_generate(). The decode handler already does this; this closes the gap for disaggregated serving.

SGLang backwards compatibility (_compat.py)

SGLang post-0.5.9 moved get_local_ip_auto, get_zmq_socket from sglang.srt.utils to sglang.srt.utils.network and introduced NetworkAddress. Rather than pinning to one version, _compat.py tries the new import path first and falls back to the old path with a minimal NetworkAddress polyfill.

All SGLang network imports in the component now go through _compat.py instead of importing directly from sglang.srt.utils*.

Supersedes PR #7597

PR #7597 (Fix SGLang network helper imports) is a strict subset of this work. If this merges first, #7597 can be closed.

Validation

Tested all 4 combinations:

  • agg.sh + sglang main -- pass
  • disagg.sh + sglang main -- pass
  • agg.sh + sglang 0.5.9 (pip) -- pass
  • disagg.sh + sglang 0.5.9 (pip) -- pass

Summary by CodeRabbit

  • New Features

    • Added support for dp_rank (data parallel rank) parameter in request routing, enabling more granular control over distributed execution.
  • Documentation

    • Added SGLang backwards compatibility guidelines to ensure consistent handling across different SGLang versions.
  • Refactor

    • Improved internal network address handling and endpoint formatting for better version compatibility.

The prefill handler was missing the data_parallel_rank parameter in its
async_generate call, preventing DP rank-aware routing from working in
disaggregated mode. The decode handler already passes this correctly.

Extract dp_rank from the routing info (set by the KV router in
prefill_router.rs) and forward it to SGLang's engine so the prefill
scheduler directs work to the correct DP rank.

This works in conjunction with sgl-project/sglang#19168, which adds
per-request DP rank resolution on the SGLang side -- the decode worker
can now resolve the prefill DP rank via the bootstrap server rather
than relying on bootstrap_room % dp_size.
@github-actions github-actions bot added feat backend::sglang Relates to the sglang backend labels Mar 1, 2026
@ishandhanani ishandhanani changed the title feat: pass data_parallel_rank to prefill handler for disagg DP routing [do not merge] feat: pass data_parallel_rank to prefill handler for disagg DP routing Mar 1, 2026
@ishandhanani ishandhanani changed the title [do not merge] feat: pass data_parallel_rank to prefill handler for disagg DP routing [do not merge] feat: pass data_parallel_rank for disagg routing + API changes for SGL > 0.5.9 Mar 17, 2026
@pull-request-size pull-request-size bot added size/L and removed size/S labels Mar 17, 2026
MatejKosec added a commit that referenced this pull request Mar 20, 2026
- Revert publisher.py changes (PR #6736 handles the SGLang compat)
- Unify /// doc comments to // regular comments in reasoning parser tests

Signed-off-by: Matej Kosec <mkosec@nvidia.com>
MatejKosec added a commit that referenced this pull request Mar 21, 2026
- Revert publisher.py changes (PR #6736 handles the SGLang compat)
- Unify /// doc comments to // regular comments in reasoning parser tests

Signed-off-by: Matej Kosec <mkosec@nvidia.com>
SGLang post-0.5.9 moved network helpers to sglang.srt.utils.network
and introduced NetworkAddress. This adds _compat.py that tries the new
path first and falls back to the old path with a minimal polyfill,
so Dynamo works with both sglang 0.5.9 and main.
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 24, 2026
@ishandhanani ishandhanani changed the title [do not merge] feat: pass data_parallel_rank for disagg routing + API changes for SGL > 0.5.9 feat(sglang): disagg DP rank routing + backwards-compatible network imports Mar 24, 2026
@ishandhanani ishandhanani marked this pull request as ready for review March 24, 2026 11:31
@ishandhanani ishandhanani requested a review from a team as a code owner March 24, 2026 11:31
Each fallback branch must note which version it supports and when it
can be removed. Old fallbacks are cleaned up when that version falls
outside the 1-version-back support window.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 24, 2026

Walkthrough

This change introduces a SGLang network compatibility layer (_compat.py) to abstract version-fragile imports, then refactors four modules to use NetworkAddress for endpoint formatting and address resolution instead of manual socket operations and IPv6 handling. An additional module now extracts and forwards data parallel rank in request handling.

Changes

Cohort / File(s) Summary
SGLang Compatibility Documentation and Shim
components/src/dynamo/sglang/CLAUDE.md, components/src/dynamo/sglang/_compat.py
Establishes backwards-compatible import pattern for SGLang network utilities via try/except ImportError fallback in new shim. Shim re-exports NetworkAddress, get_local_ip_auto, and get_zmq_socket, with internal NetworkAddress polyfill supporting IPv6 detection, DNS resolution, and endpoint formatting when SGLang ≤ 0.5.9.
Endpoint and Address Resolution Refactoring
components/src/dynamo/sglang/publisher.py, components/src/dynamo/sglang/register.py, components/src/dynamo/sglang/request_handlers/handler_base.py
Replaces manual socket-based IPv6 parsing and DNS resolution with NetworkAddress abstraction. Updates endpoint formatting to validate tcp URLs and construct endpoints via NetworkAddress(...).to_tcp() instead of wildcard replacement. Removes explicit bracket-wrapping logic and socket.getaddrinfo calls across address resolution flows.
Request Handler Enhancement
components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py
Extracts dp_rank from routing dict and forwards it to engine as data_parallel_rank parameter alongside existing priority-based kwargs.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically summarizes the two main changes: disaggregated DP rank routing and backwards-compatible SGLang network imports.
Description check ✅ Passed The description follows the template with Overview (Summary section), Details (Changes with subsections), and Related Issues. All required sections are present and well-documented.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
components/src/dynamo/sglang/publisher.py (1)

30-51: ⚠️ Potential issue | 🟠 Major

Only replace wildcard-style hosts here.

This rewrite now discards any explicit host in endpoint_template and always substitutes ip_address. That changes the old wildcard-replacement contract and breaks configs like tcp://127.0.0.1:5557, where the subscriber needs to keep the configured loopback host.

🔧 Suggested fix
 def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str:
@@
     parsed = urlparse(endpoint_template)
-    if parsed.scheme != "tcp" or parsed.port is None:
+    if parsed.scheme != "tcp" or parsed.port is None or parsed.hostname is None:
         raise ValueError(f"Expected tcp://host:port endpoint, got {endpoint_template!r}")
-    return NetworkAddress(ip_address, parsed.port).to_tcp()
+    host = parsed.hostname
+    if host in {"*", "0.0.0.0", "::"}:
+        host = ip_address
+    return NetworkAddress(host, parsed.port).to_tcp()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@components/src/dynamo/sglang/publisher.py` around lines 30 - 51,
format_zmq_endpoint currently always replaces the host with ip_address, breaking
explicit hosts; change it to only substitute when the template host is a
wildcard (e.g., '*' or the universal bind addresses '0.0.0.0' or '::'). In
format_zmq_endpoint, inspect parsed.hostname (from urlparse(endpoint_template))
and if it is not a wildcard/universal value, return the original
endpoint_template (or reconstruct and return parsed.scheme + host + port)
unchanged; otherwise call NetworkAddress(ip_address, parsed.port).to_tcp() as
before. Ensure you still validate scheme == "tcp" and parsed.port is present and
only perform substitution when hostname indicates a wildcard bind.
components/src/dynamo/sglang/register.py (1)

8-15: ⚠️ Potential issue | 🟡 Minor

Commit the isort rewrite for this import block.

Pre-merge is already failing on this file, so the new _compat import still needs the repository’s standard ordering applied before this can merge.

As per coding guidelines, "Follow import ordering via isort (stdlib → third-party → first-party) and run ruff format / ruff check or pre-commit on touched Python files."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@components/src/dynamo/sglang/register.py` around lines 8 - 15, The import
block in register.py is out of isort order; reorder imports into stdlib →
third-party → first-party groups and run the project's formatter (isort/ruff) so
the new `dynamo.sglang._compat` import is placed in the first-party group.
Specifically, group and sort the imports that reference sglang (sgl,
ServerArgs), dynamo.sglang._compat (NetworkAddress, get_local_ip_auto),
dynamo._core (Endpoint), dynamo.common.utils.output_modalities
(get_output_modalities), dynamo.llm (ModelInput, ModelRuntimeConfig, ModelType,
register_model), and dynamo.sglang.args (DynamoConfig) according to the repo's
import ordering rules and then run the pre-commit/ruff checks to ensure
formatting passes.
🧹 Nitpick comments (1)
components/src/dynamo/sglang/register.py (1)

89-110: Extract bootstrap host normalization into one helper.

The NetworkAddress(...).to_host_port_str().rsplit(":", 1)[0] sequence is now duplicated here and in components/src/dynamo/sglang/request_handlers/handler_base.py. Since discovery registration and the runtime bootstrap payload need to stay identical, centralizing it next to the shim would reduce drift risk.

As per coding guidelines, "Keep functions and methods concise and focused on a single responsibility."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@components/src/dynamo/sglang/register.py` around lines 89 - 110, Duplicate
bootstrap host normalization logic
(NetworkAddress(...).to_host_port_str().rsplit(":", 1)[0]) exists in register.py
and handler_base.py; extract this into a single helper (e.g.,
normalize_bootstrap_host or build_bootstrap_host) placed next to the shim so
both discovery registration and runtime bootstrap payload use the same code.
Update components/src/dynamo/sglang/register.py (the code paths that set
bootstrap_host from dist_init.resolved() and from
get_local_ip_auto()/local_addr) to call the new helper, and likewise replace the
duplicated sequence in
components/src/dynamo/sglang/request_handlers/handler_base.py with the same
helper; ensure the helper accepts a NetworkAddress (or host and port) and
returns the normalized host string, preserving IPv6 bracket handling.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@components/src/dynamo/sglang/request_handlers/handler_base.py`:
- Around line 21-23: Reorder the import block to follow isort ordering (stdlib →
third-party → first-party): ensure any stdlib imports come first, then
third-party like "import sglang as sgl", and then the local package import "from
dynamo.sglang._compat import NetworkAddress, get_local_ip_auto"; run isort/ruff
format or pre-commit on handler_base.py (referencing the top-level imports in
this file) to apply and persist the change before merging.

---

Outside diff comments:
In `@components/src/dynamo/sglang/publisher.py`:
- Around line 30-51: format_zmq_endpoint currently always replaces the host with
ip_address, breaking explicit hosts; change it to only substitute when the
template host is a wildcard (e.g., '*' or the universal bind addresses '0.0.0.0'
or '::'). In format_zmq_endpoint, inspect parsed.hostname (from
urlparse(endpoint_template)) and if it is not a wildcard/universal value, return
the original endpoint_template (or reconstruct and return parsed.scheme + host +
port) unchanged; otherwise call NetworkAddress(ip_address, parsed.port).to_tcp()
as before. Ensure you still validate scheme == "tcp" and parsed.port is present
and only perform substitution when hostname indicates a wildcard bind.

In `@components/src/dynamo/sglang/register.py`:
- Around line 8-15: The import block in register.py is out of isort order;
reorder imports into stdlib → third-party → first-party groups and run the
project's formatter (isort/ruff) so the new `dynamo.sglang._compat` import is
placed in the first-party group. Specifically, group and sort the imports that
reference sglang (sgl, ServerArgs), dynamo.sglang._compat (NetworkAddress,
get_local_ip_auto), dynamo._core (Endpoint),
dynamo.common.utils.output_modalities (get_output_modalities), dynamo.llm
(ModelInput, ModelRuntimeConfig, ModelType, register_model), and
dynamo.sglang.args (DynamoConfig) according to the repo's import ordering rules
and then run the pre-commit/ruff checks to ensure formatting passes.

---

Nitpick comments:
In `@components/src/dynamo/sglang/register.py`:
- Around line 89-110: Duplicate bootstrap host normalization logic
(NetworkAddress(...).to_host_port_str().rsplit(":", 1)[0]) exists in register.py
and handler_base.py; extract this into a single helper (e.g.,
normalize_bootstrap_host or build_bootstrap_host) placed next to the shim so
both discovery registration and runtime bootstrap payload use the same code.
Update components/src/dynamo/sglang/register.py (the code paths that set
bootstrap_host from dist_init.resolved() and from
get_local_ip_auto()/local_addr) to call the new helper, and likewise replace the
duplicated sequence in
components/src/dynamo/sglang/request_handlers/handler_base.py with the same
helper; ensure the helper accepts a NetworkAddress (or host and port) and
returns the normalized host string, preserving IPv6 bracket handling.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: ce1bbe20-a510-49e7-8eb6-8cf6fddb40e8

📥 Commits

Reviewing files that changed from the base of the PR and between 115512e and f572cda.

📒 Files selected for processing (6)
  • components/src/dynamo/sglang/CLAUDE.md
  • components/src/dynamo/sglang/_compat.py
  • components/src/dynamo/sglang/publisher.py
  • components/src/dynamo/sglang/register.py
  • components/src/dynamo/sglang/request_handlers/handler_base.py
  • components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py

@huitianbai
Copy link
Copy Markdown
Contributor

@ishandhanani Hi, If router_mode is not kv, prefill handler will always get a dp rank "0".
Here it is: https://github.com/ai-dynamo/dynamo/blob/main/lib/llm/src/kv_router/prefill_router.rs#L551
I think it should be avoided.
In my PR to return a special value #7214, this is tricky or let 'rust' never return a dp_rank ?

@ishandhanani
Copy link
Copy Markdown
Contributor Author

@ishandhanani Hi, If router_mode is not kv, prefill handler will always get a dp rank "0". Here it is: https://github.com/ai-dynamo/dynamo/blob/main/lib/llm/src/kv_router/prefill_router.rs#L551 I think it should be avoided. In my PR to return a special value #7214, this is tricky or let 'rust' never return a dp_rank ?

Good catch. Having @PeaBrane review your change. If it makes sense we can cherry pick it into this PR and get all changes in at once. As always will give you credit for the work 😄

@ishandhanani ishandhanani requested a review from a team as a code owner March 25, 2026 14:07
@github-actions github-actions bot added the router Relates to routing, KV-aware routing, etc. label Mar 25, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 25, 2026

ishandhanani and others added 2 commits March 25, 2026 15:09
When router_mode is not KV, the SimpleRouter path was hardcoding dp_rank
to 0, causing prefill to always target the first data parallel rank. Use
u32::MAX as a sentinel value instead, and treat it as None on the Python
side so SGLang picks the correct rank.

Cherry-picked from #7214.

Co-Authored-By: huitian bai <baihuitian.bht@gmail.com>
Avoid recomputing 2**32-1 on every request.

Co-Authored-By: huitian bai <baihuitian.bht@gmail.com>
@ishandhanani ishandhanani force-pushed the idhanani/disagg-dp-rank-routing branch from 258f8df to bb1a919 Compare March 25, 2026 14:10
@ishandhanani ishandhanani enabled auto-merge (squash) March 25, 2026 14:10
@ishandhanani ishandhanani merged commit 7edb07b into main Mar 25, 2026
85 of 89 checks passed
@ishandhanani ishandhanani deleted the idhanani/disagg-dp-rank-routing branch March 25, 2026 15:05
MatejKosec added a commit that referenced this pull request Mar 29, 2026
- Revert publisher.py changes (PR #6736 handles the SGLang compat)
- Unify /// doc comments to // regular comments in reasoning parser tests

Signed-off-by: Matej Kosec <mkosec@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend::sglang Relates to the sglang backend documentation Improvements or additions to documentation feat router Relates to routing, KV-aware routing, etc. size/L

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants