diff --git a/components/src/dynamo/sglang/request_handlers/handler_base.py b/components/src/dynamo/sglang/request_handlers/handler_base.py index d4995a2b771..449ea91f610 100644 --- a/components/src/dynamo/sglang/request_handlers/handler_base.py +++ b/components/src/dynamo/sglang/request_handlers/handler_base.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import base64 -import json import logging import random import socket @@ -12,7 +10,6 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple import sglang as sgl -from sglang.srt.tracing import trace as sglang_trace from sglang.srt.utils import get_local_ip_auto from dynamo._core import Component, Context @@ -143,38 +140,20 @@ def _get_bootstrap_info(engine: sgl.Engine) -> Tuple[str, int]: return bootstrap_host, bootstrap_port - def _propagate_trace_context_to_sglang( - self, context: Context, bootstrap_room: int = 0 - ): - """Propagate Dynamo's trace context to SGLang for distributed tracing. SGLang expects a certain - format derived by loooking at https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/tracing/trace.py - in the to_dict() method. + def _get_trace_header(self, context: Context) -> Optional[Dict[str, str]]: + """Get trace header dict for passing to SGLang's external_trace_header parameter. Args: context: Dynamo Context object containing trace information. - bootstrap_room: Bootstrap room ID (0 for aggregated, actual room for disaggregated). + + Returns: + Dict with traceparent header if trace context available, None otherwise. """ trace_id = context.trace_id span_id = context.span_id if not trace_id or not span_id: - return - - # Build trace context for SGLang - trace_context = { - str(bootstrap_room): { - "root_span": {"traceparent": f"00-{trace_id}-{span_id}-01"}, - "prev_span": { - "span_id": int(span_id, 16), - "trace_id": int(trace_id, 16), - }, - } - } - - # Encode and propagate - base64_context = base64.b64encode( - json.dumps(trace_context, ensure_ascii=False).encode("utf-8") - ).decode("utf-8") - sglang_trace.trace_set_remote_propagate_context(base64_context) + return None + return {"traceparent": f"00-{trace_id}-{span_id}-01"} async def _handle_cancellation( self, request_id_future: asyncio.Future, context: Context diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index 51196d72e8c..ad13e061053 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -119,10 +119,9 @@ async def generate( f"room={bootstrap_info['bootstrap_room']}" ) - if self.enable_trace: - self._propagate_trace_context_to_sglang( - context, bootstrap_info["bootstrap_room"] - ) + trace_header = ( + self._get_trace_header(context) if self.enable_trace else None + ) decode = await self.engine.async_generate( **input_param, @@ -131,6 +130,7 @@ async def generate( bootstrap_host=bootstrap_info["bootstrap_host"], bootstrap_port=bootstrap_info["bootstrap_port"], bootstrap_room=bootstrap_info["bootstrap_room"], + external_trace_header=trace_header, rid=trace_id, ) @@ -141,13 +141,15 @@ async def generate( async for out in self._process_text_stream(decode, context): yield out else: - if self.enable_trace: - self._propagate_trace_context_to_sglang(context) + trace_header = ( + self._get_trace_header(context) if self.enable_trace else None + ) agg = await self.engine.async_generate( **input_param, sampling_params=sampling_params, stream=True, + external_trace_header=trace_header, rid=trace_id, ) if self.skip_tokenizer_init: diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index d0943b3b7a4..e426d9493e0 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -113,9 +113,7 @@ async def generate( input_param = self._get_input_param(inner_request) - # Propagate trace context to SGLang - if self.enable_trace: - self._propagate_trace_context_to_sglang(context, bootstrap_room) + trace_header = self._get_trace_header(context) if self.enable_trace else None results = await self.engine.async_generate( **input_param, @@ -124,6 +122,7 @@ async def generate( bootstrap_host=self.bootstrap_host, bootstrap_port=self.bootstrap_port, bootstrap_room=bootstrap_room, + external_trace_header=trace_header, rid=trace_id, ) diff --git a/lib/llm/src/kv_router/prefill_router.rs b/lib/llm/src/kv_router/prefill_router.rs index cff45f172e4..864d8548e51 100644 --- a/lib/llm/src/kv_router/prefill_router.rs +++ b/lib/llm/src/kv_router/prefill_router.rs @@ -8,6 +8,7 @@ use futures::StreamExt; use rand::Rng; use tokio::sync::{OwnedSemaphorePermit, oneshot}; use tokio_util::sync::CancellationToken; +use tracing::Instrument; use dynamo_runtime::{ component::Endpoint, @@ -265,10 +266,14 @@ impl PrefillRouter { InnerPrefillRouter::KvRouter(r) => r, _ => return None, }; - match kv_router - .chooser - .find_best_match(None, &req.token_ids, None, false) - .await + match async { + kv_router + .chooser + .find_best_match(None, &req.token_ids, None, false) + .await + } + .instrument(tracing::info_span!("kv_find_best_match")) + .await { Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank), Err(_) => return None, @@ -405,19 +410,29 @@ impl PrefillRouter { phase_permit: OwnedSemaphorePermit, ) { let router = self.prefill_router.get().cloned(); - - tokio::spawn(async move { - match Self::execute_prefill(router, prefill_request, target_worker, Some(phase_permit)) + // Capture current span to propagate trace context to the spawned task + let span = tracing::Span::current(); + + tokio::spawn( + async move { + match Self::execute_prefill( + router, + prefill_request, + target_worker, + Some(phase_permit), + ) .await - { - Ok(_) => { - tracing::debug!("Prefill background task completed"); - } - Err(e) => { - tracing::warn!("Prefill background task error: {e:?}"); + { + Ok(_) => { + tracing::debug!("Prefill background task completed"); + } + Err(e) => { + tracing::warn!("Prefill background task error: {e:?}"); + } } } - }); + .instrument(span), + ); } /// Call the prefill router and extract structured prefill result and worker ID. @@ -491,47 +506,51 @@ impl .as_ref() .and_then(|r| r.prefill_worker_id); - let prefill_result = if let Some((worker_id, dp_rank, bootstrap_info)) = self - .build_bootstrap_info(&prefill_req, preselected_worker) - .await - { - // Bootstrap optimization path: spawn prefill in background - // We successfully used the peeked worker, so we must now advance the router state - // to ensure the next request gets a different worker. - if !self.router_mode.is_kv_routing() - && let Some(router) = self.prefill_router.get() + let prefill_result = async { + if let Some((worker_id, dp_rank, bootstrap_info)) = self + .build_bootstrap_info(&prefill_req, preselected_worker) + .await { - router.select_next_worker(); - } + // Bootstrap optimization path: spawn prefill in background + // We successfully used the peeked worker, so we must now advance the router state + // to ensure the next request gets a different worker. + if !self.router_mode.is_kv_routing() + && let Some(router) = self.prefill_router.get() + { + router.select_next_worker(); + } - let routing = prefill_req.routing_mut(); - routing.prefill_worker_id = Some(worker_id); - routing.dp_rank = Some(dp_rank); - prefill_req.bootstrap_info = Some(bootstrap_info.clone()); + let routing = prefill_req.routing_mut(); + routing.prefill_worker_id = Some(worker_id); + routing.dp_rank = Some(dp_rank); + prefill_req.bootstrap_info = Some(bootstrap_info.clone()); - let prefill_context = Context::with_id(prefill_req, request_id.clone()); - engine_ctx.link_child(prefill_context.context()); + let prefill_context = Context::with_id(prefill_req, request_id.clone()); + engine_ctx.link_child(prefill_context.context()); - // Pass phase permit to spawned task - it drops after first output (record_worker complete) - // This allows set_phase(Decode) below to proceed only after prefill routing is done - self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit); + // Pass phase permit to spawned task - it drops after first output (record_worker complete) + // This allows set_phase(Decode) below to proceed only after prefill routing is done + self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit); - Ok((None, Some(worker_id), Some(bootstrap_info))) - } else { - // Original prefill path: wait for prefill to complete - tracing::debug!("Using original prefill path"); + Ok((None, Some(worker_id), Some(bootstrap_info))) + } else { + // Original prefill path: wait for prefill to complete + tracing::debug!("Using original prefill path"); - // Drop the phase permit before calling call_prefill - we wait for completion - // so there's no race with set_phase(Decode) below - drop(prefill_phase_permit); + // Drop the phase permit before calling call_prefill - we wait for completion + // so there's no race with set_phase(Decode) below + drop(prefill_phase_permit); - let prefill_context = Context::with_id(prefill_req, request_id.clone()); - engine_ctx.link_child(prefill_context.context()); + let prefill_context = Context::with_id(prefill_req, request_id.clone()); + engine_ctx.link_child(prefill_context.context()); - self.call_prefill(prefill_context) - .await - .map(|(result, worker_id)| (Some(result), worker_id, None)) - }; + self.call_prefill(prefill_context) + .await + .map(|(result, worker_id)| (Some(result), worker_id, None)) + } + } + .instrument(tracing::info_span!("prefill_routing")) + .await; // Abort if cancelled during prefill if engine_ctx.is_stopped() || engine_ctx.is_killed() {