diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 879c66df09cf..cfbde332fcf4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -338,7 +338,7 @@ def __init__( # Draft workers are looked up via `SpeculativeAlgorithm` registry; new # algorithms should register their factory instead of patching this code. - if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}: + if self.spec_algorithm.is_eagle(): draft_worker_kwargs["enable_overlap"] = self.enable_overlap self.draft_worker = self.spec_algorithm.create_draft_worker( **draft_worker_kwargs @@ -864,8 +864,16 @@ def init_disaggregation(self): ) self.disagg_metadata_buffers = MetadataBuffers( buffer_size, - hidden_size=self.model_config.hf_text_config.hidden_size, - hidden_states_dtype=self.model_config.dtype, + hidden_size=( + self.draft_worker.model_config.hidden_size + if self.spec_algorithm.is_eagle() + else 16 # minimal padding size for RDMA + ), + hidden_states_dtype=( + self.draft_worker.model_config.dtype + if self.spec_algorithm.is_eagle() + else torch.float32 + ), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), ) @@ -909,8 +917,16 @@ def init_disaggregation(self): ) self.disagg_metadata_buffers = MetadataBuffers( buffer_size, - hidden_size=self.model_config.hf_text_config.hidden_size, - hidden_states_dtype=self.model_config.dtype, + hidden_size=( + self.draft_worker.model_config.hidden_size + if self.spec_algorithm.is_eagle() + else 16 # minimal padding size for RDMA + ), + hidden_states_dtype=( + self.draft_worker.model_config.dtype + if self.spec_algorithm.is_eagle() + else torch.float32 + ), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), )