-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Return expert routing info to support MoE routing replay #9499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4d045c0
2d97ed8
ed0e583
6face71
5a02fa8
3b95e67
93dfc0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -85,6 +85,111 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| _is_cuda = is_cuda() | ||||||||||||||||||
|
|
||||||||||||||||||
| # ---------------- MoE routing helpers ---------------- | ||||||||||||||||||
| from typing import Dict as _Dict, List as _List, Iterable as _Iterable, Set as _Set | ||||||||||||||||||
| from collections import Counter | ||||||||||||||||||
| import torch | ||||||||||||||||||
|
|
||||||||||||||||||
| def _records_to_per_rids_subset( | ||||||||||||||||||
| records_obj, | ||||||||||||||||||
| wanted_rids: _Set[str], | ||||||||||||||||||
| *, | ||||||||||||||||||
| allow_multi_active_sequences: bool = False, | ||||||||||||||||||
| ) -> _Dict[str, _Dict]: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Fast path: single pass over records, accumulating only the RIDs we care about. | ||||||||||||||||||
| Handles both PREFILL (flattened tokens) and DECODE (1 token per row). | ||||||||||||||||||
| Output per rid: { "topk_ids_of_layer": Tensor[L, T_total, K], "positions": List[int], "physical_to_logical_map": Tensor[L,E_phys], "shape": {...} } | ||||||||||||||||||
| """ | ||||||||||||||||||
| if not wanted_rids: | ||||||||||||||||||
| return {} | ||||||||||||||||||
|
|
||||||||||||||||||
| records = records_obj.get("records", []) | ||||||||||||||||||
| phys_to_log = records_obj.get("last_physical_to_logical_map", None) | ||||||||||||||||||
|
|
||||||||||||||||||
| per_rid_chunks: dict[str, _List[torch.Tensor]] = {rid: [] for rid in wanted_rids} | ||||||||||||||||||
| per_rid_positions: dict[str, _List[int]] = {rid: [] for rid in wanted_rids} | ||||||||||||||||||
|
|
||||||||||||||||||
| def _append_piece(rid: str, topk_piece: torch.Tensor, pos_slice: _Iterable[int]): | ||||||||||||||||||
| per_rid_chunks[rid].append(topk_piece) | ||||||||||||||||||
| per_rid_positions[rid].extend(pos_slice) | ||||||||||||||||||
|
|
||||||||||||||||||
| for rec in records: | ||||||||||||||||||
| topk: torch.Tensor = rec["topk_ids_of_layer"] # [L, T, K] (CPU) | ||||||||||||||||||
| rids: _List[str] = rec["rids"] | ||||||||||||||||||
| positions: _List[int] = rec["positions"] | ||||||||||||||||||
| ext: _List[int] | None = rec.get("extend_seq_lens", None) | ||||||||||||||||||
| T = int(topk.shape[1]) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Decode heuristic: exactly one token per active row | ||||||||||||||||||
| is_decode = (ext is None) or (len(positions) == len(rids)) | ||||||||||||||||||
|
|
||||||||||||||||||
| if not allow_multi_active_sequences: | ||||||||||||||||||
| cnt = Counter(rids) | ||||||||||||||||||
| dups = [rid for rid, c in cnt.items() if c > 1] | ||||||||||||||||||
| if dups: | ||||||||||||||||||
| # If any of those dups intersect our wanted set, it's a real violation for our use-case | ||||||||||||||||||
| if any(r in wanted_rids for r in dups): | ||||||||||||||||||
| stage = "DECODE" if is_decode else "PREFILL" | ||||||||||||||||||
| raise AssertionError( | ||||||||||||||||||
| f"Multiple active sequences detected in {stage} pass for RID(s): {sorted(set(dups) & wanted_rids)}. " | ||||||||||||||||||
| f"Disable fan-out (n=1, no beam/best_of) or set allow_multi_active_sequences=True." | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| if is_decode: | ||||||||||||||||||
| assert T == len(rids) == len(positions), \ | ||||||||||||||||||
| f"Decode mismatch: T={T}, |rids|={len(rids)}, |positions|={len(positions)}" | ||||||||||||||||||
| # Column i belongs to rids[i] | ||||||||||||||||||
| for i, rid in enumerate(rids): | ||||||||||||||||||
| if rid in wanted_rids: | ||||||||||||||||||
| _append_piece(rid, topk[:, i:i+1, :], [positions[i]]) | ||||||||||||||||||
| else: | ||||||||||||||||||
| assert sum(ext) == T == len(positions), \ | ||||||||||||||||||
| f"Prefill mismatch: sum(ext)={sum(ext)}, T={T}, |positions|={len(positions)}" | ||||||||||||||||||
| off = 0 | ||||||||||||||||||
| for rid, n in zip(rids, ext): | ||||||||||||||||||
| if rid in wanted_rids and n > 0: | ||||||||||||||||||
| sl = slice(off, off + n) | ||||||||||||||||||
| _append_piece(rid, topk[:, sl, :], positions[off:off+n]) | ||||||||||||||||||
| off += n | ||||||||||||||||||
|
|
||||||||||||||||||
| out: _Dict[str, _Dict] = {} | ||||||||||||||||||
| p2l = phys_to_log.cpu() if isinstance(phys_to_log, torch.Tensor) else phys_to_log | ||||||||||||||||||
| for rid in wanted_rids: | ||||||||||||||||||
| parts = per_rid_chunks[rid] | ||||||||||||||||||
| if not parts: | ||||||||||||||||||
| continue | ||||||||||||||||||
| cat = torch.cat(parts, dim=1).cpu() # [L, T_total, K] | ||||||||||||||||||
| out[rid] = { | ||||||||||||||||||
| "topk_ids_of_layer": cat, | ||||||||||||||||||
| "positions": per_rid_positions[rid], | ||||||||||||||||||
| "physical_to_logical_map": p2l, | ||||||||||||||||||
| "shape": {"num_layers": int(cat.shape[0]), "num_tokens": int(cat.shape[1]), "top_k": int(cat.shape[2])}, | ||||||||||||||||||
| } | ||||||||||||||||||
| return out | ||||||||||||||||||
|
|
||||||||||||||||||
| def _attach_routing_to_ret(ret, per_rid: _Dict[str, _Dict]) -> None: | ||||||||||||||||||
| """Attach per-RID routing to each response item's meta_info.""" | ||||||||||||||||||
| def _serialize(entry: _Dict) -> _Dict: | ||||||||||||||||||
| topk = entry["topk_ids_of_layer"] | ||||||||||||||||||
| p2l = entry["physical_to_logical_map"] | ||||||||||||||||||
| return { | ||||||||||||||||||
| "topk_ids_of_layer": topk.tolist(), | ||||||||||||||||||
| "positions": entry["positions"], | ||||||||||||||||||
| "physical_to_logical_map": (p2l.tolist() if isinstance(p2l, torch.Tensor) else None), | ||||||||||||||||||
| "shape": entry["shape"], | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| if isinstance(ret, list): | ||||||||||||||||||
| for item in ret: | ||||||||||||||||||
| rid = item.get("meta_info", {}).get("id") | ||||||||||||||||||
| if rid and rid in per_rid: | ||||||||||||||||||
| item.setdefault("meta_info", {})["moe_routing"] = {"rid": rid, **_serialize(per_rid[rid])} | ||||||||||||||||||
| elif isinstance(ret, dict): | ||||||||||||||||||
| rid = ret.get("meta_info", {}).get("id") | ||||||||||||||||||
| if rid and rid in per_rid: | ||||||||||||||||||
| ret.setdefault("meta_info", {})["moe_routing"] = {"rid": rid, **_serialize(per_rid[rid])} | ||||||||||||||||||
| # ----------------------------------------------------- | ||||||||||||||||||
|
|
||||||||||||||||||
| class Engine(EngineBase): | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
@@ -137,6 +242,8 @@ def __init__(self, **kwargs): | |||||||||||||||||
| context, zmq.DEALER, self.port_args.rpc_ipc_name, True | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| self._expert_routing_lock = asyncio.Lock() | ||||||||||||||||||
|
|
||||||||||||||||||
| def generate( | ||||||||||||||||||
| self, | ||||||||||||||||||
| # The input prompt. It can be a single prompt or a batch of prompts. | ||||||||||||||||||
|
|
@@ -246,6 +353,7 @@ async def async_generate( | |||||||||||||||||
| bootstrap_port: Optional[Union[List[int], int]] = None, | ||||||||||||||||||
| bootstrap_room: Optional[Union[List[int], int]] = None, | ||||||||||||||||||
| data_parallel_rank: Optional[int] = None, | ||||||||||||||||||
| return_expert_routing: Optional[Union[List[bool], bool]] = False, | ||||||||||||||||||
| ) -> Union[Dict, AsyncIterator[Dict]]: | ||||||||||||||||||
| """ | ||||||||||||||||||
| The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. | ||||||||||||||||||
|
|
@@ -283,12 +391,74 @@ async def async_generate( | |||||||||||||||||
| bootstrap_room=bootstrap_room, | ||||||||||||||||||
| data_parallel_rank=data_parallel_rank, | ||||||||||||||||||
| ) | ||||||||||||||||||
| generator = self.tokenizer_manager.generate_request(obj, None) | ||||||||||||||||||
|
|
||||||||||||||||||
| if stream is True: | ||||||||||||||||||
| return generator | ||||||||||||||||||
|
|
||||||||||||||||||
| # Attach routing only if requested | ||||||||||||||||||
| def _any_true(x): | ||||||||||||||||||
| if isinstance(x, list): | ||||||||||||||||||
| return any(bool(v) for v in x) | ||||||||||||||||||
| return bool(x) | ||||||||||||||||||
|
|
||||||||||||||||||
| if not _any_true(return_expert_routing): | ||||||||||||||||||
| generator = self.tokenizer_manager.generate_request(obj, None) | ||||||||||||||||||
|
|
||||||||||||||||||
| if stream is True: | ||||||||||||||||||
| return generator | ||||||||||||||||||
| else: | ||||||||||||||||||
| return await generator.__anext__() | ||||||||||||||||||
| else: | ||||||||||||||||||
| return await generator.__anext__() | ||||||||||||||||||
| await self._expert_routing_lock.acquire() | ||||||||||||||||||
| try: | ||||||||||||||||||
| if stream is True: | ||||||||||||||||||
| raise NotImplementedError("Expert statistics logging is not availale in streaming mode.") | ||||||||||||||||||
| else: | ||||||||||||||||||
| await self.tokenizer_manager.start_expert_distribution_record() | ||||||||||||||||||
| generator = self.tokenizer_manager.generate_request(obj, None) | ||||||||||||||||||
|
|
||||||||||||||||||
| ret = await generator.__anext__() | ||||||||||||||||||
| await self.tokenizer_manager.stop_expert_distribution_record() | ||||||||||||||||||
| records_obj = await self.tokenizer_manager.dump_expert_distribution_record_object() | ||||||||||||||||||
| # per_req = _split_routing_per_request(records_obj) | ||||||||||||||||||
| # Attach under meta_info; consumers can ignore if unused | ||||||||||||||||||
| # ret.setdefault("meta_info", {})["moe_routing_per_request"] = { | ||||||||||||||||||
| # rid: { | ||||||||||||||||||
| # "topk_ids_of_layer": v["topk_ids_of_layer"].tolist(), | ||||||||||||||||||
| # "positions": v["positions"], | ||||||||||||||||||
| # "physical_to_logical_map": v["physical_to_logical_map"].cpu().tolist(), | ||||||||||||||||||
| # } | ||||||||||||||||||
| # for rid, v in per_req.items() | ||||||||||||||||||
| # } | ||||||||||||||||||
|
|
||||||||||||||||||
| # Determine which RIDs we actually need to attach | ||||||||||||||||||
| if isinstance(ret, list): | ||||||||||||||||||
| wanted = {item.get("meta_info", {}).get("id") for item in ret if item.get("meta_info")} | ||||||||||||||||||
| else: | ||||||||||||||||||
| wanted = {ret.get("meta_info", {}).get("id")} if ret.get("meta_info") else set() | ||||||||||||||||||
| wanted = {rid for rid in wanted if rid} # drop Nones | ||||||||||||||||||
|
|
||||||||||||||||||
| # breakpoint() | ||||||||||||||||||
| # Fast single-pass subset aggregation | ||||||||||||||||||
| per_rid = _records_to_per_rids_subset( | ||||||||||||||||||
| records_obj, | ||||||||||||||||||
| wanted_rids=wanted, | ||||||||||||||||||
| allow_multi_active_sequences=False, # TODO flip to True only if you explicitly support fan-out | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| _attach_routing_to_ret(ret, per_rid) | ||||||||||||||||||
| # breakpoint() | ||||||||||||||||||
|
|
||||||||||||||||||
| # # TODO DEBUGGING ONLY | ||||||||||||||||||
| # breakpoint() # TODO NOTE what if one prompt reaches EOS. records-1 or simply padding token id in the input_ids, where we can try using the rwo_idx to match back | ||||||||||||||||||
| # # TODO MOVE per req all to cpu | ||||||||||||||||||
| # ret.setdefault("meta_info", {})["moe_routing_per_request"] = per_req | ||||||||||||||||||
|
|
||||||||||||||||||
| return ret | ||||||||||||||||||
| finally: | ||||||||||||||||||
| # Safety: ensure recorder is not left on, and release the lock | ||||||||||||||||||
| try: | ||||||||||||||||||
| await self.tokenizer_manager.stop_expert_distribution_record() | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| pass | ||||||||||||||||||
|
Comment on lines
+457
to
+460
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Swallowing exceptions with a
Suggested change
|
||||||||||||||||||
| self._expert_routing_lock.release() | ||||||||||||||||||
|
|
||||||||||||||||||
| def encode( | ||||||||||||||||||
| self, | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -358,7 +358,7 @@ def __init__( | |||||||||||||
| server_args.chunked_prefill_size * 8, | ||||||||||||||
| self._TOP_K_NUM, | ||||||||||||||
| ), | ||||||||||||||
| dtype=torch.int32, | ||||||||||||||
| dtype=torch.uint8, | ||||||||||||||
| device=server_args.device, | ||||||||||||||
| ) | ||||||||||||||
| self._misc_objects: List[Dict[str, Any]] = [] | ||||||||||||||
|
|
@@ -367,11 +367,32 @@ def __init__( | |||||||||||||
| ), "DetailSinglePassGatherer does not support TBO yet" | ||||||||||||||
| # TODO assert shared experts fusion is disabled, o/w data is wrong | ||||||||||||||
|
|
||||||||||||||
| def _ensure_capacity(self, need_tokens: int): | ||||||||||||||
| L, cur, K = self._topk_ids_of_layer.shape | ||||||||||||||
| if need_tokens > cur: | ||||||||||||||
| new = torch.full((L, need_tokens, K), -1, | ||||||||||||||
| dtype=self._topk_ids_of_layer.dtype, | ||||||||||||||
| device=self._topk_ids_of_layer.device) | ||||||||||||||
|
Comment on lines
+373
to
+375
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||||||
| new[:, :cur, :] = self._topk_ids_of_layer | ||||||||||||||
| self._topk_ids_of_layer = new | ||||||||||||||
|
|
||||||||||||||
| def on_forward_pass_start(self, forward_batch: ForwardBatch): | ||||||||||||||
| assert self._metadata is None | ||||||||||||||
|
|
||||||||||||||
| # shape assertion, tho may be redundant | ||||||||||||||
| if forward_batch.forward_mode == "decode": | ||||||||||||||
| assert ( | ||||||||||||||
| len(forward_batch.rids) | ||||||||||||||
| == forward_batch.input_ids.numel() | ||||||||||||||
| == forward_batch.positions.numel() | ||||||||||||||
| ), "Decode: rids/input_ids/positions must align 1:1" | ||||||||||||||
|
|
||||||||||||||
| T_pass = forward_batch.input_ids.numel() # prefill: sum(extend_seq_lens); decode: batch size | ||||||||||||||
| self._ensure_capacity(T_pass) | ||||||||||||||
|
|
||||||||||||||
| self._metadata = dict( | ||||||||||||||
| # TODO pr-chain | ||||||||||||||
| # rids=forward_batch.rids, | ||||||||||||||
| rids=forward_batch.rids, # # record request IDs so we can split tokens by prompt | ||||||||||||||
| input_ids=forward_batch.input_ids.cpu().tolist(), | ||||||||||||||
| positions=forward_batch.positions.cpu().tolist(), | ||||||||||||||
| extend_seq_lens=forward_batch.extend_seq_lens_cpu, | ||||||||||||||
|
|
@@ -726,16 +747,35 @@ def reset(self): | |||||||||||||
| super().reset() | ||||||||||||||
| self._records.clear() | ||||||||||||||
|
|
||||||||||||||
| # def dump(self, output_mode: _OutputMode): | ||||||||||||||
| # assert output_mode == "file" | ||||||||||||||
| # output = dict( | ||||||||||||||
| # records=self._records, | ||||||||||||||
| # # NOTE: This may change during recording, so here we say it is the "last" one | ||||||||||||||
| # last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map, | ||||||||||||||
| # ) | ||||||||||||||
| # _dump_to_file( | ||||||||||||||
| # f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output | ||||||||||||||
| # ) | ||||||||||||||
|
Comment on lines
+750
to
+759
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||
|
|
||||||||||||||
| def dump(self, output_mode: _OutputMode): | ||||||||||||||
| assert output_mode == "file" | ||||||||||||||
| output = dict( | ||||||||||||||
| records=self._records, | ||||||||||||||
| # NOTE: This may change during recording, so here we say it is the "last" one | ||||||||||||||
| # copy the list so resetting doesn’t wipe caller’s view | ||||||||||||||
| records=list(self._records), | ||||||||||||||
| # Last mapping used during recording | ||||||||||||||
| last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map, | ||||||||||||||
| ) | ||||||||||||||
| _dump_to_file( | ||||||||||||||
| f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output | ||||||||||||||
| ) | ||||||||||||||
| if output_mode == "file": | ||||||||||||||
| _dump_to_file( | ||||||||||||||
| f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", | ||||||||||||||
| output, | ||||||||||||||
| ) | ||||||||||||||
| elif output_mode == "object": | ||||||||||||||
| # return the raw object without saving | ||||||||||||||
| return output | ||||||||||||||
| else: | ||||||||||||||
| raise NotImplementedError(f"Unknown output mode: {output_mode}") | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class _StatAccumulator(_UtilizationRateAccumulatorMixin): | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1084,11 +1084,19 @@ class ExpertDistributionReq(Enum): | |
| START_RECORD = 1 | ||
| STOP_RECORD = 2 | ||
| DUMP_RECORD = 3 | ||
| # New: request the in-memory dump (Python object) instead of writing files | ||
| DUMP_RECORD_OBJECT = 4 | ||
|
|
||
|
|
||
| @dataclass | ||
| class ExpertDistributionReqOutput: | ||
| pass | ||
| # success/failure for the op | ||
| success: bool | ||
| # optional details | ||
| message: str = "" # message: Optional[str] = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| # New: used only by DUMP_RECORD_OBJECT; holds the recorder dump object | ||
| # shape: {"records": [...], "last_physical_to_logical_map": Tensor or list} | ||
| payload: Optional[Any] = None | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several commented-out blocks of code, including
breakpoint()calls and old implementation logic. These should be removed to improve code clarity and maintainability.