Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def generate(
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
external_trace_header: Optional[Dict] = None,
rid: Optional[Union[List[str], str]] = None,
) -> Union[Dict, Iterator[Dict]]:
"""
Expand Down Expand Up @@ -266,6 +267,7 @@ def generate(
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
external_trace_header=external_trace_header,
rid=rid,
)
generator = self.tokenizer_manager.generate_request(obj, None)
Expand Down Expand Up @@ -315,6 +317,7 @@ async def async_generate(
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
external_trace_header: Optional[Dict] = None,
rid: Optional[Union[List[str], str]] = None,
) -> Union[Dict, AsyncIterator[Dict]]:
"""
Expand Down Expand Up @@ -352,6 +355,7 @@ async def async_generate(
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
external_trace_header=external_trace_header,
rid=rid,
)
generator = self.tokenizer_manager.generate_request(obj, None)
Expand All @@ -368,6 +372,7 @@ def encode(
audio_data: Optional[MultimodalDataInputFormat] = None,
video_data: Optional[MultimodalDataInputFormat] = None,
dimensions: Optional[int] = None,
external_trace_header: Optional[Dict] = None,
rid: Optional[Union[List[str], str]] = None,
) -> Dict:
"""
Expand All @@ -380,6 +385,7 @@ def encode(
audio_data=audio_data,
video_data=video_data,
dimensions=dimensions,
external_trace_header=external_trace_header,
rid=rid,
)
generator = self.tokenizer_manager.generate_request(obj, None)
Expand All @@ -393,6 +399,7 @@ async def async_encode(
audio_data: Optional[MultimodalDataInputFormat] = None,
video_data: Optional[MultimodalDataInputFormat] = None,
dimensions: Optional[int] = None,
external_trace_header: Optional[Dict] = None,
rid: Optional[Union[List[str], str]] = None,
) -> Dict:
"""
Expand All @@ -407,6 +414,7 @@ async def async_encode(
audio_data=audio_data,
video_data=video_data,
dimensions=dimensions,
external_trace_header=external_trace_header,
rid=rid,
)
generator = self.tokenizer_manager.generate_request(obj, None)
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin):
# Whether to return entropy
return_entropy: bool = False

# Propagates trace context via Engine.generate/async_generate
external_trace_header: Optional[Dict] = None

# For EPD-disaggregated inference
need_wait_for_image: Optional[bool] = None
num_items_assigned: Optional[List] = None
Expand Down Expand Up @@ -662,6 +665,7 @@ def __getitem__(self, i):
custom_labels=self.custom_labels,
return_bytes=self.return_bytes,
return_entropy=self.return_entropy,
external_trace_header=self.external_trace_header,
http_worker_ipc=self.http_worker_ipc,
**{
field: getattr(self, field)
Expand Down Expand Up @@ -796,8 +800,8 @@ class EmbeddingReqInput(BaseReq, APIServingTimingMixin):
# For background responses (OpenAI responses API)
background: bool = False

# tracing context
trace_context: Optional[Dict] = None
# Propagates trace context via Engine.encode/async_encode
external_trace_header: Optional[Dict] = None
Comment on lines +803 to +804
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Renaming trace_context to external_trace_header is a good change for consistency with GenerateReqInput.

However, I noticed a related pre-existing issue. TokenizedEmbeddingReqInput is missing the trace_context field, which TokenizedGenerateReqInput has. This will likely cause an AttributeError in tokenizer_manager.py inside the _send_one_request method when it tries to set tokenized_obj.trace_context for embedding requests. Since this PR is focused on improving tracing, it would be a good opportunity to fix this to ensure tracing works correctly for embedding requests as well.

You could fix this by adding trace_context: Optional[Dict] = None to the TokenizedEmbeddingReqInput class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix the issue you mentioned in #13152.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark here


# The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
dimensions: Optional[int] = None
Expand Down Expand Up @@ -878,6 +882,7 @@ def __getitem__(self, i):
video_data=self.video_data[i] if self.video_data is not None else None,
sampling_params=self.sampling_params[i],
rid=self.rid[i],
external_trace_header=self.external_trace_header,
dimensions=self.dimensions,
http_worker_ipc=self.http_worker_ipc,
**{
Expand Down
14 changes: 6 additions & 8 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,15 +460,14 @@ async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
traceparent: Optional[str] = None,
):
created_time = obj.received_time if obj.received_time else time.time()
self.auto_create_handle_loop()

# Normalize the request
obj.normalize_batch_and_arguments()
if self.enable_trace:
self._trace_request_start(obj, created_time, request, traceparent)
self._trace_request_start(obj, created_time, request)
if self.server_args.language_only:
self._handle_epd_disaggregation_encode_request(obj)
if self.server_args.tokenizer_worker_num > 1:
Expand Down Expand Up @@ -2131,19 +2130,18 @@ def _trace_request_start(
obj: Union[GenerateReqInput, EmbeddingReqInput],
created_time: Optional[float] = None,
request: Optional[fastapi.Request] = None,
traceparent: Optional[str] = None,
):
external_trace_header = None
if request:
if "trace_context" in request.headers:
trace_set_remote_propagate_context(request.headers["trace_context"])
else:
external_trace_header = extract_trace_headers(request.headers)
elif traceparent:
# When the request comes form the rust grpc server there isn't a
# real request object but we still need to propagate the traceparent from
# the traceparent that is explicitly passed in
external_trace_header = {"traceparent": traceparent}
elif obj.external_trace_header:
# When the request comes form the rust grpc server or Engine there isn't a
# real request object but we still need to propagate the trace context from
# the trace context that is explicitly passed in
external_trace_header = obj.external_trace_header

if obj.is_single:
bootstrap_room = (
Expand Down
Loading