Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 175 additions & 5 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,111 @@

_is_cuda = is_cuda()

# ---------------- MoE routing helpers ----------------
from typing import Dict as _Dict, List as _List, Iterable as _Iterable, Set as _Set
from collections import Counter
import torch

def _records_to_per_rids_subset(
records_obj,
wanted_rids: _Set[str],
*,
allow_multi_active_sequences: bool = False,
) -> _Dict[str, _Dict]:
"""
Fast path: single pass over records, accumulating only the RIDs we care about.
Handles both PREFILL (flattened tokens) and DECODE (1 token per row).
Output per rid: { "topk_ids_of_layer": Tensor[L, T_total, K], "positions": List[int], "physical_to_logical_map": Tensor[L,E_phys], "shape": {...} }
"""
if not wanted_rids:
return {}

records = records_obj.get("records", [])
phys_to_log = records_obj.get("last_physical_to_logical_map", None)

per_rid_chunks: dict[str, _List[torch.Tensor]] = {rid: [] for rid in wanted_rids}
per_rid_positions: dict[str, _List[int]] = {rid: [] for rid in wanted_rids}

def _append_piece(rid: str, topk_piece: torch.Tensor, pos_slice: _Iterable[int]):
per_rid_chunks[rid].append(topk_piece)
per_rid_positions[rid].extend(pos_slice)

for rec in records:
topk: torch.Tensor = rec["topk_ids_of_layer"] # [L, T, K] (CPU)
rids: _List[str] = rec["rids"]
positions: _List[int] = rec["positions"]
ext: _List[int] | None = rec.get("extend_seq_lens", None)
T = int(topk.shape[1])

# Decode heuristic: exactly one token per active row
is_decode = (ext is None) or (len(positions) == len(rids))

if not allow_multi_active_sequences:
cnt = Counter(rids)
dups = [rid for rid, c in cnt.items() if c > 1]
if dups:
# If any of those dups intersect our wanted set, it's a real violation for our use-case
if any(r in wanted_rids for r in dups):
stage = "DECODE" if is_decode else "PREFILL"
raise AssertionError(
f"Multiple active sequences detected in {stage} pass for RID(s): {sorted(set(dups) & wanted_rids)}. "
f"Disable fan-out (n=1, no beam/best_of) or set allow_multi_active_sequences=True."
)

if is_decode:
assert T == len(rids) == len(positions), \
f"Decode mismatch: T={T}, |rids|={len(rids)}, |positions|={len(positions)}"
# Column i belongs to rids[i]
for i, rid in enumerate(rids):
if rid in wanted_rids:
_append_piece(rid, topk[:, i:i+1, :], [positions[i]])
else:
assert sum(ext) == T == len(positions), \
f"Prefill mismatch: sum(ext)={sum(ext)}, T={T}, |positions|={len(positions)}"
off = 0
for rid, n in zip(rids, ext):
if rid in wanted_rids and n > 0:
sl = slice(off, off + n)
_append_piece(rid, topk[:, sl, :], positions[off:off+n])
off += n

out: _Dict[str, _Dict] = {}
p2l = phys_to_log.cpu() if isinstance(phys_to_log, torch.Tensor) else phys_to_log
for rid in wanted_rids:
parts = per_rid_chunks[rid]
if not parts:
continue
cat = torch.cat(parts, dim=1).cpu() # [L, T_total, K]
out[rid] = {
"topk_ids_of_layer": cat,
"positions": per_rid_positions[rid],
"physical_to_logical_map": p2l,
"shape": {"num_layers": int(cat.shape[0]), "num_tokens": int(cat.shape[1]), "top_k": int(cat.shape[2])},
}
return out

def _attach_routing_to_ret(ret, per_rid: _Dict[str, _Dict]) -> None:
"""Attach per-RID routing to each response item's meta_info."""
def _serialize(entry: _Dict) -> _Dict:
topk = entry["topk_ids_of_layer"]
p2l = entry["physical_to_logical_map"]
return {
"topk_ids_of_layer": topk.tolist(),
"positions": entry["positions"],
"physical_to_logical_map": (p2l.tolist() if isinstance(p2l, torch.Tensor) else None),
"shape": entry["shape"],
}

if isinstance(ret, list):
for item in ret:
rid = item.get("meta_info", {}).get("id")
if rid and rid in per_rid:
item.setdefault("meta_info", {})["moe_routing"] = {"rid": rid, **_serialize(per_rid[rid])}
elif isinstance(ret, dict):
rid = ret.get("meta_info", {}).get("id")
if rid and rid in per_rid:
ret.setdefault("meta_info", {})["moe_routing"] = {"rid": rid, **_serialize(per_rid[rid])}
# -----------------------------------------------------

