[PD][MoRI] Align hybrid state transfer with per-component schema#26539
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
| state_type = getattr(self.kv_args, "state_type", "none") | ||
|
|
||
| if state_type == "none": | ||
| raise RuntimeError( | ||
| "PD state transfer failed: state_type is 'none' but state_indices were provided" | ||
| ) | ||
|
|
||
| if not peer_info.dst_state_mem_descs: | ||
| state_types = getattr(self.kv_args, "state_types", None) or [] |
There was a problem hiding this comment.
I think state_types = self.kv_args.state_types is enough. We have made sure this value will be set.
There was a problem hiding this comment.
Thanks, updated as suggested
|
CC: @HaiShaw |
|
/tag-and-rerun-ci |
|
@amd-bot ci-status |
CI Status for PR #26539PR: [PD][MoRI] Align hybrid state transfer with per-component schema AMD: 1 failure (0 likely related) | Others: 12 failures (0 related) AMD CI Failures
Other CI Failures
DetailsNone of the failures are related to this PR's changes (
Verdict: the failures are unrelated to this PR. Safe to ignore from a correctness standpoint; the
|
Motivation
PR #24932 ([PD] Refactor hybrid state transfer) migrated
KVArgsfrom a flat state layout (state_type: str,state_item_lens: List[int],state_dim_per_tensor: List[int]) to a per-component one (state_types: List[StateType], both*_item_lens/*_dim_per_tensorbecomeList[List[int]]). Mooncake and NIXL were migrated to the new schema in the same PR, but MoRI was only partially migrated — the inner-loop in_register_local_bufferswas updated, while_register_kv_args,send_state,send_metadata,TransferInfo, andKVArgsRegisterInfowere left on the old flat assumption.For any model with a non-empty state pool (DeepSeek V4, GLM-5, Qwen3.5) this manifests as
struct.error: required argument is not an integerat PD bootstrap (#26525), because_register_kv_argsdoesstruct.pack("I", item_len)on what is now a list. A flatten-on-send hack would silence that crash but still routes Mamba state buffers through the SWA/DSA contiguous-page logic on multi-component hybrids, so this change aligns MoRI with the per-component dispatch model Mooncake and NIXL already use.Modifications
state_item_lens/state_dim_per_tensortopack_int_lists("I")/unpack_int_lists("I"), switchstate_indicestopack_int_lists("i")/unpack_int_lists("i"), and add nested-msgpack helpers forList[List[MemoryDesc]].KVArgsRegisterInfo.dst_state_{mem_descs,item_lens,dim_per_tensor}andTransferInfo.dst_state_indicesbecomeList[List[...]]/List[np.ndarray].MoriKVManager.state_mem_descsbecomesList[List[MemoryDesc]];_register_local_buffersbuilds it per-component.send_stateiteratesstate_types[i]and dispatches each component to_send_mamba_stateor_send_swa_dsa_stateindependently (mirrorsMooncakeKVManager.maybe_send_extraandNixlKVManager.maybe_send_extra)._send_mamba_state/_send_swa_dsa_stateaccept a single component's slice instead of indexing intoself.kv_args.*directly._normalize_state_indices_per_componentravels each component's payload to 1-D once at the API boundary, removing the 2-D single-component DSA edge case at the source.Accuracy Tests
Cross-machine PD on AMD MI300X with
--disaggregation-transfer-backend mori.Qwen3-8B (pure transformer, validates non-hybrid path / empty state lists):
Qwen3.5-122B-A10B (hybrid linear attention, exercises the per-component mamba state transfer path — decode logs show
Mamba Cache is allocatedwithssm_state 18.02GB / TP rankandUsing hybrid linear attention backend for hybrid GDN models, and per-requestmamba usageis non-zero):No state-transfer-related errors in prefill or decode logs across all runs.
Speed Tests and Profiling
sglang.bench_serving --backend sglang-oai-chatagainst PD routerfronting 1P + 1D over RDMA:
Qwen3-8B (1P TP=4 + 1D TP=4):
Qwen3.5-122B-A10B (1P TP=8 + 1D TP=8):
Checklist
cc @Duyi-Wang
CI States
Latest PR Test (Base): ⏳ Run #26624086997
Latest PR Test (Extra): ❌ Run #26624086820