Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions docs/references/production_request_trace.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--oltp-traces-endpoint` when launching the server.

You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965.

## Setup Guide
This section explains how to configure the request tracing and export the trace data.
1. Install the required packages and tools
* install Docker and Docker Compose
* install the dependencies
```bash
# enter the SGLang root directory
pip install -e "python[tracing]"

# or manually install the dependencies using pip
pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc
```

2. launch opentelemetry collector and jaeger
```bash
docker compose -f examples/monitoring/tracing_compose.yaml up -d
```

3. start your SGLang server with tracing enabled
```bash
python -m sglang.launch_server --enable-trace --oltp-traces-endpoint 0.0.0.0:4317 <other option>
```

Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317.

4. raise some requests
5. Observe whether trace data is being exported
* Access port 16686 of Jaeger using a web browser to visualize the request traces.
* The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI.

## How to add Tracing for slices you're interested in?
We have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below.

1. initialization

Every process involved in tracing during the initialization phase should execute:
```python
process_tracing_init(oltp_traces_endpoint, server_name)
```
The oltp_traces_endpoint is obtained from the arguments, and you can set server_name freely, but it should remain consistent across all processes.

Every thread involved in tracing during the initialization phase should execute:
```python
trace_set_thread_info("thread label", tp_rank, dp_rank)
```
The "thread label" can be regarded as the name of the thread, used to distinguish different threads in the visualization view.

2. Mark the beginning and end of a request
```
trace_req_start(rid, bootstrap_room)
trace_req_finish(rid)
```
These two APIs must be called within the same process, for example, in the tokenizer.

3. Add tracing for slice

* Add slice tracing normally:
```python
trace_slice_start("slice A", rid)
trace_slice_end("slice A", rid)
```

- Use the "anonymous" flag to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end.
<br>Note: Anonymous slices must not be nested.
```python
trace_slice_start("", rid, anonymous = True)
trace_slice_end("slice A", rid)
```

- In trace_slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed.
```python
trace_slice_start("", rid, anonymous = True)
trace_slice_end("slice A", rid, auto_next_anon = True)
trace_slice_end("slice B", rid, auto_next_anon = True)
trace_slice_end("slice C", rid, auto_next_anon = True)
trace_slice_end("slice D", rid)
```
- The end of the last slice in a thread must be marked with thread_finish_flag=True; otherwise, the thread's span will not be properly generated.
```python
trace_slice_end("slice D", rid, thread_finish_flag = True)
```

4. When the request execution flow transfers to another thread, the trace context needs to be explicitly propagated.
- sender: Execute the following code before sending the request to another thread via ZMQ
```python
trace_context = trace_get_proc_propagate_context(rid)
req.trace_context = trace_context
```
- receiver: Execute the following code after receiving the request via ZMQ
```python
trace_set_proc_propagate_context(rid, req.trace_context)
```

## How to Extend the Tracing Framework to Support Complex Tracing Scenarios

The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles.

The core of the tracing framework's implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a trace context with a three-level structure.

The core of the tracing framework implementation lies in the design of the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure: `SglangTraceReqContext`, `SglangTraceThreadContext`, and `SglangTraceSliceContext`. Their relationship is as follows:
```
SglangTraceReqContext (req_id="req-123")
├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0)
│ └── SglangTraceSliceContext (name="prefill") # cur slice
|
└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1)
└── SglangTraceSliceContext (name="prefill") # cur slice
```

Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is represented by a `SglangTraceSliceContext`, which is stored in the `SglangTraceThreadContext`. Generate a span and release the corresponding context when slice tracing, thread tracing, or request tracing ends.

In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow.

When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span.
38 changes: 38 additions & 0 deletions examples/monitoring/opentelemetry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
receivers:
otlp:
protocols:
grpc:
endpoint: 0.0.0.0:4317
http:
endpoint: 0.0.0.0:4318
processors:
batch:

exporters:
otlp:
endpoint: jaeger:4317
tls:
insecure: true
file:
path: /tmp/otel_trace.json

extensions:
health_check:
pprof:
zpages:

service:
extensions: [health_check, pprof, zpages]
pipelines:
traces:
receivers: [otlp]
processors: [batch]
exporters: [otlp, file]
metrics:
receivers: [otlp]
processors: [batch]
exporters: [otlp]
logs:
receivers: [otlp]
processors: [batch]
exporters: [otlp]
21 changes: 21 additions & 0 deletions examples/monitoring/tracing_compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
services:
otel-collector:
image: docker.io/otel/opentelemetry-collector
volumes:
- ./opentelemetry.yaml:/etc/otelcol/config.yaml
- /tmp:/tmp
ports:
- "4317:4317" # OTLP gRPC
- "4318:4318" # OTLP HTTP
depends_on:
- jaeger
restart: unless-stopped

jaeger:
image: jaegertracing/all-in-one
container_name: jaeger
ports:
- "16686:16686"
environment:
- COLLECTOR_OTLP_ENABLED=true
restart: unless-stopped
7 changes: 7 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ runtime_common = [
"xgrammar==0.1.24",
]

tracing = [
"opentelemetry-sdk",
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc",
]

srt = [
"sglang[runtime_common]",
"sgl-kernel==0.3.9.post2",
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import zmq.asyncio
from PIL.Image import Image

from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info

# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Expand Down Expand Up @@ -138,6 +140,12 @@ def __init__(self, **kwargs):
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
)

if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)

def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
Expand Down
15 changes: 15 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

import setproctitle

from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info

# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Expand Down Expand Up @@ -177,6 +179,13 @@ async def init_multi_tokenizer() -> ServerArgs:
scheduler_info=scheduler_info,
)
)

if server_args.enable_trace:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we can enable this when using the http server entrypoint. What do you think about enabling this for the sgl.Engine API as well?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense. I'll add it, run some tests, and push an update shortly.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've enabled this for the sgl.Engine API.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat! Thank you.

I've kicked off the PR checks

process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = f"MultiTokenizer-{tokenizer_manager.worker_id}"
trace_set_thread_info(thread_label)

return server_args


Expand Down Expand Up @@ -1171,6 +1180,12 @@ def launch_server(
server_args=server_args,
)

if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Tokenizer"
trace_set_thread_info(thread_label)

set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ class TokenizedGenerateReqInput:
# Image gen grpc migration
return_bytes: bool = False

# tracing context
trace_context: Optional[Dict] = None


@dataclass
class BatchTokenizedGenerateReqInput:
Expand Down Expand Up @@ -654,6 +657,9 @@ class EmbeddingReqInput:
# For background responses (OpenAI responses API)
background: bool = False

# tracing context
trace_context: Optional[Dict] = None

def normalize_batch_and_arguments(self):
# at least one of text, input_ids, or image should be provided
if self.text is None and self.input_ids is None and self.image_data is None:
Expand Down
45 changes: 45 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.tracing.trace import (
process_tracing_init,
trace_event,
trace_set_proc_propagate_context,
trace_set_thread_info,
trace_slice,
trace_slice_end,
trace_slice_start,
)
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DynamicGradMode,
Expand Down Expand Up @@ -814,6 +823,10 @@ def event_loop_normal(self):
batch = self.get_next_batch_to_run()
self.cur_batch = batch

if batch:
for req in batch.reqs:
trace_event("schedule", req.rid)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrap this into a function and only call it if tracing is enabled.
The principal: if tracing is not enabled, the overhead should just be a single if/else, not a for loop


if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
Expand All @@ -835,6 +848,10 @@ def event_loop_overlap(self):
batch = self.get_next_batch_to_run()
self.cur_batch = batch

if batch:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a if batch condition right after this, should you put this block under L855?

for req in batch.reqs:
trace_event("schedule", req.rid)
Comment on lines +852 to +853
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap this into a function called trace_event_batch.

The code in event_loop_XXXX should be very concise, only expose core logics


if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
Expand Down Expand Up @@ -1098,6 +1115,12 @@ def recv_requests(self) -> List[Req]:
self.tp_cpu_group,
src=self.tp_group.ranks[0],
)

for req in recv_reqs:
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
trace_set_proc_propagate_context(req.rid, req.trace_context)
trace_slice_start("", req.rid, anonymous=True)

return recv_reqs

def process_input_requests(self, recv_reqs: List):
Expand Down Expand Up @@ -1327,6 +1350,7 @@ def _add_request_to_queue(self, req: Req):
else:
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
trace_slice_end("process req", req.rid, auto_next_anon=True)

def _prefetch_kvcache(self, req: Req):
if self.enable_hicache_storage:
Expand Down Expand Up @@ -1880,8 +1904,23 @@ def process_batch_result(
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result, launch_done)
for req in batch.reqs:
trace_slice(
"decode loop",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)

elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result, launch_done)
for req in batch.reqs:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If tracing is not enabled, the overhead should just be a single if/else, not a for loop

trace_slice(
"prefill",
req.rid,
auto_next_anon=not req.finished(),
thread_finish_flag=req.finished(),
)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done)
Expand Down Expand Up @@ -2550,6 +2589,12 @@ def run_scheduler_process(
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
if server_args.enable_trace:
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
if server_args.disaggregation_mode == "null":
thread_label = "Scheduler"
trace_set_thread_info(thread_label, tp_rank, dp_rank)

if (numa_node := server_args.numa_node) is not None:
numa_bind_to_node(numa_node[gpu_id])

Expand Down
Loading
Loading