class Engine(EngineBase):
"""
Expand Down Expand Up @@ -137,6 +242,8 @@ def __init__(self, **kwargs):
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
)

self._expert_routing_lock = asyncio.Lock()

def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
Expand Down Expand Up @@ -246,6 +353,7 @@ async def async_generate(
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
return_expert_routing: Optional[Union[List[bool], bool]] = False,
) -> Union[Dict, AsyncIterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Expand Down Expand Up @@ -283,12 +391,74 @@ async def async_generate(
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
generator = self.tokenizer_manager.generate_request(obj, None)

if stream is True:
return generator

# Attach routing only if requested
def _any_true(x):
if isinstance(x, list):
return any(bool(v) for v in x)
return bool(x)

if not _any_true(return_expert_routing):
generator = self.tokenizer_manager.generate_request(obj, None)

if stream is True:
return generator
else:
return await generator.__anext__()
else:
return await generator.__anext__()
await self._expert_routing_lock.acquire()
try:
if stream is True:
raise NotImplementedError("Expert statistics logging is not availale in streaming mode.")
else:
await self.tokenizer_manager.start_expert_distribution_record()
generator = self.tokenizer_manager.generate_request(obj, None)

ret = await generator.__anext__()
await self.tokenizer_manager.stop_expert_distribution_record()
records_obj = await self.tokenizer_manager.dump_expert_distribution_record_object()
# per_req = _split_routing_per_request(records_obj)
# Attach under meta_info; consumers can ignore if unused
# ret.setdefault("meta_info", {})["moe_routing_per_request"] = {
# rid: {
# "topk_ids_of_layer": v["topk_ids_of_layer"].tolist(),
# "positions": v["positions"],
# "physical_to_logical_map": v["physical_to_logical_map"].cpu().tolist(),
# }
# for rid, v in per_req.items()
# }

# Determine which RIDs we actually need to attach
if isinstance(ret, list):
wanted = {item.get("meta_info", {}).get("id") for item in ret if item.get("meta_info")}
else:
wanted = {ret.get("meta_info", {}).get("id")} if ret.get("meta_info") else set()
wanted = {rid for rid in wanted if rid} # drop Nones

# breakpoint()
# Fast single-pass subset aggregation
per_rid = _records_to_per_rids_subset(
records_obj,
wanted_rids=wanted,
allow_multi_active_sequences=False, # TODO flip to True only if you explicitly support fan-out
)

_attach_routing_to_ret(ret, per_rid)
# breakpoint()

# # TODO DEBUGGING ONLY
# breakpoint() # TODO NOTE what if one prompt reaches EOS. records-1 or simply padding token id in the input_ids, where we can try using the rwo_idx to match back
# # TODO MOVE per req all to cpu
# ret.setdefault("meta_info", {})["moe_routing_per_request"] = per_req

Comment on lines +420 to +453
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are several commented-out blocks of code, including breakpoint() calls and old implementation logic. These should be removed to improve code clarity and maintainability.

return ret
finally:
# Safety: ensure recorder is not left on, and release the lock
try:
await self.tokenizer_manager.stop_expert_distribution_record()
except Exception:
pass
Comment on lines +457 to +460
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Swallowing exceptions with a pass statement can hide important errors and make debugging difficult. It's better to at least log the exception to be aware of potential issues during the execution of stop_expert_distribution_record.

Suggested change
try:
await self.tokenizer_manager.stop_expert_distribution_record()
except Exception:
pass
try:
await self.tokenizer_manager.stop_expert_distribution_record()
except Exception as e:
logger.warning(f"Error stopping expert distribution record: {e}")

self._expert_routing_lock.release()

def encode(
self,
Expand Down
56 changes: 48 additions & 8 deletions python/sglang/srt/eplb/expert_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
server_args.chunked_prefill_size * 8,
self._TOP_K_NUM,
),
dtype=torch.int32,
dtype=torch.uint8,
device=server_args.device,
)
self._misc_objects: List[Dict[str, Any]] = []
Expand All @@ -367,11 +367,32 @@ def __init__(
), "DetailSinglePassGatherer does not support TBO yet"
# TODO assert shared experts fusion is disabled, o/w data is wrong

def _ensure_capacity(self, need_tokens: int):
L, cur, K = self._topk_ids_of_layer.shape
if need_tokens > cur:
new = torch.full((L, need_tokens, K), -1,
dtype=self._topk_ids_of_layer.dtype,
device=self._topk_ids_of_layer.device)
Comment on lines +373 to +375
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.full(..., -1, ...) with dtype=torch.uint8 will cause the fill value to wrap around to 255. If 255 is a valid expert ID, this could lead to subtle bugs. Additionally, uint8 limits the number of experts to 255. It would be safer to add an assertion in _DetailSinglePassGatherer.__init__ to ensure the number of experts is within the valid range for uint8 and use a fill value that is guaranteed not to be a valid expert ID, such as torch.iinfo(torch.uint8).max.

Suggested change
new = torch.full((L, need_tokens, K), -1,
dtype=self._topk_ids_of_layer.dtype,
device=self._topk_ids_of_layer.device)
new = torch.full((L, need_tokens, K), torch.iinfo(self._topk_ids_of_layer.dtype).max,
dtype=self._topk_ids_of_layer.dtype,
device=self._topk_ids_of_layer.device)

new[:, :cur, :] = self._topk_ids_of_layer
self._topk_ids_of_layer = new

def on_forward_pass_start(self, forward_batch: ForwardBatch):
assert self._metadata is None

# shape assertion, tho may be redundant
if forward_batch.forward_mode == "decode":
assert (
len(forward_batch.rids)
== forward_batch.input_ids.numel()
== forward_batch.positions.numel()
), "Decode: rids/input_ids/positions must align 1:1"

T_pass = forward_batch.input_ids.numel() # prefill: sum(extend_seq_lens); decode: batch size
self._ensure_capacity(T_pass)

self._metadata = dict(
# TODO pr-chain
# rids=forward_batch.rids,
rids=forward_batch.rids, # # record request IDs so we can split tokens by prompt
input_ids=forward_batch.input_ids.cpu().tolist(),
positions=forward_batch.positions.cpu().tolist(),
extend_seq_lens=forward_batch.extend_seq_lens_cpu,
Expand Down Expand Up @@ -726,16 +747,35 @@ def reset(self):
super().reset()
self._records.clear()

# def dump(self, output_mode: _OutputMode):
# assert output_mode == "file"
# output = dict(
# records=self._records,
# # NOTE: This may change during recording, so here we say it is the "last" one
# last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
# )
# _dump_to_file(
# f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output
# )
Comment on lines +750 to +759
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out implementation of the dump method should be removed to keep the codebase clean and avoid confusion.


def dump(self, output_mode: _OutputMode):
assert output_mode == "file"
output = dict(
records=self._records,
# NOTE: This may change during recording, so here we say it is the "last" one
# copy the list so resetting doesn’t wipe caller’s view
records=list(self._records),
# Last mapping used during recording
last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
_dump_to_file(
f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output
)
if output_mode == "file":
_dump_to_file(
f"expert_distribution_recorder_{time.time()}_{self._rank}.pt",
output,
)
elif output_mode == "object":
# return the raw object without saving
return output
else:
raise NotImplementedError(f"Unknown output mode: {output_mode}")



class _StatAccumulator(_UtilizationRateAccumulatorMixin):
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,11 +1084,19 @@ class ExpertDistributionReq(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
# New: request the in-memory dump (Python object) instead of writing files
DUMP_RECORD_OBJECT = 4


@dataclass
class ExpertDistributionReqOutput:
pass
# success/failure for the op
success: bool
# optional details
message: str = "" # message: Optional[str] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out type hint seems to be a leftover from development and should be removed for code cleanliness.

Suggested change
message: str = "" # message: Optional[str] = None
message: str = ""

# New: used only by DUMP_RECORD_OBJECT; holds the recorder dump object
# shape: {"records": [...], "last_physical_to_logical_map": Tensor or list}
payload: Optional[Any] = None


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,7 @@ def get_model_worker_batch(
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
rids=[req.rid for req in self.reqs],
)

def copy(self):
Expand Down Expand Up @@ -1922,6 +1923,8 @@ class ModelWorkerBatch:
# Overlap event
launch_done: Optional[threading.Event] = None

rids: list[str] = None


@triton.jit
def write_req_to_token_pool_triton(
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,10 +2484,16 @@ def slow_down(self, recv_req: SlowDownReqInput):
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
get_global_expert_distribution_recorder().start_record()
return ExpertDistributionReqOutput(success=True, message="started")
elif recv_req == ExpertDistributionReq.STOP_RECORD:
get_global_expert_distribution_recorder().stop_record()
return ExpertDistributionReqOutput(success=True, message="stopped")
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
get_global_expert_distribution_recorder().dump_record()
return ExpertDistributionReqOutput(success=True, message="dumped_to_files")
elif recv_req == ExpertDistributionReq.DUMP_RECORD_OBJECT:
obj = get_global_expert_distribution_recorder().dump_record(output_mode="object")
return ExpertDistributionReqOutput(success=True, payload=obj)
else:
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
return ExpertDistributionReqOutput()
Expand Down
29 changes: 29 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,35 @@ async def dump_expert_distribution_record(self):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)

async def dump_expert_distribution_record_object(self):
"""
Returns the expert-distribution recorder dump as a Python object
(not a file), merged across DP ranks if needed.
Requires the scheduler to support DUMP_RECORD_OBJECT and return
ExpertDistributionReqOutput with `payload` holding the object.
"""
self.auto_create_handle_loop()
results: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator(
ExpertDistributionReq.DUMP_RECORD_OBJECT
)
# `results` has one element per DP rank (fan_out = dp_size).
# Each `.payload` is either None or a dict like:
# {"records": [...], "last_physical_to_logical_map": Tensor[L, E_phys]}
objs = [r.payload for r in results if getattr(r, "payload", None) is not None]
if not objs:
# Fallback empty shape if nothing was recorded
return {"records": [], "last_physical_to_logical_map": None}

merged = {
"records": [],
"last_physical_to_logical_map": objs[-1]["last_physical_to_logical_map"],
}
for o in objs:
merged["records"].extend(o["records"])
# We keep the "last" physical_to_logical_map; in practice maps
# are identical across ranks for a given run.
return merged

async def pause_generation(self):
async with self.is_pause_cond:
self.is_pause = True
Expand Down
Loading