Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:

polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.tp_worker.get_tp_cpu_group(),
self.attn_tp_cpu_group,
)

undone_reqs: List[Req] = []
Expand Down
12 changes: 5 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def init_memory_pool_and_cache(self):
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
tp_cache_group=self.tp_cpu_group,
page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio,
)
Expand Down Expand Up @@ -553,7 +553,7 @@ def init_disaggregation(self):

# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
)
Expand All @@ -568,7 +568,7 @@ def init_disaggregation(self):
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
tree_cache=self.tree_cache,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
gloo_group=self.attn_tp_cpu_group,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
Expand Down Expand Up @@ -597,7 +597,7 @@ def init_disaggregation(self):
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
gloo_group=self.attn_tp_cpu_group,
transfer_backend=self.transfer_backend,
scheduler=self,
)
Expand Down Expand Up @@ -710,9 +710,7 @@ def event_loop_normal_disagg_decode(self):
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(
batch.reqs, [False for _ in range(len(batch.reqs))]
)
self.stream_output(batch.reqs, False)
else:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
Expand Down
Loading