diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 2b576b409399..180c17e25135 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -85,6 +85,111 @@ _is_cuda = is_cuda() +# ---------------- MoE routing helpers ---------------- +from typing import Dict as _Dict, List as _List, Iterable as _Iterable, Set as _Set +from collections import Counter +import torch + +def _records_to_per_rids_subset( + records_obj, + wanted_rids: _Set[str], + *, + allow_multi_active_sequences: bool = False, +) -> _Dict[str, _Dict]: + """ + Fast path: single pass over records, accumulating only the RIDs we care about. + Handles both PREFILL (flattened tokens) and DECODE (1 token per row). + Output per rid: { "topk_ids_of_layer": Tensor[L, T_total, K], "positions": List[int], "physical_to_logical_map": Tensor[L,E_phys], "shape": {...} } + """ + if not wanted_rids: + return {} + + records = records_obj.get("records", []) + phys_to_log = records_obj.get("last_physical_to_logical_map", None) + + per_rid_chunks: dict[str, _List[torch.Tensor]] = {rid: [] for rid in wanted_rids} + per_rid_positions: dict[str, _List[int]] = {rid: [] for rid in wanted_rids} + + def _append_piece(rid: str, topk_piece: torch.Tensor, pos_slice: _Iterable[int]): + per_rid_chunks[rid].append(topk_piece) + per_rid_positions[rid].extend(pos_slice) + + for rec in records: + topk: torch.Tensor = rec["topk_ids_of_layer"] # [L, T, K] (CPU) + rids: _List[str] = rec["rids"] + positions: _List[int] = rec["positions"] + ext: _List[int] | None = rec.get("extend_seq_lens", None) + T = int(topk.shape[1]) + + # Decode heuristic: exactly one token per active row + is_decode = (ext is None) or (len(positions) == len(rids)) + + if not allow_multi_active_sequences: + cnt = Counter(rids) + dups = [rid for rid, c in cnt.items() if c > 1] + if dups: + # If any of those dups intersect our wanted set, it's a real violation for our use-case + if any(r in wanted_rids for r in dups): + stage = "DECODE" if is_decode else "PREFILL" + raise AssertionError( + f"Multiple active sequences detected in {stage} pass for RID(s): {sorted(set(dups) & wanted_rids)}. " + f"Disable fan-out (n=1, no beam/best_of) or set allow_multi_active_sequences=True." + ) + + if is_decode: + assert T == len(rids) == len(positions), \ + f"Decode mismatch: T={T}, |rids|={len(rids)}, |positions|={len(positions)}" + # Column i belongs to rids[i] + for i, rid in enumerate(rids): + if rid in wanted_rids: + _append_piece(rid, topk[:, i:i+1, :], [positions[i]]) + else: + assert sum(ext) == T == len(positions), \ + f"Prefill mismatch: sum(ext)={sum(ext)}, T={T}, |positions|={len(positions)}" + off = 0 + for rid, n in zip(rids, ext): + if rid in wanted_rids and n > 0: + sl = slice(off, off + n) + _append_piece(rid, topk[:, sl, :], positions[off:off+n]) + off += n + + out: _Dict[str, _Dict] = {} + p2l = phys_to_log.cpu() if isinstance(phys_to_log, torch.Tensor) else phys_to_log + for rid in wanted_rids: + parts = per_rid_chunks[rid] + if not parts: + continue + cat = torch.cat(parts, dim=1).cpu() # [L, T_total, K] + out[rid] = { + "topk_ids_of_layer": cat, + "positions": per_rid_positions[rid], + "physical_to_logical_map": p2l, + "shape": {"num_layers": int(cat.shape[0]), "num_tokens": int(cat.shape[1]), "top_k": int(cat.shape[2])}, + } + return out + +def _attach_routing_to_ret(ret, per_rid: _Dict[str, _Dict]) -> None: + """Attach per-RID routing to each response item's meta_info.""" + def _serialize(entry: _Dict) -> _Dict: + topk = entry["topk_ids_of_layer"] + p2l = entry["physical_to_logical_map"] + return { + "topk_ids_of_layer": topk.tolist(), + "positions": entry["positions"], + "physical_to_logical_map": (p2l.tolist() if isinstance(p2l, torch.Tensor) else None), + "shape": entry["shape"], + } + + if isinstance(ret, list): + for item in ret: + rid = item.get("meta_info", {}).get("id") + if rid and rid in per_rid: + item.setdefault("meta_info", {})["moe_routing"] = {"rid": rid, **_serialize(per_rid[rid])} + elif isinstance(ret, dict): + rid = ret.get("meta_info", {}).get("id") + if rid and rid in per_rid: + ret.setdefault("meta_info", {})["moe_routing"] = {"rid": rid, **_serialize(per_rid[rid])} +# ----------------------------------------------------- class Engine(EngineBase): """ @@ -137,6 +242,8 @@ def __init__(self, **kwargs): context, zmq.DEALER, self.port_args.rpc_ipc_name, True ) + self._expert_routing_lock = asyncio.Lock() + def generate( self, # The input prompt. It can be a single prompt or a batch of prompts. @@ -246,6 +353,7 @@ async def async_generate( bootstrap_port: Optional[Union[List[int], int]] = None, bootstrap_room: Optional[Union[List[int], int]] = None, data_parallel_rank: Optional[int] = None, + return_expert_routing: Optional[Union[List[bool], bool]] = False, ) -> Union[Dict, AsyncIterator[Dict]]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. @@ -283,12 +391,74 @@ async def async_generate( bootstrap_room=bootstrap_room, data_parallel_rank=data_parallel_rank, ) - generator = self.tokenizer_manager.generate_request(obj, None) - - if stream is True: - return generator + + # Attach routing only if requested + def _any_true(x): + if isinstance(x, list): + return any(bool(v) for v in x) + return bool(x) + + if not _any_true(return_expert_routing): + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream is True: + return generator + else: + return await generator.__anext__() else: - return await generator.__anext__() + await self._expert_routing_lock.acquire() + try: + if stream is True: + raise NotImplementedError("Expert statistics logging is not availale in streaming mode.") + else: + await self.tokenizer_manager.start_expert_distribution_record() + generator = self.tokenizer_manager.generate_request(obj, None) + + ret = await generator.__anext__() + await self.tokenizer_manager.stop_expert_distribution_record() + records_obj = await self.tokenizer_manager.dump_expert_distribution_record_object() + # per_req = _split_routing_per_request(records_obj) + # Attach under meta_info; consumers can ignore if unused + # ret.setdefault("meta_info", {})["moe_routing_per_request"] = { + # rid: { + # "topk_ids_of_layer": v["topk_ids_of_layer"].tolist(), + # "positions": v["positions"], + # "physical_to_logical_map": v["physical_to_logical_map"].cpu().tolist(), + # } + # for rid, v in per_req.items() + # } + + # Determine which RIDs we actually need to attach + if isinstance(ret, list): + wanted = {item.get("meta_info", {}).get("id") for item in ret if item.get("meta_info")} + else: + wanted = {ret.get("meta_info", {}).get("id")} if ret.get("meta_info") else set() + wanted = {rid for rid in wanted if rid} # drop Nones + + # breakpoint() + # Fast single-pass subset aggregation + per_rid = _records_to_per_rids_subset( + records_obj, + wanted_rids=wanted, + allow_multi_active_sequences=False, # TODO flip to True only if you explicitly support fan-out + ) + + _attach_routing_to_ret(ret, per_rid) + # breakpoint() + + # # TODO DEBUGGING ONLY + # breakpoint() # TODO NOTE what if one prompt reaches EOS. records-1 or simply padding token id in the input_ids, where we can try using the rwo_idx to match back + # # TODO MOVE per req all to cpu + # ret.setdefault("meta_info", {})["moe_routing_per_request"] = per_req + + return ret + finally: + # Safety: ensure recorder is not left on, and release the lock + try: + await self.tokenizer_manager.stop_expert_distribution_record() + except Exception: + pass + self._expert_routing_lock.release() def encode( self, diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index 1b3d573d8b29..28ef28df1640 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -358,7 +358,7 @@ def __init__( server_args.chunked_prefill_size * 8, self._TOP_K_NUM, ), - dtype=torch.int32, + dtype=torch.uint8, device=server_args.device, ) self._misc_objects: List[Dict[str, Any]] = [] @@ -367,11 +367,32 @@ def __init__( ), "DetailSinglePassGatherer does not support TBO yet" # TODO assert shared experts fusion is disabled, o/w data is wrong + def _ensure_capacity(self, need_tokens: int): + L, cur, K = self._topk_ids_of_layer.shape + if need_tokens > cur: + new = torch.full((L, need_tokens, K), -1, + dtype=self._topk_ids_of_layer.dtype, + device=self._topk_ids_of_layer.device) + new[:, :cur, :] = self._topk_ids_of_layer + self._topk_ids_of_layer = new + def on_forward_pass_start(self, forward_batch: ForwardBatch): assert self._metadata is None + + # shape assertion, tho may be redundant + if forward_batch.forward_mode == "decode": + assert ( + len(forward_batch.rids) + == forward_batch.input_ids.numel() + == forward_batch.positions.numel() + ), "Decode: rids/input_ids/positions must align 1:1" + + T_pass = forward_batch.input_ids.numel() # prefill: sum(extend_seq_lens); decode: batch size + self._ensure_capacity(T_pass) + self._metadata = dict( # TODO pr-chain - # rids=forward_batch.rids, + rids=forward_batch.rids, # # record request IDs so we can split tokens by prompt input_ids=forward_batch.input_ids.cpu().tolist(), positions=forward_batch.positions.cpu().tolist(), extend_seq_lens=forward_batch.extend_seq_lens_cpu, @@ -726,16 +747,35 @@ def reset(self): super().reset() self._records.clear() + # def dump(self, output_mode: _OutputMode): + # assert output_mode == "file" + # output = dict( + # records=self._records, + # # NOTE: This may change during recording, so here we say it is the "last" one + # last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map, + # ) + # _dump_to_file( + # f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output + # ) + def dump(self, output_mode: _OutputMode): - assert output_mode == "file" output = dict( - records=self._records, - # NOTE: This may change during recording, so here we say it is the "last" one + # copy the list so resetting doesn’t wipe caller’s view + records=list(self._records), + # Last mapping used during recording last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map, ) - _dump_to_file( - f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output - ) + if output_mode == "file": + _dump_to_file( + f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", + output, + ) + elif output_mode == "object": + # return the raw object without saving + return output + else: + raise NotImplementedError(f"Unknown output mode: {output_mode}") + class _StatAccumulator(_UtilizationRateAccumulatorMixin): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 917d387fe5db..a5db224923cb 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1084,11 +1084,19 @@ class ExpertDistributionReq(Enum): START_RECORD = 1 STOP_RECORD = 2 DUMP_RECORD = 3 + # New: request the in-memory dump (Python object) instead of writing files + DUMP_RECORD_OBJECT = 4 @dataclass class ExpertDistributionReqOutput: - pass + # success/failure for the op + success: bool + # optional details + message: str = "" # message: Optional[str] = None + # New: used only by DUMP_RECORD_OBJECT; holds the recorder dump object + # shape: {"records": [...], "last_physical_to_logical_map": Tensor or list} + payload: Optional[Any] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a35ba025304f..9d3a3480a677 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1780,6 +1780,7 @@ def get_model_worker_batch( ), extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, launch_done=self.launch_done, + rids=[req.rid for req in self.reqs], ) def copy(self): @@ -1922,6 +1923,8 @@ class ModelWorkerBatch: # Overlap event launch_done: Optional[threading.Event] = None + rids: list[str] = None + @triton.jit def write_req_to_token_pool_triton( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 38ff0ef145d6..5154ebf6573d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2484,10 +2484,16 @@ def slow_down(self, recv_req: SlowDownReqInput): def expert_distribution_handle(self, recv_req: ExpertDistributionReq): if recv_req == ExpertDistributionReq.START_RECORD: get_global_expert_distribution_recorder().start_record() + return ExpertDistributionReqOutput(success=True, message="started") elif recv_req == ExpertDistributionReq.STOP_RECORD: get_global_expert_distribution_recorder().stop_record() + return ExpertDistributionReqOutput(success=True, message="stopped") elif recv_req == ExpertDistributionReq.DUMP_RECORD: get_global_expert_distribution_recorder().dump_record() + return ExpertDistributionReqOutput(success=True, message="dumped_to_files") + elif recv_req == ExpertDistributionReq.DUMP_RECORD_OBJECT: + obj = get_global_expert_distribution_recorder().dump_record(output_mode="object") + return ExpertDistributionReqOutput(success=True, payload=obj) else: raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}") return ExpertDistributionReqOutput() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 36fd4964b15e..d6973c9ead3c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1062,6 +1062,35 @@ async def dump_expert_distribution_record(self): self.auto_create_handle_loop() await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) + async def dump_expert_distribution_record_object(self): + """ + Returns the expert-distribution recorder dump as a Python object + (not a file), merged across DP ranks if needed. + Requires the scheduler to support DUMP_RECORD_OBJECT and return + ExpertDistributionReqOutput with `payload` holding the object. + """ + self.auto_create_handle_loop() + results: List[ExpertDistributionReqOutput] = await self.expert_distribution_communicator( + ExpertDistributionReq.DUMP_RECORD_OBJECT + ) + # `results` has one element per DP rank (fan_out = dp_size). + # Each `.payload` is either None or a dict like: + # {"records": [...], "last_physical_to_logical_map": Tensor[L, E_phys]} + objs = [r.payload for r in results if getattr(r, "payload", None) is not None] + if not objs: + # Fallback empty shape if nothing was recorded + return {"records": [], "last_physical_to_logical_map": None} + + merged = { + "records": [], + "last_physical_to_logical_map": objs[-1]["last_physical_to_logical_map"], + } + for o in objs: + merged["records"].extend(o["records"]) + # We keep the "last" physical_to_logical_map; in practice maps + # are identical across ranks for a given run. + return merged + async def pause_generation(self): async with self.is_pause_cond: self.is_pause = True diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 65c0a07f8ab1..6ed87a1d485b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -306,6 +306,8 @@ class ForwardBatch: tbo_parent_token_range: Optional[Tuple[int, int]] = None tbo_children: Optional[List[ForwardBatch]] = None + rids: list[str] = None + @classmethod def init_new( cls, @@ -346,6 +348,7 @@ def init_new( input_embeds=batch.input_embeds, token_type_ids=batch.token_type_ids, tbo_split_seq_index=batch.tbo_split_seq_index, + rids=batch.rids, ) device = model_runner.device