Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 7 additions & 28 deletions components/src/dynamo/sglang/request_handlers/handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import base64
import json
import logging
import random
import socket
Expand All @@ -12,7 +10,6 @@
from typing import Any, AsyncGenerator, Dict, Optional, Tuple

import sglang as sgl
from sglang.srt.tracing import trace as sglang_trace
from sglang.srt.utils import get_local_ip_auto

from dynamo._core import Component, Context
Expand Down Expand Up @@ -143,38 +140,20 @@ def _get_bootstrap_info(engine: sgl.Engine) -> Tuple[str, int]:

return bootstrap_host, bootstrap_port

def _propagate_trace_context_to_sglang(
self, context: Context, bootstrap_room: int = 0
):
"""Propagate Dynamo's trace context to SGLang for distributed tracing. SGLang expects a certain
format derived by loooking at https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/tracing/trace.py
in the to_dict() method.
def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]:
"""Get trace header dict for passing to SGLang's external_trace_header parameter.

Args:
context: Dynamo Context object containing trace information.
bootstrap_room: Bootstrap room ID (0 for aggregated, actual room for disaggregated).

Returns:
Dict with traceparent header if trace context available, None otherwise.
"""
trace_id = context.trace_id
span_id = context.span_id
if not trace_id or not span_id:
return

# Build trace context for SGLang
trace_context = {
str(bootstrap_room): {
"root_span": {"traceparent": f"00-{trace_id}-{span_id}-01"},
"prev_span": {
"span_id": int(span_id, 16),
"trace_id": int(trace_id, 16),
},
}
}

# Encode and propagate
base64_context = base64.b64encode(
json.dumps(trace_context, ensure_ascii=False).encode("utf-8")
).decode("utf-8")
sglang_trace.trace_set_remote_propagate_context(base64_context)
return None
return {"traceparent": f"00-{trace_id}-{span_id}-01"}

async def _handle_cancellation(
self, request_id_future: asyncio.Future, context: Context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,9 @@ async def generate(
f"room={bootstrap_info['bootstrap_room']}"
)

if self.enable_trace:
self._propagate_trace_context_to_sglang(
context, bootstrap_info["bootstrap_room"]
)
trace_header = (
self._get_trace_header(context) if self.enable_trace else None
)

decode = await self.engine.async_generate(
**input_param,
Expand All @@ -131,6 +130,7 @@ async def generate(
bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"],
external_trace_header=trace_header,
rid=trace_id,
)

Expand All @@ -141,13 +141,15 @@ async def generate(
async for out in self._process_text_stream(decode, context):
yield out
else:
if self.enable_trace:
self._propagate_trace_context_to_sglang(context)
trace_header = (
self._get_trace_header(context) if self.enable_trace else None
)

agg = await self.engine.async_generate(
**input_param,
sampling_params=sampling_params,
stream=True,
external_trace_header=trace_header,
rid=trace_id,
)
if self.skip_tokenizer_init:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ async def generate(

input_param = self._get_input_param(inner_request)

# Propagate trace context to SGLang
if self.enable_trace:
self._propagate_trace_context_to_sglang(context, bootstrap_room)
trace_header = self._get_trace_header(context) if self.enable_trace else None

results = await self.engine.async_generate(
**input_param,
Expand All @@ -124,6 +122,7 @@ async def generate(
bootstrap_host=self.bootstrap_host,
bootstrap_port=self.bootstrap_port,
bootstrap_room=bootstrap_room,
external_trace_header=trace_header,
rid=trace_id,
)

Expand Down
113 changes: 66 additions & 47 deletions lib/llm/src/kv_router/prefill_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use futures::StreamExt;
use rand::Rng;
use tokio::sync::{OwnedSemaphorePermit, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::Instrument;

use dynamo_runtime::{
component::Endpoint,
Expand Down Expand Up @@ -265,10 +266,14 @@ impl PrefillRouter {
InnerPrefillRouter::KvRouter(r) => r,
_ => return None,
};
match kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
.await
match async {
kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
.await
}
.instrument(tracing::info_span!("kv_find_best_match"))
.await
{
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None,
Expand Down Expand Up @@ -405,19 +410,29 @@ impl PrefillRouter {
phase_permit: OwnedSemaphorePermit,
) {
let router = self.prefill_router.get().cloned();

tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request, target_worker, Some(phase_permit))
// Capture current span to propagate trace context to the spawned task
let span = tracing::Span::current();

tokio::spawn(
async move {
match Self::execute_prefill(
router,
prefill_request,
target_worker,
Some(phase_permit),
)
.await
{
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
Err(e) => {
tracing::warn!("Prefill background task error: {e:?}");
{
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
Err(e) => {
tracing::warn!("Prefill background task error: {e:?}");
}
}
}
});
.instrument(span),
);
}

/// Call the prefill router and extract structured prefill result and worker ID.
Expand Down Expand Up @@ -491,47 +506,51 @@ impl
.as_ref()
.and_then(|r| r.prefill_worker_id);

let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if !self.router_mode.is_kv_routing()
&& let Some(router) = self.prefill_router.get()
let prefill_result = async {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
router.select_next_worker();
}
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if !self.router_mode.is_kv_routing()
&& let Some(router) = self.prefill_router.get()
{
router.select_next_worker();
}

let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());

let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());

// Pass phase permit to spawned task - it drops after first output (record_worker complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
// Pass phase permit to spawned task - it drops after first output (record_worker complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);

Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path");
Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path");

// Drop the phase permit before calling call_prefill - we wait for completion
// so there's no race with set_phase(Decode) below
drop(prefill_phase_permit);
// Drop the phase permit before calling call_prefill - we wait for completion
// so there's no race with set_phase(Decode) below
drop(prefill_phase_permit);

let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());

self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
};
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
}
.instrument(tracing::info_span!("prefill_routing"))
.await;

// Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
Expand Down
Loading