diff --git a/docs/references/production_request_trace.md b/docs/references/production_request_trace.md index 2d19570c2158..2c28e403cff2 100644 --- a/docs/references/production_request_trace.md +++ b/docs/references/production_request_trace.md @@ -1,6 +1,6 @@ # Production Request Tracing -SGlang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--enable-trace` and configure the OpenTelemetry Collector endpoint using `--otlp-traces-endpoint` when launching the server. +SGLang exports request trace data based on the OpenTelemetry Collector. You can enable tracing by adding the `--trace-level` and configure the OpenTelemetry Collector endpoint using `--otlp-traces-endpoint` when launching the server. The `--trace-level` option accepts configurable values from `0` to `3`, where `0` means tracing is disabled and higher numbers indicate more detailed tracing. Additionally, you can use `--trace-module` to specify the module to trace; currently, only `request` is supported. You can find example screenshots of the visualization in https://github.com/sgl-project/sglang/issues/8965. @@ -17,23 +17,23 @@ This section explains how to configure the request tracing and export the trace pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc ``` -2. launch opentelemetry collector and jaeger +2. Launch OpenTelemetry collector and Jaeger ```bash docker compose -f examples/monitoring/tracing_compose.yaml up -d ``` -3. start your SGLang server with tracing enabled +3. Start your SGLang server with tracing enabled ```bash # set env variables export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500 export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64 # start the prefill and decode server - python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 + python -m sglang.launch_server --trace-level 3 --otlp-traces-endpoint 0.0.0.0:4317 [--trace-module request] # start the mini lb python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 ``` - Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317. + Replace `0.0.0.0:4317` with the actual endpoint of the OpenTelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317. To use the HTTP/protobuf span exporter, set the following environment variable and point to an HTTP endpoint, for example, `http://0.0.0.0:4318/v1/traces`. ```bash @@ -41,15 +41,15 @@ This section explains how to configure the request tracing and export the trace ``` -4. raise some requests +4. Raise some requests 5. Observe whether trace data is being exported * Access port 16686 of Jaeger using a web browser to visualize the request traces. * The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI. -## How to add Tracing for slices you're interested in? +## How to add Tracing for slices you're interested in?(API introduction) We have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below. -1. initialization +1. Initialization Every process involved in tracing during the initialization phase should execute: ```python @@ -63,98 +63,64 @@ We have already inserted instrumentation points in the tokenizer and scheduler m ``` The "thread label" can be regarded as the name of the thread, used to distinguish different threads in the visualization view. -2. Mark the beginning and end of a request +2. Create a time recorder for a request + Each request needs to call `TraceMetricContext()` to initialize a time recorder, which is used to generate slice spans and request stage metrics. You can either store it within the request object or maintain it as a global variable. A set of APIs for managing the global time recorder is provided in `python/sglang/srt/tracing/trace_metric_wrapper.py`. + +3. Mark the beginning and end of a request ``` - trace_req_start(rid, bootstrap_room) - trace_req_finish(rid) + # The time recorder calls trace_req_start() by default when it is created. + trace_metric_ctx.trace_req_finish() ``` - These two APIs must be called within the same process, for example, in the tokenizer. + TraceMetricContext() and trace_req_finish() must be called within the same process, for example, in the tokenizer. -3. Add tracing for slice +4. Add tracing for a slice * Add slice tracing normally: ```python - trace_slice_start("slice A", rid) - trace_slice_end("slice A", rid) + trace_metric_ctx.slice_start(RequestStage.TOKENIZER) + trace_metric_ctx.slice_end(RequestStage.TOKENIZER) ``` - - Use the "anonymous" flag to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end. + - Use the `ANONYMOUS` to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end.
Note: Anonymous slices must not be nested. ```python - trace_slice_start("", rid, anonymous = True) - trace_slice_end("slice A", rid) + trace_metric_ctx.slice_start(RequestStage.ANONYMOUS) + trace_metric_ctx.slice_end(RequestStage.TOKENIZER) ``` - - In trace_slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed. + - In slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed. ```python - trace_slice_start("", rid, anonymous = True) - trace_slice_end("slice A", rid, auto_next_anon = True) - trace_slice_end("slice B", rid, auto_next_anon = True) - trace_slice_end("slice C", rid, auto_next_anon = True) - trace_slice_end("slice D", rid) + trace_metric_ctx.slice_start(RequestStage.ANONYMOUS) + trace_metric_ctx.slice_end(RequestStage.A, auto_next_anon = True) + trace_metric_ctx.slice_end(RequestStage.B, auto_next_anon = True) + trace_metric_ctx.slice_end(RequestStage.C, auto_next_anon = True) + trace_metric_ctx.slice_end(RequestStage.D) ``` - The end of the last slice in a thread must be marked with thread_finish_flag=True; otherwise, the thread's span will not be properly generated. ```python - trace_slice_end("slice D", rid, thread_finish_flag = True) + trace_metric_ctx.slice_end(RequestStage.D, thread_finish_flag = True) ``` -4. When the request execution flow transfers to another thread, the trace context needs to be explicitly propagated. - - sender: Execute the following code before sending the request to another thread via ZMQ - ```python - trace_context = trace_get_proc_propagate_context(rid) - req.trace_context = trace_context - ``` +5. When the request execution flow transfers to another thread, the thread context needs to be explicitly rebuilt. - receiver: Execute the following code after receiving the request via ZMQ ```python - trace_set_proc_propagate_context(rid, req.trace_context) - ``` - -5. When the request execution flow transfers to another node(PD disaggregation), the trace context needs to be explicitly propagated. - - sender: Execute the following code before sending the request to node thread via http - ```python - trace_context = trace_get_remote_propagate_context(bootstrap_room_list) - headers = {"trace_context": trace_context} - session.post(url, headers=headers) - ``` - - receiver: Execute the following code after receiving the request via http - ```python - trace_set_remote_propagate_context(request.headers['trace_context']) + trace_metric_ctx.rebuild_thread_context() ``` ## How to Extend the Tracing Framework to Support Complex Tracing Scenarios The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles. -The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a two-level trace context structure and a four-level span structure: `SglangTraceReqContext`, `SglangTraceThreadContext`. Their relationship is as follows: +The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure or span structure: `TraceReqContext`, `TraceThreadContext` and `TraceSliceContext`. Their relationship is as follows: ``` -SglangTraceReqContext (req_id="req-123") -├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0) +TraceReqContext (req_id="req-123") +├── TraceThreadContext(thread_label="scheduler", tp_rank=0) +| └── TraceSliceContext(slice_name="prefill") | -└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1) +└── TraceThreadContext(thread_label="scheduler", tp_rank=1) + └── TraceSliceContext(slice_name="prefill") ``` -Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is stored in a list. +Each traced request maintains a global `TraceReqContext` and creates a corresponding request span. For every thread that processes the request, a `TraceThreadContext` is recorded and a thread span is created. The `TraceThreadContext` is nested within the `TraceReqContext`, and each currently traced code slice—potentially nested—is stored in its associated `TraceThreadContext`. In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow. - -When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span. - - -We designed a four-level span structure, consisting of `bootstrap_room_span`, `req_root_span`, `thread_span`, and `slice_span`. Among them, `req_root_span` and `thread_span` correspond to `SglangTraceReqContext` and `SglangTraceThreadContext`, respectively, and `slice_span` is stored within the `SglangTraceThreadContext`. The `bootstrap_room_span` is designed to accommodate the separation of PD-disaggregation. On different nodes, we may want to add certain attributes to the `req_root_span`. However, if the `req_root_span` is shared across all nodes, the Prefill and Decode nodes would not be allowed to add attributes due to the constraints imposed by OpenTelemetry's design. - -``` -bootstrap room span -├── router req root span -| └── router thread span -| └── slice span -├── prefill req root span -| ├── tokenizer thread span -| | └── slice span -| └── scheduler thread span -| └── slice span -└── decode req root span - ├── tokenizer thread span - | └── slice span - └── scheduler thread span - └── slice span -``` diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 51af67636336..24dc02b76b56 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -47,7 +47,7 @@ prepare_abort, ) from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch +from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -60,7 +60,7 @@ ReqToTokenPool, ) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool -from sglang.srt.tracing.trace import trace_event_batch, trace_slice_end +from sglang.srt.tracing.trace_metric_wrapper import RequestStage, trace_event_batch from sglang.srt.utils import get_int_env_var from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -344,8 +344,9 @@ def add(self, req: Req, is_retracted: bool = False) -> None: prefill_dp_rank=req.data_parallel_rank, ) - req.add_latency(RequestStage.DECODE_PREPARE) - trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True) + req.trace_metric_ctx.slice_end( + RequestStage.DECODE_PREPARE, auto_next_anon=True + ) self.queue.append( DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) ) @@ -354,6 +355,7 @@ def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: if len(req.origin_input_ids) > self.max_total_num_tokens: message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" logger.error(message) + req.trace_metric_ctx.abort(abort_info={"abort_info": message}) prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST) self.scheduler.stream_output([req], req.return_logprob) return True @@ -473,6 +475,9 @@ def pop_preallocated( ) failed_reqs.append(decode_req) indices_to_remove.add(i) + decode_req.req.trace_metric_ctx.abort( + abort_info=decode_req.req.finished_reason + ) # Then, preallocate the remaining requests if possible for i, decode_req in enumerate(self.queue): @@ -578,9 +583,9 @@ def pop_preallocated( decode_req.req.time_stats.decode_transfer_queue_entry_time = ( time.perf_counter() ) - decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) - trace_slice_end( - RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True + + decode_req.req.trace_metric_ctx.slice_end( + RequestStage.DECODE_BOOTSTRAP, auto_next_anon=True ) self.queue = [ @@ -762,9 +767,8 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> None: decode_req.kv_receiver.clear() decode_req.kv_receiver = None - trace_slice_end( + decode_req.req.trace_metric_ctx.slice_end( RequestStage.DECODE_TRANSFERRED, - decode_req.req.rid, auto_next_anon=True, ) decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter() @@ -788,6 +792,9 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req except Exception as e: error_message += f" with exception {e}" logger.error(error_message) + decode_req.req.trace_metric_ctx.abort( + abort_info={"abort_info": error_message} + ) prepare_abort( decode_req.req, error_message, @@ -806,6 +813,7 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req self._commit_transfer_to_req(decode_req) indices_to_remove.add(i) transferred_reqs.append(decode_req.req) + elif poll in [ KVPoll.Bootstrapping, KVPoll.WaitingForInput, @@ -818,7 +826,6 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req for i in indices_to_remove: idx = self.queue[i].metadata_buffer_index assert idx != -1 - self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED) self.req_to_metadata_buffer_idx_allocator.free(idx) self.queue = [ @@ -967,7 +974,9 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: # we can only add at least `num_not_used_batch` new batch to the running queue if i < num_not_used_batch: can_run_list.append(req) - req.add_latency(RequestStage.DECODE_WAITING) + req.trace_metric_ctx.slice_end( + RequestStage.DECODE_WAITING, auto_next_anon=True + ) req.init_next_round_input(self.tree_cache) else: waiting_queue.append(req) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index cbd18af03d6e..364a9543ea76 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -42,16 +42,11 @@ poll_and_all_reduce, prepare_abort, ) -from sglang.srt.managers.schedule_batch import ( - FINISH_LENGTH, - Req, - RequestStage, - ScheduleBatch, -) +from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool -from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end +from sglang.srt.tracing.trace_metric_wrapper import RequestStage, trace_event_batch if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -201,9 +196,10 @@ def add(self, req: Req, num_kv_heads: int) -> None: pp_rank=self.pp_rank, ) self._process_req(req) - req.add_latency(RequestStage.PREFILL_PREPARE) self.queue.append(req) - trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True) + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_PREPARE, auto_next_anon=True + ) def extend(self, reqs: List[Req], num_kv_heads: int) -> None: for req in reqs: @@ -213,6 +209,7 @@ def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: if len(req.origin_input_ids) > self.max_total_num_tokens: message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" logger.error(message) + req.trace_metric_ctx.abort(abort_info={"abort_info": message}) prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST) self.scheduler.stream_output([req], req.return_logprob) return True @@ -265,6 +262,7 @@ def pop_bootstrapped( except Exception as e: error_message += f" with exception {e}" logger.error(error_message) + req.trace_metric_ctx.abort(abort_info={"abort_info": error_message}) prepare_abort( req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR ) @@ -294,10 +292,9 @@ def pop_bootstrapped( bootstrapped_reqs.append(req) indices_to_remove.add(i) req.time_stats.wait_queue_entry_time = time.perf_counter() - req.add_latency(RequestStage.PREFILL_BOOTSTRAP) - trace_slice_end( - RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_BOOTSTRAP, auto_next_anon=True ) self.queue = [ @@ -449,8 +446,6 @@ def process_batch_result_disagg_prefill( # There is no output_ids for prefill req.output_ids.append(next_token_id) self.tree_cache.cache_unfinished_req(req) # update the tree and lock - req.add_latency(RequestStage.PREFILL_FORWARD) - trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True) self.disagg_prefill_inflight_queue.append(req) if self.spec_algorithm.is_eagle() and batch.spec_info is not None: req.output_topk_p = batch.spec_info.topk_p[i] @@ -477,6 +472,9 @@ def process_batch_result_disagg_prefill( logprob_pt += num_input_logprobs self.send_kv_chunk(req, last_chunk=True) req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter() + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_FORWARD, auto_next_anon=True + ) if req.grammar is not None: # FIXME: this try-except block is for handling unexpected xgrammar issue. @@ -515,8 +513,9 @@ def process_batch_result_disagg_prefill( if self.enable_overlap: self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) - trace_slice( - RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_CHUNKED_FORWARD, + auto_next_anon=(req.is_chunked != 0), ) self.maybe_send_health_check_signal() @@ -565,6 +564,7 @@ def process_disagg_prefill_inflight_queue( except Exception as e: error_message += f" with exception {e}" logger.warning(error_message) + req.trace_metric_ctx.abort(abort_info={"abort_info": error_message}) release_kv_cache(req, self.tree_cache) # unlock the tree prepare_abort( req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR @@ -586,11 +586,10 @@ def process_disagg_prefill_inflight_queue( ) for req in done_reqs: req: Req - req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE) self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index) req.metadata_buffer_index = -1 - trace_slice( - RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_TRANSFER_KV_CACHE, thread_finish_flag=True ) self.disagg_prefill_inflight_queue = undone_reqs diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 6f69fd19b051..6c4f83180ad9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -184,7 +184,7 @@ def __init__(self, **kwargs): self.send_to_rpc = None # Enable tracing - if server_args.enable_trace: + if server_args.trace_level > 0: process_tracing_init(server_args.otlp_traces_endpoint, "sglang") thread_label = "Tokenizer" if server_args.disaggregation_mode == "prefill": diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 88705cc35a96..8a7a18002dc7 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -257,7 +257,7 @@ async def lifespan(fast_api_app: FastAPI): enable_func_timer() # Init tracing - if server_args.enable_trace: + if server_args.trace_level > 0: process_tracing_init(server_args.otlp_traces_endpoint, "sglang") if server_args.disaggregation_mode == "prefill": thread_label = "Prefill" + thread_label diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 4f297a32d995..aabf42d2f7f9 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -35,20 +35,18 @@ TokenizedGenerateReqInput, WatchLoadUpdateReq, ) -from sglang.srt.managers.schedule_batch import Req, RequestStage +from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import ( DP_ATTENTION_HANDSHAKE_PORT_DELTA, PortArgs, ServerArgs, ) -from sglang.srt.tracing.trace import ( - process_tracing_init, - trace_get_proc_propagate_context, - trace_set_proc_propagate_context, - trace_set_thread_info, - trace_slice_end, - trace_slice_start, +from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info +from sglang.srt.tracing.trace_metric_wrapper import ( + NullContext, + RequestStage, + TraceMetricContext, ) from sglang.srt.utils import numa_utils from sglang.srt.utils.common import ( @@ -220,15 +218,16 @@ def handle_load_update_req(self, obj): self.dp_budget.update_budget(obj) def dispatching_with_trace(self, req: Req): - if self.server_args.enable_trace: - trace_set_proc_propagate_context(req.rid, req.trace_context) - trace_slice_start(RequestStage.DC_DISPATCH, req.rid) - req.trace_context = trace_get_proc_propagate_context(req.rid) + if isinstance(req.trace_metric_ctx, TraceMetricContext): + req.trace_metric_ctx.rebuild_thread_context() + else: + req.trace_metric_ctx = NullContext() + req.trace_metric_ctx.slice_start(RequestStage.DC_DISPATCH) self.dispatching(req) - - if self.server_args.enable_trace: - trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True) + req.trace_metric_ctx.slice_end( + RequestStage.DC_DISPATCH, thread_finish_flag=True + ) def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( @@ -580,7 +579,7 @@ def run_data_parallel_controller_process( parent_process = psutil.Process().parent() configure_logger(server_args) - if server_args.enable_trace: + if server_args.trace_level > 0: process_tracing_init(server_args.otlp_traces_endpoint, "sglang") thread_label = "DP Controller" if server_args.disaggregation_mode == "prefill": diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2ecd8542f567..7d8c182f87e3 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -31,6 +31,7 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.tracing.trace_metric_wrapper import TraceMetricContext from sglang.srt.utils import ImageData # Handle serialization of Image for pydantic @@ -743,9 +744,6 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to disallow logging for this request (e.g. due to ZDR) no_logs: bool = False - # tracing context - trace_context: Optional[Dict] = None - # (Internal) Whether to return bytes for image generation return_bytes: bool = False @@ -755,6 +753,9 @@ class TokenizedGenerateReqInput(BaseReq): need_wait_for_image: bool = False num_items_assigned: Optional[List] = None + # For observability + trace_metric_ctx: Optional[Union[TraceMetricContext, Dict]] = None + @dataclass class BatchTokenizedGenerateReqInput(BaseBatchReq): @@ -916,6 +917,8 @@ class TokenizedEmbeddingReqInput(BaseReq): priority: Optional[int] = None # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings. dimensions: Optional[int] = None + # For observability + trace_metric_ctx: Optional[Union[TraceMetricContext, Dict]] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 17ea53769094..6e15cef64dbf 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,7 +1,5 @@ from __future__ import annotations -import enum - from sglang.srt.dllm.config import DllmConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -40,7 +38,6 @@ import dataclasses import logging import re -import time from enum import Enum, auto from http import HTTPStatus from itertools import chain @@ -83,6 +80,7 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs, get_global_server_args +from sglang.srt.tracing.trace_metric_wrapper import NullContext, TraceMetricContext from sglang.srt.utils import flatten_nested_list from sglang.srt.utils.cuda_ipc_transport_utils import CudaIpcTensorTransportProxy @@ -456,35 +454,6 @@ def merge(self, other: MultimodalInputs): # other args would be kept intact -class RequestStage(str, enum.Enum): - # Tokenizer - TOKENIZE = "tokenize" - TOKENIZER_DISPATCH = "dispatch" - - # DP controller - DC_DISPATCH = "dc_dispatch" - - # common/non-disaggregation - PREFILL_WAITING = "prefill_waiting" - REQUEST_PROCESS = "request_process" - DECODE_LOOP = "decode_loop" - PREFILL_FORWARD = "prefill_forward" - PREFILL_CHUNKED_FORWARD = "chunked_prefill" - - # disaggregation prefill - PREFILL_PREPARE = "prefill_prepare" - PREFILL_BOOTSTRAP = "prefill_bootstrap" - PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache" - - # disaggregation decode - DECODE_PREPARE = "decode_prepare" - DECODE_BOOTSTRAP = "decode_bootstrap" - DECODE_WAITING = "decode_waiting" - DECODE_TRANSFERRED = "decode_transferred" - DECODE_FAKE_OUTPUT = "fake_output" - DECODE_QUICK_FINISH = "quick_finish" - - class Req: """The input and output status of a request.""" @@ -520,6 +489,7 @@ def __init__( extra_key: Optional[str] = None, dimensions: Optional[int] = None, http_worker_ipc: Optional[str] = None, + trace_metric_ctx: Optional[TraceMetricContext] = None, ): # Input and output info self.rid = rid @@ -733,11 +703,11 @@ def __init__( self.retraction_count = 0 self.retraction_mb_id = None - # For metrics + # For metrics or trace self.metrics_collector = metrics_collector self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode) + self.trace_metric_ctx = trace_metric_ctx if trace_metric_ctx else NullContext() self.has_log_time_stats: bool = False - self.last_tic = time.monotonic() # For disaggregation self.bootstrap_host: str = bootstrap_host @@ -809,16 +779,6 @@ def pop_overallocated_kv_cache(self) -> Tuple[int, int]: self.kv_overallocated_freed = True return self.kv_committed_len, self.kv_allocated_len - def add_latency(self, stage: RequestStage): - if self.metrics_collector is None: - return - - now = time.monotonic() - self.metrics_collector.observe_per_stage_req_latency( - stage.value, now - self.last_tic - ) - self.last_tic = now - def extend_image_inputs(self, image_inputs): if self.multimodal_inputs is None: self.multimodal_inputs = image_inputs diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4f484fae5d4a..90e2957d7126 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -129,7 +129,6 @@ ModelWorkerBatch, MultimodalInputs, Req, - RequestStage, ScheduleBatch, ) from sglang.srt.managers.schedule_policy import ( @@ -166,14 +165,14 @@ from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.tracing.trace import ( - process_tracing_init, +from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info +from sglang.srt.tracing.trace_metric_wrapper import ( + NullContext, + RequestStage, + StageMetricContext, + TraceMetricContext, + metric_trace_slice_batch, trace_event_batch, - trace_set_proc_propagate_context, - trace_set_thread_info, - trace_slice_batch, - trace_slice_end, - trace_slice_start, ) from sglang.srt.utils import ( DynamicGradMode, @@ -292,7 +291,6 @@ def __init__( self.enable_kv_cache_events = bool( server_args.kv_events_config and tp_rank == 0 ) - self.enable_trace = server_args.enable_trace self.stream_interval = server_args.stream_interval self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm @@ -1281,14 +1279,6 @@ def recv_requests( ): recv_reqs = self.mm_receiver.process_waiting_requests(recv_reqs) - if self.enable_trace: - for req in recv_reqs: - if isinstance( - req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) - ): - trace_set_proc_propagate_context(req.rid, req.trace_context) - trace_slice_start("", req.rid, anonymous=True) - return recv_reqs def _split_work_and_control_reqs(self, recv_reqs: List): @@ -1424,6 +1414,7 @@ def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, ): + self._req_trace_metric_ctx_init(recv_req) # Create a new request if ( recv_req.session_params is None @@ -1468,6 +1459,7 @@ def handle_generate_request( ), http_worker_ipc=recv_req.http_worker_ipc, dllm_config=self.dllm_config, + trace_metric_ctx=recv_req.trace_metric_ctx, ) req.tokenizer = self.tokenizer @@ -1479,6 +1471,9 @@ def handle_generate_request( f"boostrap room id. {req.rid=}" ) logger.error(error_msg) + recv_req.trace_metric_ctx.abort( + abort_info={"abort_info": error_msg} + ) prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST) self.stream_output([req], req.return_logprob) return @@ -1497,6 +1492,11 @@ def handle_generate_request( # Create a new request from a previous session session = self.sessions[recv_req.session_params.id] req = session.create_req(recv_req, self.tokenizer) + req.trace_metric_ctx = ( + recv_req.trace_metric_ctx + if recv_req.trace_metric_ctx + else NullContext() + ) if isinstance(req.finished_reason, FINISH_ABORT): self.init_req_max_new_tokens(req) self._add_request_to_queue(req) @@ -1645,7 +1645,9 @@ def _add_request_to_queue(self, req: Req, is_retracted: bool = False): self._prefetch_kvcache(req) self.waiting_queue.append(req) req.time_stats.wait_queue_entry_time = time.perf_counter() - trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True) + req.trace_metric_ctx.slice_end( + RequestStage.REQUEST_PROCESS, auto_next_anon=True + ) elif self.disaggregation_mode == DisaggregationMode.PREFILL: self._prefetch_kvcache(req) self.disagg_prefill_bootstrap_queue.add( @@ -1679,6 +1681,7 @@ def _set_or_validate_priority(self, req: Req) -> bool: }, rid=req.rid, ) + req.trace_metric_ctx.abort(abort_info=abort_req.finished_reason) self.send_to_tokenizer.send_output(abort_req, req) return False return True @@ -1724,12 +1727,14 @@ def _abort_on_queued_limit(self, recv_req: Req) -> bool: ), req_to_abort, ) + req_to_abort.trace_metric_ctx.abort(abort_info={"abort_info": message}) return req_to_abort.rid == recv_req.rid def handle_embedding_request( self, recv_req: TokenizedEmbeddingReqInput, ): + self._req_trace_metric_ctx_init(recv_req) req = Req( recv_req.rid, recv_req.input_text, @@ -1739,6 +1744,7 @@ def handle_embedding_request( priority=recv_req.priority, dimensions=recv_req.dimensions, http_worker_ipc=recv_req.http_worker_ipc, + trace_metric_ctx=recv_req.trace_metric_ctx, ) req.tokenizer = self.tokenizer @@ -2022,11 +2028,6 @@ def _get_new_batch_prefill_raw( if len(can_run_list) == 0: return None - if self.enable_metrics: - # only record queue time when enable_metrics is True to avoid overhead - for req in can_run_list: - req.add_latency(RequestStage.PREFILL_WAITING) - self.waiting_queue = [ x for x in self.waiting_queue if x not in set(can_run_list) ] @@ -2059,6 +2060,14 @@ def _get_new_batch_prefill_raw( self.metrics_collector.observe_queue_time( req.time_stats.get_queueing_time(), ) + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_WAITING, auto_next_anon=True + ) + + if self.chunked_req and self.chunked_req.is_chunked == 1: + self.chunked_req.trace_metric_ctx.slice_start( + RequestStage.PREFILL_CHUNKED_FORWARD + ) # Create a new batch new_batch = ScheduleBatch.init_new( @@ -2352,7 +2361,7 @@ def process_batch_result( ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) - trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs) + metric_trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs) elif batch.forward_mode.is_extend(): if batch.is_dllm(): self.process_batch_result_dllm(batch, result) @@ -2806,6 +2815,28 @@ def update_cache_from_scheduler( def get_remote_instance_transfer_engine_info(self): return self.tp_worker.get_remote_instance_transfer_engine_info() + def _req_trace_metric_ctx_init( + self, req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput] + ): + if isinstance(req.trace_metric_ctx, TraceMetricContext): + req.trace_metric_ctx.rebuild_thread_context() + metrics_collector = ( + self.metrics_collector if self.server_args.enable_metrics else None + ) + req.trace_metric_ctx.reinit_metric_ctx( + self.server_args.enable_metrics, + metrics_collector=metrics_collector, + ) + elif self.server_args.enable_metrics: + req.trace_metric_ctx = StageMetricContext( + True, + metrics_collector=self.metrics_collector, + ) + else: + req.trace_metric_ctx = NullContext() + + req.trace_metric_ctx.slice_start(RequestStage.ANONYMOUS) + class IdleSleeper: """ @@ -2922,7 +2953,7 @@ def run_scheduler_process( numa_bind_to_node(numa_node[gpu_id]) # Set up tracing - if server_args.enable_trace: + if server_args.trace_level > 0: process_tracing_init(server_args.otlp_traces_endpoint, "sglang") thread_label = "Scheduler" if server_args.disaggregation_mode == "prefill": diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index e40586c24cc1..688454adb647 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -15,15 +15,10 @@ BatchEmbeddingOutput, BatchTokenIDOutput, ) -from sglang.srt.managers.schedule_batch import ( - BaseFinishReason, - Req, - RequestStage, - ScheduleBatch, -) +from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.server_args import get_global_server_args -from sglang.srt.tracing.trace import trace_slice, trace_slice_batch, trace_slice_end +from sglang.srt.tracing.trace_metric_wrapper import RequestStage if TYPE_CHECKING: from sglang.srt.managers.scheduler import ( @@ -52,15 +47,18 @@ def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): req.time_stats.forward_entry_time = req.time_stats.completion_time = ( time.perf_counter() ) - trace_slice_end( + req.trace_metric_ctx.slice_end( RequestStage.DECODE_QUICK_FINISH, - req.rid, thread_finish_flag=True, ) release_kv_cache(req, self.tree_cache) + else: + req.trace_metric_ctx.slice_end( + RequestStage.DECODE_FAKE_OUTPUT, + auto_next_anon=True, + ) # Note: Logprobs should be handled on the prefill engine. - trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs) self.stream_output(batch.reqs, batch.return_logprob) def maybe_collect_routed_experts(self: Scheduler, req: Req): @@ -195,9 +193,8 @@ def process_batch_result_prefill( self.abort_request(AbortReq(rid=req.rid)) req.grammar.finished = req.finished() - trace_slice( + req.trace_metric_ctx.slice_end( RequestStage.PREFILL_FORWARD, - req.rid, auto_next_anon=not req.finished(), thread_finish_flag=req.finished(), ) @@ -230,10 +227,9 @@ def process_batch_result_prefill( ) logprob_pt += num_input_logprobs - trace_slice( + req.trace_metric_ctx.slice_end( RequestStage.PREFILL_CHUNKED_FORWARD, - req.rid, - auto_next_anon=True, + auto_next_anon=(req.is_chunked != 0), ) else: # embedding or reward model @@ -270,6 +266,12 @@ def process_batch_result_prefill( req.output_ids.append(0) req.check_finished() + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_FORWARD, + auto_next_anon=not req.finished(), + thread_finish_flag=req.finished(), + ) + if req.finished(): release_kv_cache(req, self.tree_cache) else: @@ -278,12 +280,10 @@ def process_batch_result_prefill( # being chunked reqs' prefill is not finished req.is_chunked -= 1 - trace_slice( - RequestStage.PREFILL_FORWARD, - req.rid, - auto_next_anon=not req.finished(), - thread_finish_flag=req.finished(), - ) + req.trace_metric_ctx.slice_end( + RequestStage.PREFILL_CHUNKED_FORWARD, + auto_next_anon=(req.is_chunked != 0), + ) self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 45cf37b48161..3963a6298827 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -72,7 +72,7 @@ from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.request_metrics_exporter import RequestMetricsExporterManager -from sglang.srt.managers.schedule_batch import MultimodalDataItem, RequestStage +from sglang.srt.managers.schedule_batch import MultimodalDataItem from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin @@ -87,14 +87,15 @@ set_global_server_args_for_tokenizer, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.tracing.trace import ( - extract_trace_headers, - trace_get_proc_propagate_context, - trace_req_finish, - trace_req_start, - trace_set_remote_propagate_context, - trace_slice_end, - trace_slice_start, +from sglang.srt.tracing.trace import extract_trace_headers +from sglang.srt.tracing.trace_metric_wrapper import ( + RequestStage, + TraceMetricContext, + global_del_trace_metric_ctx, + global_get_trace_metric_ctx, + global_set_trace_metric_ctx, + metric_trace_slice_batch, + metric_trace_slice_start_batch, ) from sglang.srt.utils import ( configure_gc_warning, @@ -186,7 +187,6 @@ def __init__( self.enable_metrics = server_args.enable_metrics self.preferred_sampling_params = server_args.preferred_sampling_params self.crash_dump_folder = server_args.crash_dump_folder - self.enable_trace = server_args.enable_trace set_global_server_args_for_tokenizer(server_args) # Init model config @@ -478,8 +478,8 @@ async def generate_request( # Normalize the request obj.normalize_batch_and_arguments() - if self.enable_trace: - self._trace_request_start(obj, created_time, request) + if self.server_args.trace_level > 0: + self._req_trace_metric_ctx_init(obj, created_time, request) if self.server_args.language_only: self._handle_epd_disaggregation_encode_request(obj) if self.server_args.tokenizer_worker_num > 1: @@ -723,7 +723,7 @@ async def _tokenize_one_request( mm_inputs = None self._validate_one_request(obj, input_ids) - trace_slice_end(RequestStage.TOKENIZE, obj.rid) + global_get_trace_metric_ctx(obj.rid).slice_end(RequestStage.TOKENIZE) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids ) @@ -939,6 +939,8 @@ def _create_tokenized_object( http_worker_ipc=obj.http_worker_ipc, ) + tokenized_obj.trace_metric_ctx = global_get_trace_metric_ctx(obj.rid) + return tokenized_obj async def _batch_tokenize_and_process( @@ -982,7 +984,7 @@ async def _batch_tokenize_and_process( req, req.text, input_ids_list[i], None, None, token_type_ids ) ) - trace_slice_end(RequestStage.TOKENIZE, req.rid) + global_get_trace_metric_ctx(req.rid).slice_end(RequestStage.TOKENIZE) logger.debug(f"Completed batch processing for {batch_size} requests") return tokenized_objs @@ -1038,14 +1040,13 @@ def _send_one_request( tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], created_time: Optional[float] = None, ): - trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid) - tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) + tokenized_obj.trace_metric_ctx.slice_start(RequestStage.TOKENIZER_DISPATCH) self.send_to_scheduler.send_pyobj(tokenized_obj) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) state.request_sent_to_scheduler_ts = time.time() self.rid_to_state[obj.rid] = state - trace_slice_end( - RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True + tokenized_obj.trace_metric_ctx.slice_end( + RequestStage.TOKENIZER_DISPATCH, thread_finish_flag=True ) return state @@ -1058,6 +1059,7 @@ def _send_batch_request( created_time: Optional[float] = None, ): """Send a batch of tokenized requests as a single batched request to the scheduler.""" + metric_trace_slice_start_batch(RequestStage.TOKENIZER_DISPATCH, tokenized_objs) if isinstance(tokenized_objs[0], TokenizedGenerateReqInput): batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs) else: @@ -1071,6 +1073,9 @@ def _send_batch_request( [], False, asyncio.Event(), tmp_obj, created_time=created_time ) self.rid_to_state[tmp_obj.rid] = state + metric_trace_slice_batch( + RequestStage.TOKENIZER_DISPATCH, tokenized_objs, thread_finish_flag=True + ) async def _wait_one_response( self, @@ -1572,7 +1577,9 @@ def _handle_batch_output( if self.enable_metrics: self._calculate_timing_metrics(meta_info, state, recv_obj, i) - trace_req_finish(rid, ts=int(state.finished_time * 1e9)) + trace_metric_ctx = global_get_trace_metric_ctx(rid) + trace_metric_ctx.trace_req_finish(ts=int(state.finished_time * 1e9)) + global_del_trace_metric_ctx(rid) del self.rid_to_state[rid] @@ -2160,7 +2167,7 @@ async def _resolve_lora_path(self, obj: Union[GenerateReqInput, EmbeddingReqInpu # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests. obj.lora_id = await self.lora_registry.acquire(obj.lora_path) - def _trace_request_start( + def _req_trace_metric_ctx_init( self, obj: Union[GenerateReqInput, EmbeddingReqInput], created_time: Optional[float] = None, @@ -2168,10 +2175,7 @@ def _trace_request_start( ): external_trace_header = None if request: - if "trace_context" in request.headers: - trace_set_remote_propagate_context(request.headers["trace_context"]) - else: - external_trace_header = extract_trace_headers(request.headers) + external_trace_header = extract_trace_headers(request.headers) elif obj.external_trace_header: # When the request comes form the rust grpc server or Engine there isn't a # real request object but we still need to propagate the trace context from @@ -2182,14 +2186,24 @@ def _trace_request_start( bootstrap_room = ( obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None ) - trace_req_start( - obj.rid, - bootstrap_room, - ts=int(created_time * 1e9), - role=self.server_args.disaggregation_mode, - external_trace_header=external_trace_header, + trace_metric_ctx = TraceMetricContext( + rid=obj.rid, + bootstrap_room=bootstrap_room, + module_name="request", + server_args=self.server_args, + ) + if not trace_metric_ctx.tracing_enable: + return + + # store into global table, + # because trace_metric_ctx can not be passed to _handle_batch_output + global_set_trace_metric_ctx(trace_metric_ctx) + trace_metric_ctx.trace_req_start( + ts=int(created_time * 1e9), external_trace_header=external_trace_header + ) + trace_metric_ctx.slice_start( + RequestStage.ANONYMOUS, ts=int(created_time * 1e9) ) - trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) else: for i in range(len(obj.rid)): bootstrap_room = ( @@ -2197,15 +2211,23 @@ def _trace_request_start( if hasattr(obj, "bootstrap_room") and obj.bootstrap_room else None ) - trace_req_start( - obj.rid[i], - bootstrap_room, + trace_metric_ctx = TraceMetricContext( + rid=obj.rid[i], + bootstrap_room=bootstrap_room, + module_name="request", + server_args=self.server_args, + ) + if not trace_metric_ctx.tracing_enable: + return + + global_set_trace_metric_ctx(trace_metric_ctx) + trace_metric_ctx.trace_req_start( ts=int(created_time * 1e9), - role=self.server_args.disaggregation_mode, external_trace_header=external_trace_header, ) - trace_slice_start( - "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True + trace_metric_ctx.slice_start( + RequestStage.ANONYMOUS, + ts=int(created_time * 1e9), ) def _handle_epd_disaggregation_encode_request( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ca32e98c4221..499ca3e2077b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -356,7 +356,8 @@ class ServerArgs: decode_log_interval: int = 40 enable_request_time_stats_logging: bool = False kv_events_config: Optional[str] = None - enable_trace: bool = False + trace_level: int = 0 + trace_module: str = "request" otlp_traces_endpoint: str = "localhost:4317" # RequestMetricsExporter configuration @@ -3144,9 +3145,16 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.", ) parser.add_argument( - "--enable-trace", - action="store_true", - help="Enable opentelemetry trace", + "--trace-level", + type=int, + default=ServerArgs.trace_level, + help="0: disable tracing. 1: Trace important slices. 2: Trace all slices except nested ones. 3: Trace all slices.", + ) + parser.add_argument( + "--trace-module", + type=str, + default=ServerArgs.trace_module, + help="Enable opentelemetry trace level", ) parser.add_argument( "--otlp-traces-endpoint", diff --git a/python/sglang/srt/tracing/trace.py b/python/sglang/srt/tracing/trace.py index 355c477b0576..a03ec962537f 100644 --- a/python/sglang/srt/tracing/trace.py +++ b/python/sglang/srt/tracing/trace.py @@ -15,8 +15,6 @@ from __future__ import annotations -import base64 -import json import logging import os import random @@ -24,17 +22,13 @@ import time import uuid from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import Any, Dict, List, Mapping, Optional from sglang.srt.utils import get_int_env_var -if TYPE_CHECKING: - from sglang.srt.managers.scheduler import Req -from typing import Any, Dict, List, Mapping, Optional - logger = logging.getLogger(__name__) opentelemetry_imported = False -tracing_enabled = False +opentelemetry_initialized = False _trace_context_propagator = None TRACE_HEADERS = ["traceparent", "tracestate"] @@ -53,6 +47,7 @@ from opentelemetry.sdk.resources import SERVICE_NAME, Resource from opentelemetry.sdk.trace import TracerProvider, id_generator from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.trace import Status, StatusCode from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator, ) @@ -69,16 +64,12 @@ class IdGenerator: logger.debug("opentelemetry package is not installed, tracing disabled") -def is_tracing_enabled() -> bool: - return tracing_enabled - - def extract_trace_headers(headers: Mapping[str, str]) -> Optional[Dict]: return {h: headers[h] for h in TRACE_HEADERS if h in headers} @dataclass -class SglangTraceThreadInfo: +class TraceThreadInfo: host_id: str pid: int thread_label: str @@ -88,79 +79,39 @@ class SglangTraceThreadInfo: @dataclass -class SglangTraceSliceContext: +class TraceEvent: + event_name: str + ts: int + attrs: Dict[str, Any] + + +@dataclass +class TraceSliceContext: slice_name: str + start_time_ns: int + end_time_ns: Optional[int] = None span: Optional[trace.span.Span] = None # When True, defers slice_name assignment until trace_slice_end() anonymous: bool = False + # For nested slices, if parent slice is anonymous, + # child slice will be create lazily ultil parent slice_name is assigned. + lazy_flag: bool = False + level: int = 1 + attrs: Optional[Dict[str, Any]] = None + events: Optional[List[TraceEvent]] = None + parent_slice: Optional[TraceSliceContext] = None + child_slices: Optional[List[TraceSliceContext]] = None + prev_span_context: Optional[trace.span.SpanContext] = None @dataclass -class SglangTraceThreadContext: - thread_info: SglangTraceThreadInfo - cur_slice_stack: List[SglangTraceSliceContext] +class TraceThreadContext: + thread_info: TraceThreadInfo + cur_slice: Optional[TraceSliceContext] = None thread_span: Optional[trace.span.Span] = None - # Record the most recently completed span as the previous span for the next span to be created. - last_span_context: Optional[trace.span.SpanContext] = None - - -@dataclass -class SglangTraceReqContext: - rid: str - start_time_ns: int - threads_context: Dict[int, SglangTraceThreadContext] - bootstrap_room: Optional[int] = None - - # Indicates whether this instance is a replica from the main process. - # When True, root_span is None and only root_span_context is preserved. - is_copy: bool = False - bootstrap_room_span: Optional[trace.span.Span] = None - bootstrap_room_span_context: Optional[context.Context] = None - root_span: Optional[trace.span.Span] = None - root_span_context: Optional[context.Context] = None - - -@dataclass -class SglangTracePropagateContext: - root_span_context: context.Context - prev_span_context: Optional[trace.span.SpanContext] - - def to_dict(self): - carrier: dict[str, str] = {} - propagate.inject(carrier, self.root_span_context) - - if self.prev_span_context: - return { - "root_span": carrier, - "prev_span": { - "span_id": self.prev_span_context.span_id, - "trace_id": self.prev_span_context.trace_id, - }, - } - else: - return {"root_span": carrier, "prev_span": "None"} - - @classmethod - def instance_from_dict(cls, d): - if "root_span" not in d or "prev_span" not in d: - return None - carrier = d["root_span"] - root_span_context = propagate.extract(carrier) - if d["prev_span"] == "None": - prev_span_context = None - else: - prev_span_context = trace.span.SpanContext( - trace_id=d["prev_span"]["trace_id"], - span_id=d["prev_span"]["span_id"], - is_remote=True, - ) - - return cls(root_span_context, prev_span_context) - - -class SglangTraceCustomIdGenerator(id_generator.IdGenerator): +class TraceCustomIdGenerator(id_generator.IdGenerator): """ The default IdGenerator may produce duplicate trace IDs across multiple TP scheduler processes, hence a custom IdGenerator is implemented. @@ -179,11 +130,11 @@ def generate_span_id(self) -> int: # global variables -remote_trace_contexts: Dict[str, SglangTracePropagateContext] = {} -threads_info: Dict[int, SglangTraceThreadInfo] = {} -reqs_context: Dict[str, SglangTraceReqContext] = {} +threads_info: Dict[int, TraceThreadInfo] = {} -__get_cur_time_ns = lambda: int(time.time() * 1e9) +get_cur_time_ns = lambda: int(time.time() * 1e9) +if hasattr(time, "time_ns"): + get_cur_time_ns = lambda: int(time.time_ns()) def __get_host_id() -> str: @@ -208,12 +159,13 @@ def __get_host_id() -> str: # Should be called by each tracked process. def process_tracing_init(otlp_endpoint, server_name): - global tracing_enabled - global __get_cur_time_ns + global opentelemetry_initialized + global get_cur_time_ns if not opentelemetry_imported: - logger.warning(f"Tracing is disabled because the packages cannot be imported.") - tracing_enabled = False - return + opentelemetry_initialized = False + raise RuntimeError( + "opentelemetry package is not installed!!! Please not enable tracing or install opentelemetry" + ) try: resource = Resource.create( @@ -222,7 +174,7 @@ def process_tracing_init(otlp_endpoint, server_name): } ) tracer_provider = TracerProvider( - resource=resource, id_generator=SglangTraceCustomIdGenerator() + resource=resource, id_generator=TraceCustomIdGenerator() ) schedule_delay_millis = get_int_env_var( @@ -240,16 +192,16 @@ def process_tracing_init(otlp_endpoint, server_name): tracer_provider.add_span_processor(processor) trace.set_tracer_provider(tracer_provider) except Exception as e: - logger.error( - f"Initialize OpenTelemetry error: {e}. Please set correct otlp endpoint." + opentelemetry_initialized = False + raise RuntimeError( + f"initialize opentelemetry error:{e}. Please set correct otlp endpoint." ) - tracing_enabled = False - return - if hasattr(time, "time_ns"): - __get_cur_time_ns = lambda: int(time.time_ns()) + opentelemetry_initialized = True + - tracing_enabled = True +def get_opentelemetry_initialized(): + return opentelemetry_initialized def get_otlp_span_exporter(endpoint): @@ -272,14 +224,14 @@ def get_otlp_span_exporter(endpoint): def trace_set_thread_info( thread_label: str, tp_rank: Optional[int] = None, dp_rank: Optional[int] = None ): - if not tracing_enabled: + if not opentelemetry_initialized: return pid = threading.get_native_id() if pid in threads_info: return - threads_info[pid] = SglangTraceThreadInfo( + threads_info[pid] = TraceThreadInfo( host_id=__get_host_id(), pid=pid, thread_label=thread_label, @@ -289,451 +241,454 @@ def trace_set_thread_info( ) -def __create_thread_context(pid, req_span_context, ts: Optional[int] = None): - if pid not in threads_info: - trace_set_thread_info("unknown") - - thread_info = threads_info[pid] - thread_context = SglangTraceThreadContext( - thread_info=thread_info, - cur_slice_stack=[], - ) - - thread_name = f"{thread_info.thread_label}" - if thread_info.tp_rank is not None: - thread_name += f" [TP {thread_info.tp_rank}] " - thread_name += f"(host:{thread_info.host_id[:8]} | pid:{pid})" - ts = ts or __get_cur_time_ns() - thread_context.thread_span = thread_context.thread_info.tracer.start_span( - name=thread_name, - start_time=ts, - context=req_span_context, - ) - - if thread_info.tp_rank is not None: - thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank}) - - thread_context.thread_span.set_attributes( - { - "host_id": thread_info.host_id, - "pid": thread_info.pid, - "thread_label": thread_info.thread_label, - } - ) - - return thread_context - - -def trace_get_proc_propagate_context( - rid, remote_propagate=False -) -> Optional[Dict[str, Any]]: - if not tracing_enabled: - return None - - rid = str(rid) - if rid not in reqs_context or not reqs_context[rid].root_span_context: - return None - - pid = threading.get_native_id() - prev_span_context = None - thread_context = reqs_context[rid].threads_context[pid] - if thread_context.cur_slice_stack: - cur_slice_info = thread_context.cur_slice_stack[0] - prev_span_context = cur_slice_info.span.get_span_context() - elif thread_context.last_span_context: - prev_span_context = thread_context.last_span_context - - root_span_context = reqs_context[rid].root_span_context - if remote_propagate: - root_span_context = reqs_context[rid].bootstrap_room_span_context - - trace_context = SglangTracePropagateContext(root_span_context, prev_span_context) - return trace_context.to_dict() - - -def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]]): - if not tracing_enabled: - return - if not trace_context: - return - - trace_context = SglangTracePropagateContext.instance_from_dict(trace_context) - if not trace_context: - return - - rid = str(rid) - # Create a copy of the request context - if rid not in reqs_context: - reqs_context[rid] = SglangTraceReqContext( - rid=rid, - start_time_ns=__get_cur_time_ns(), - threads_context={}, - root_span_context=trace_context.root_span_context, - is_copy=True, +class TraceReqContext: + def __init__( + self, + rid, + bootstrap_room=None, + role="null", + tracing_enable=False, + trace_level=1, + module_name="", + ): + self.tracing_enable: bool = tracing_enable and opentelemetry_initialized + self.rid: str = str(rid) + if not self.tracing_enable: + return + + self.start_time_ns: Optional[int] = None + self.thread_context: Optional[TraceThreadContext] = None + self.bootstrap_room: Optional[int] = bootstrap_room + self.role: str = role + + self.trace_level = trace_level + self.module_name = module_name + + # Indicates whether this instance is a replica from the main process. + # When True, root_span is None and only root_span_context is preserved. + self.is_copy: bool = False + self.root_span: Optional[trace.span.Span] = None + self.root_span_context: Optional[context.Context] = None + # Record the most recently completed span as the previous span for the next span to be created. + self.last_span_context: Optional[trace.span.SpanContext] = None + + self.pid: int = threading.get_native_id() + + def is_tracing_enabled(self) -> bool: + return self.tracing_enable + + def __create_thread_context(self, ts: int): + if self.pid not in threads_info: + trace_set_thread_info("unknown") + + thread_info = threads_info[self.pid] + thread_context = TraceThreadContext( + thread_info=thread_info, ) - pid = threading.get_native_id() - - if pid in reqs_context[rid].threads_context: - return - - # Create new thread context. - reqs_context[rid].threads_context[pid] = __create_thread_context( - pid, - trace_context.root_span_context, - reqs_context[rid].start_time_ns, - ) - - reqs_context[rid].threads_context[ - pid - ].last_span_context = trace_context.prev_span_context - - -def trace_get_remote_propagate_context(bootstrap_room_list: List[str]): - if not tracing_enabled: - return "" - - reqs_trace_contexts = {} - for bootstrap_room in bootstrap_room_list: - # In the router, rid is also the bootstrap room. - bootstrap_room = str(bootstrap_room) - - if bootstrap_room not in reqs_context: - continue - - _context = trace_get_proc_propagate_context( - bootstrap_room, remote_propagate=True + thread_name = f"{thread_info.thread_label}" + if thread_info.tp_rank is not None: + thread_name += f" [TP {thread_info.tp_rank}] " + thread_name += f"(host:{thread_info.host_id[:8]} | pid:{self.pid})" + thread_context.thread_span = thread_context.thread_info.tracer.start_span( + name=thread_name, + start_time=ts, + context=self.root_span_context, ) - reqs_trace_contexts[bootstrap_room] = _context - json_str = json.dumps(reqs_trace_contexts, ensure_ascii=False) - return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") + if thread_info.tp_rank is not None: + thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank}) + thread_context.thread_span.set_attributes( + { + "host_id": thread_info.host_id, + "pid": thread_info.pid, + "thread_label": thread_info.thread_label, + } + ) -def trace_set_remote_propagate_context(base64_str): - if not tracing_enabled: - return - - if base64_str is None or base64_str == "" or base64_str == "None": - return - - base64_bytes = base64.b64decode(base64_str) - json_str = base64_bytes.decode("utf-8") - remote_reqs_trace_contexts = json.loads(json_str) + return thread_context + + def __getstate__(self) -> Optional[Dict[str, Any]]: + if not self.tracing_enable: + return {"tracing_enable": False} + + if not self.root_span_context: + return {"tracing_enable": False} + + state = { + "tracing_enable": self.tracing_enable, + "rid": self.rid, + "bootstrap_room": self.bootstrap_room, + "start_time_ns": self.start_time_ns, + "role": self.role, + "trace_level": self.trace_level, + "module_name": self.module_name, + "is_copy": self.is_copy, + "pid": self.pid, + "thread_context": None, + "root_span": None, + "last_span_context": None, + } - for bootstrap_room in remote_reqs_trace_contexts: - if bootstrap_room in remote_trace_contexts: - continue + carrier: dict[str, str] = {} + propagate.inject(carrier, self.root_span_context) + state["root_span_context"] = carrier + + prev_span_context = self.last_span_context + if self.thread_context and self.thread_context.cur_slice: + cur_slice = self.thread_context.cur_slice + if cur_slice.span: + prev_span_context = cur_slice.span.get_span_context() + + if prev_span_context: + state["last_span_context"] = { + "span_id": prev_span_context.span_id, + "trace_id": prev_span_context.trace_id, + } - remote_trace_contexts[bootstrap_room] = ( - SglangTracePropagateContext.instance_from_dict( - remote_reqs_trace_contexts[bootstrap_room] + return state + + def __setstate__(self, state: Dict[str, Any]): + self.__dict__.update(state) + if not opentelemetry_initialized: + self.tracing_enable = False + if not self.tracing_enable: + return + + self.is_copy = True + self.pid = threading.get_native_id() + self.root_span_context = propagate.extract(self.root_span_context) + if self.last_span_context: + self.last_span_context = trace.span.SpanContext( + trace_id=self.last_span_context["trace_id"], + span_id=self.last_span_context["span_id"], + is_remote=True, ) - ) - -def trace_req_start( - rid: str, - bootstrap_room: Optional[int] = None, - ts: Optional[int] = None, - role: Optional[str] = "null", - external_trace_header: Optional[Dict[str, str]] = None, -): - if not tracing_enabled: - return + def rebuild_thread_context(self, ts: Optional[int] = None): + if not self.tracing_enable: + return - rid = str(rid) + ts = ts or get_cur_time_ns() + self.thread_context = self.__create_thread_context(ts) - ts = ts or __get_cur_time_ns() + def trace_req_start( + self, + ts: Optional[int] = None, + external_trace_header: Optional[Dict[str, str]] = None, + ): + if not self.tracing_enable: + return - pid = threading.get_native_id() - if pid not in threads_info: - return + ts = ts or get_cur_time_ns() - # create req context and root span - bootstrap_room = 0 if bootstrap_room is None else bootstrap_room - reqs_context[rid] = SglangTraceReqContext( - rid=rid, - start_time_ns=ts, - threads_context={}, - bootstrap_room=bootstrap_room, - is_copy=False, - ) + # create req context and root span + self.start_time_ns = ts - # create bootstrap room span - tracer = threads_info[pid].tracer - if str(bootstrap_room) not in remote_trace_contexts: - attrs = {"bootstrap_room": str(hex(bootstrap_room))} + tracer = threads_info[self.pid].tracer external_trace_context = _trace_context_propagator.extract( external_trace_header or {} ) - bootstrap_room_span = tracer.start_span( - name=f"Bootstrap Room {hex(bootstrap_room)}", + + # Drop the worker_id added by MultiTokenizer + orig_rid = self.rid.split("_")[-1] + role = "" if self.role == "null" else self.role + attrs = {"rid": orig_rid, "module": f"sglang::{self.module_name}"} + if self.bootstrap_room: + attrs["bootstrap_room"] = str(hex(self.bootstrap_room)) + root_span = tracer.start_span( + name=f"{role} Req {orig_rid[:8]}", start_time=ts, - attributes=attrs, context=external_trace_context, + attributes=attrs, ) - reqs_context[rid].bootstrap_room_span = bootstrap_room_span - bootstrap_room_span_context = trace.set_span_in_context(bootstrap_room_span) - else: - bootstrap_room_span_context = remote_trace_contexts[ - str(bootstrap_room) - ].root_span_context - - # Drop the worker_id added by MultiTokenizer - orig_rid = rid.split("_")[-1] - role = "" if role == "null" else role - attrs = {"rid": orig_rid} - root_span = tracer.start_span( - name=f"{role} Req {orig_rid[:8]}", - start_time=ts, - context=bootstrap_room_span_context, - attributes=attrs, - ) - - root_span.set_attributes( - { - "rid": rid, - } - ) - - reqs_context[rid].root_span = root_span - reqs_context[rid].root_span_context = trace.set_span_in_context(root_span) - reqs_context[rid].bootstrap_room_span_context = bootstrap_room_span_context - - # create thread context and thread span - reqs_context[rid].threads_context[pid] = __create_thread_context( - pid, - reqs_context[rid].root_span_context, - ts, - ) - if str(bootstrap_room) in remote_trace_contexts: - reqs_context[rid].threads_context[pid].last_span_context = ( - remote_trace_contexts[str(bootstrap_room)].prev_span_context - ) - - -def trace_req_finish( - rid: str, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - req_context = reqs_context[rid] - ts = ts or __get_cur_time_ns() - - # End all unclosed thread spans. - for thread_context in req_context.threads_context.values(): - thread_context.thread_span.end(end_time=ts) - - if attrs: - req_context.root_span.set_attributes(attrs) - - req_context.root_span.end(end_time=ts) - if str(req_context.bootstrap_room) in remote_trace_contexts: - del remote_trace_contexts[str(req_context.bootstrap_room)] - elif req_context.bootstrap_room_span: - req_context.bootstrap_room_span.end(end_time=ts) - - del reqs_context[rid] - - -def trace_slice_start( - name: str, - rid: str, - ts: Optional[int] = None, - anonymous: bool = False, -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return - - thread_context = reqs_context[rid].threads_context[pid] - - ts = ts or __get_cur_time_ns() - - slice_info = SglangTraceSliceContext( - slice_name=name, - anonymous=anonymous, - ) - - # find prev slice - prev_span_context = None - if not thread_context.cur_slice_stack: - if thread_context.last_span_context: - prev_span_context = thread_context.last_span_context - - parent_span = thread_context.thread_span - if thread_context.cur_slice_stack: - parent_span = thread_context.cur_slice_stack[-1].span - - parent_span_context = trace.set_span_in_context(parent_span) - span = thread_context.thread_info.tracer.start_span( - name=slice_info.slice_name, - start_time=ts, - context=parent_span_context, - ) - - if prev_span_context: - span.add_link(prev_span_context) - - slice_info.span = span - thread_context.cur_slice_stack.append(slice_info) + self.root_span = root_span + self.root_span_context = trace.set_span_in_context(root_span) + # create thread context and thread span + self.thread_context = self.__create_thread_context(ts) -def trace_slice_end( - name: str, - rid: str, - ts: Optional[int] = None, - attrs: Optional[Dict[str, Any]] = None, - auto_next_anon: bool = False, - thread_finish_flag: bool = False, -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return - - thread_context = reqs_context[rid].threads_context[pid] - - if not thread_context.cur_slice_stack: - logger.warning(f"No matching with the SLICE_START event{name} is required.") - return - - ts = ts or __get_cur_time_ns() - slice_info = thread_context.cur_slice_stack[-1] - span = slice_info.span - - if slice_info.anonymous: - span.update_name(name) - else: - span = slice_info.span - if slice_info.slice_name != name: - span.set_status(trace.Status(trace.StatusCode.ERROR)) - logger.warning(f"Slice name mismatch: {name} != {slice_info.slice_name}") - - if attrs: - span.set_attributes(attrs) - - span.end(end_time=ts) - - thread_context.cur_slice_stack.pop() - if len(thread_context.cur_slice_stack) == 0: - thread_context.last_span_context = span.get_span_context() - - # If this is the last slice in the thread, - # release the thread context and check whether to release the request context. - if thread_finish_flag: - thread_context.thread_span.end(end_time=ts) - del reqs_context[rid].threads_context[pid] - if reqs_context[rid].is_copy and not reqs_context[rid].threads_context: - del reqs_context[rid] - return - - if auto_next_anon: - trace_slice_start("", rid, ts, True) - - -# alias -trace_slice = trace_slice_end - - -# Add event to the current slice on the same thread with the same rid. -def trace_event( - name: str, rid: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return + def trace_req_finish( + self, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None + ): + if not self.tracing_enable: + return - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return - - thread_context = reqs_context[rid].threads_context[pid] + ts = ts or get_cur_time_ns() - if not thread_context.cur_slice_stack: - logger.warning(f"No slice is currently being traced.") - return + # End all unclosed thread spans. + self.abort() - ts = ts or __get_cur_time_ns() + if attrs: + self.root_span.set_attributes(attrs) - slice_info = thread_context.cur_slice_stack[-1] - slice_info.span.add_event(name=name, timestamp=ts, attributes=attrs) + self.root_span.end(end_time=ts) + def __create_slice_span(self, _slice: TraceSliceContext): + if _slice.span: + return -# Add attrs to the current slice on the same thread with the same rid. -def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]): - if not tracing_enabled: - return + parent_span = self.thread_context.thread_span + if _slice.parent_slice: + parent_span = _slice.parent_slice.span - rid = str(rid) - if rid not in reqs_context: - return + parent_span_context = trace.set_span_in_context(parent_span) + span = self.thread_context.thread_info.tracer.start_span( + name=_slice.slice_name, + start_time=_slice.start_time_ns, + context=parent_span_context, + ) - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return + if _slice.prev_span_context: + span.add_link(_slice.prev_span_context) + + _slice.span = span + + if _slice.attrs: + span.set_attributes(_slice.attrs) + if _slice.events: + for event in _slice.events: + span.add_event( + name=event.event_name, + timestamp=event.ts, + attributes=event.attrs, + ) + _slice.lazy_flag = False + _slice.anonymous = False + _slice.attrs = {} + _slice.events = [] + + def __end_slice_span(self, _slice: TraceSliceContext): + # child_slices is not empty but they have not created span + # if cur_slice.lazy_flag is True before. + if _slice.child_slices: + for child_slice in _slice.child_slices: + self.__create_slice_span(child_slice) + self.__end_slice_span(child_slice) + _slice.child_slices = [] + + _slice.span.end(end_time=_slice.end_time) + _slice.parent_slice = None + + def trace_slice_start( + self, + name: str, + ts: Optional[int] = None, + anonymous: bool = False, + level: int = 1, + ): + if not self.tracing_enable: + return + + if not self.thread_context: + return + + ts = ts or get_cur_time_ns() + + cur_slice = TraceSliceContext( + slice_name=name, + start_time_ns=ts, + anonymous=anonymous, + level=level, + attrs={}, + events=[], + parent_slice=self.thread_context.cur_slice, + child_slices=[], + ) + if self.thread_context.cur_slice: + self.thread_context.cur_slice.child_slices.append(cur_slice) + self.thread_context.cur_slice = cur_slice + + if level > self.trace_level: + cur_slice.lazy_flag = True + return + + # find prev span, only first level slice has previous span + if not cur_slice.parent_slice: + if self.last_span_context: + cur_slice.prev_span_context = self.last_span_context + + # check if span creation is lazy + if anonymous or (cur_slice.parent_slice and cur_slice.parent_slice.lazy_flag): + cur_slice.lazy_flag = True + return + + self.__create_slice_span(cur_slice) + + def __release_slice_reference_tree(self, _slice: TraceSliceContext): + for child_slice in _slice.child_slices: + self.__release_slice_reference_tree(child_slice) + _slice.child_slices = [] + _slice.parent_slice = None + + def __trace_slice_end_flag_process(self, auto_next_anon, thread_finish_flag, ts): + # If this is the last slice in the thread, + # release the thread context and check whether to release the request context. + if thread_finish_flag: + self.abort(ts) + return + + if auto_next_anon: + self.trace_slice_start("", ts=ts, anonymous=True) + + def trace_slice_end( + self, + name: str, + ts: Optional[int] = None, + attrs: Optional[Dict[str, Any]] = None, + auto_next_anon: bool = False, + thread_finish_flag: bool = False, + level: int = 1, + ): + if not self.tracing_enable: + return + + if not self.thread_context: + return + + if not self.thread_context.cur_slice: + logger.warning( + f"No matching with the SLICE_START event {name} is required." + ) + return + + cur_slice = self.thread_context.cur_slice + ts = ts or get_cur_time_ns() + + if level > self.trace_level: + # release obj loop references to avoid GC block + self.thread_context.cur_slice = cur_slice.parent_slice + if cur_slice.parent_slice: + cur_slice.parent_slice.child_slices.remove(cur_slice) + self.__release_slice_reference_tree(cur_slice) + self.__trace_slice_end_flag_process(auto_next_anon, thread_finish_flag, ts) + return + + # check if slice_name matching and level matching + # unlikely path, excepting error API usage + if not cur_slice.anonymous and ( + cur_slice.slice_name != name or cur_slice.level != level + ): + logger.warning( + f"Slice name mismatch: {name} != {cur_slice.slice_name} or level mismatch: {level} != {cur_slice.level}" + ) + self.thread_context.cur_slice = cur_slice.parent_slice + if cur_slice.parent_slice: + cur_slice.parent_slice.child_slices.remove(cur_slice) + self.__release_slice_reference_tree(cur_slice) + return + + cur_slice.end_time = ts + cur_slice.slice_name = name + cur_slice.level = level + + if cur_slice.lazy_flag: + # cur slice span has not been created. + # check if parent slice is lazy, if so, mark cur slice as lazy + if cur_slice.parent_slice and cur_slice.parent_slice.lazy_flag: + if attrs: + cur_slice.attrs.update(attrs) + self.thread_context.cur_slice = cur_slice.parent_slice + self.__trace_slice_end_flag_process( + auto_next_anon, thread_finish_flag, ts + ) + return + + self.__create_slice_span(cur_slice) + + span = cur_slice.span + + if attrs: + span.set_attributes(attrs) + + self.thread_context.cur_slice = cur_slice.parent_slice + # only for first level slice + if not cur_slice.parent_slice: + self.last_span_context = span.get_span_context() + else: + cur_slice.parent_slice.child_slices.remove(cur_slice) + self.__end_slice_span(cur_slice) - thread_context = reqs_context[rid].threads_context[pid] + self.__trace_slice_end_flag_process(auto_next_anon, thread_finish_flag, ts) - if not thread_context.cur_slice_stack: - logger.warning(f"No slice is currently being traced.") - return + # Add event to the current slice on the same thread with the same rid. + def trace_event( + self, name: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None + ): + if not self.tracing_enable: + return - slice_info = thread_context.cur_slice_stack[-1] - slice_info.span.set_attributes(attrs) + if not self.thread_context: + return + if not self.thread_context.cur_slice: + logger.warning(f"No slice is currently being traced.") + return -def trace_slice_batch( - name: str, - reqs: List[Req], -): - if not tracing_enabled: - return + cur_slice = self.thread_context.cur_slice + ts = ts or get_cur_time_ns() - for req in reqs: - trace_slice( - name, - req.rid, - auto_next_anon=not req.finished(), - thread_finish_flag=req.finished(), - ) + if cur_slice.span: + cur_slice.span.add_event(name=name, timestamp=ts, attributes=attrs) + else: + cur_slice.events.append(TraceEvent(name, ts, attrs)) + # Add attrs to the current slice on the same thread with the same rid. + def trace_slice_add_attr(self, attrs: Dict[str, Any]): + if not self.tracing_enable: + return -def trace_event_batch( - name: str, - reqs: List[Req], - ts: Optional[int] = None, - attrs: Dict[str, Any] = {}, -): - if not tracing_enabled: - return + if not self.thread_context: + return - bid = uuid.uuid4().hex[:8] - _attrs = {"bid": bid, "batch_size": len(reqs)} - _attrs.update(attrs) + if not self.thread_context.cur_slice: + logger.warning(f"No slice is currently being traced.") + return - for req in reqs: - trace_event(name, req.rid, ts=ts, attrs=_attrs) + cur_slice = self.thread_context.cur_slice + if cur_slice.span: + cur_slice.span.set_attributes(attrs) + else: + cur_slice.span.attrs.update(attrs) + + def abort(self, ts=None, abort_info: Optional[Dict] = None): + if not self.tracing_enable: + return + + if not self.thread_context: + return + + # close all slice spans (unlikely except error API usage) + ts = ts or get_cur_time_ns() + if self.thread_context.cur_slice is not None: + if self.thread_context.cur_slice.span: + self.thread_context.cur_slice.span.end(end_time=ts) + + # if has nested span, end them + while self.thread_context.cur_slice.parent_slice: + self.thread_context.cur_slice = ( + self.thread_context.cur_slice.parent_slice + ) + if self.thread_context.cur_slice.span: + self.thread_context.cur_slice.span.end(end_time=ts) + + # slice will be dropped directly if slice.lazy_flag is True + self.__release_slice_reference_tree(self.thread_context.cur_slice) + self.thread_context.cur_slice = None + + # set abort info into thread span + if self.thread_context.thread_span: + if abort_info: + from sglang.srt.managers.schedule_batch import BaseFinishReason + + if isinstance(abort_info, BaseFinishReason): + abort_info = abort_info.to_json() + self.thread_context.thread_span.set_status(Status(StatusCode.ERROR)) + self.thread_context.thread_span.set_attributes(abort_info) + self.thread_context.thread_span.end(end_time=ts) + self.thread_context = None + + def __del__(self): + self.abort(abort_info={"abort_info": "have unclosed span, auto closed"}) diff --git a/python/sglang/srt/tracing/trace_metric_wrapper.py b/python/sglang/srt/tracing/trace_metric_wrapper.py new file mode 100644 index 000000000000..f8eb33fc6ab4 --- /dev/null +++ b/python/sglang/srt/tracing/trace_metric_wrapper.py @@ -0,0 +1,393 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""definition for requests stage timing recorder""" +import threading +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +from sglang.srt.tracing.trace import ( + TraceReqContext, + get_cur_time_ns, + get_opentelemetry_initialized, +) + + +@dataclass +class RequestStageConfig: + stage_name: str + level: int = 0 + # whether to call metrics_collector.observe_per_stage_req_latency + metrics_is_observed: bool = False + + +class RequestStage: + # Tokenizer + TOKENIZE = RequestStageConfig( + "tokenize", + level=1, + ) + TOKENIZER_DISPATCH = RequestStageConfig( + "dispatch", + level=2, + ) + + # DP controller + DC_DISPATCH = RequestStageConfig( + "dc_dispatch", + level=2, + ) + + # common/non-disaggregation + REQUEST_PROCESS = RequestStageConfig( + "request_process", + level=1, + metrics_is_observed=True, + ) + PREFILL_WAITING = RequestStageConfig( + "prefill_waiting", + level=1, + # equal to "observe_queue_time" + metrics_is_observed=False, + ) + DECODE_LOOP = RequestStageConfig( + "decode_loop", + level=2, + ) + PREFILL_FORWARD = RequestStageConfig( + "prefill_forward", + level=1, + metrics_is_observed=True, + ) + PREFILL_CHUNKED_FORWARD = RequestStageConfig( + "chunked_prefill", + level=3, + metrics_is_observed=True, + ) + + # disaggregation prefill + PREFILL_PREPARE = RequestStageConfig( + "prefill_prepare", + level=1, + ) + PREFILL_BOOTSTRAP = RequestStageConfig( + "prefill_bootstrap", + level=1, + metrics_is_observed=True, + ) + PREFILL_TRANSFER_KV_CACHE = RequestStageConfig( + "prefill_transfer_kv_cache", + level=1, + metrics_is_observed=True, + ) + + # disaggregation decode + DECODE_PREPARE = RequestStageConfig( + "decode_prepare", + level=1, + metrics_is_observed=True, + ) + DECODE_BOOTSTRAP = RequestStageConfig( + "decode_bootstrap", + level=1, + metrics_is_observed=True, + ) + DECODE_WAITING = RequestStageConfig( + "decode_waiting", + level=1, + metrics_is_observed=True, + ) + DECODE_TRANSFERRED = RequestStageConfig( + "decode_transferred", + level=1, + metrics_is_observed=True, + ) + DECODE_FAKE_OUTPUT = RequestStageConfig( + "fake_output", + level=1, + metrics_is_observed=True, + ) + DECODE_QUICK_FINISH = RequestStageConfig( + "quick_finish", + level=1, + metrics_is_observed=True, + ) + + # mini lb + MINI_LB_LAUNCH = RequestStageConfig( + "mini_lb_launch", + level=1, + ) + + WAIT_PD_FINISH = RequestStageConfig( + "wait_pd_finish", + level=2, + ) + + # other + ANONYMOUS = RequestStageConfig("") + + +@dataclass +class NullContext: + tracing_enable: bool = False + enable_metrics: bool = False + + def __getattr__(self, name): + return self + + def __call__(self, *args, **kwargs): + return self + + +@dataclass +class StageMetricContext(NullContext): + tracing_enable: bool = False + + def __init__(self, enable_metrics: bool, metrics_collector=None): + self.enable_metrics = enable_metrics + self.metrics_collector = metrics_collector + if not metrics_collector: + self.enable_metrics = False + + self.last_ts_stack = [] + + def slice_start( + self, + stage: RequestStageConfig, + ts: Optional[int] = None, + ): + if self.enable_metrics: + ts = ts or get_cur_time_ns() + self.last_ts_stack.append(ts) + + def slice_end( + self, + stage: RequestStageConfig, + ts: Optional[int] = None, + attrs: Optional[Dict[str, Any]] = None, + auto_next_anon: bool = False, + thread_finish_flag: bool = False, + ): + if self.enable_metrics and len(self.last_ts_stack) > 0: + ts = ts or get_cur_time_ns() + last_ts = self.last_ts_stack.pop() + lat = (ts - last_ts) / 1e9 + + if stage.metrics_is_observed: + try: + self.metrics_collector.observe_per_stage_req_latency( + stage.stage_name, + lat, + ) + except AttributeError: + pass + + if auto_next_anon: + self.last_ts_stack.append(ts) + + +class TraceMetricContext(TraceReqContext, StageMetricContext): + def __init__( + self, + rid, + bootstrap_room, + module_name, + server_args, + metrics_collector=None, + role: Optional[str] = None, + ): + enable_metrics = getattr(server_args, "enable_metrics", False) + StageMetricContext.__init__(self, enable_metrics, metrics_collector) + + opentelemetry_initialized = get_opentelemetry_initialized() + trace_level = getattr(server_args, "trace_level", 0) + tracing_enable = ( + True + if getattr(server_args, "trace_module", None) == module_name + and trace_level > 0 + and opentelemetry_initialized + else False + ) + + self.disagg_mode = getattr(server_args, "disaggregation_mode", "null") + if not role: + role = self.disagg_mode + super().__init__( + rid=str(rid), + bootstrap_room=bootstrap_room, + role=role, + tracing_enable=tracing_enable, + trace_level=trace_level, + module_name=module_name, + ) + + def reinit_metric_ctx(self, enable_metrics: bool, metrics_collector=None): + StageMetricContext.__init__(self, enable_metrics, metrics_collector) + + def slice_start( + self, + stage: RequestStageConfig, + ts: Optional[int] = None, + ): + ts = ts or get_cur_time_ns() + StageMetricContext.slice_start(self, stage, ts) + + self.trace_slice_start( + stage.stage_name, + ts=ts, + anonymous=(stage == RequestStage.ANONYMOUS), + level=stage.level, + ) + + def slice_end( + self, + stage: RequestStageConfig, + ts: Optional[int] = None, + attrs: Optional[Dict[str, Any]] = None, + auto_next_anon: bool = False, + thread_finish_flag: bool = False, + ): + ts = ts or get_cur_time_ns() + StageMetricContext.slice_end( + self, stage, ts, attrs, auto_next_anon, thread_finish_flag + ) + + self.trace_slice_end( + stage.stage_name, + ts=ts, + attrs=attrs, + auto_next_anon=auto_next_anon, + thread_finish_flag=thread_finish_flag, + level=stage.level, + ) + + +def metric_trace_slice_start_batch( + stage: RequestStageConfig, + reqs: List, +): + if not reqs: + return + + ctx = reqs[0].trace_metric_ctx + if not ctx.tracing_enable and not ctx.enable_metrics: + return + + for req in reqs: + req.trace_metric_ctx.slice_start(stage) + + +def metric_trace_slice_batch( + stage: RequestStageConfig, + reqs: List, + auto_next_anon=False, + thread_finish_flag=False, +): + if not reqs: + return + + ctx = reqs[0].trace_metric_ctx + if not ctx.tracing_enable and not ctx.enable_metrics: + return + + for req in reqs: + finished = req.finished() if hasattr(req, "finished") else None + if finished is not None: + auto_next_anon = not finished + thread_finish_flag = finished + req.trace_metric_ctx.slice_end( + stage, + auto_next_anon=auto_next_anon, + thread_finish_flag=thread_finish_flag, + ) + + +def trace_event_batch( + name: str, + reqs: List, + ts: Optional[int] = None, + attrs: Dict[str, Any] = {}, +): + if not reqs or not reqs[0].trace_metric_ctx.tracing_enable: + return + + bid = uuid.uuid4().hex[:8] + _attrs = {"bid": bid, "batch_size": len(reqs)} + _attrs.update(attrs) + + for req in reqs: + req.trace_metric_ctx.trace_event(name, ts=ts, attrs=_attrs) + + +""" +Used when the trace_metric_ctx cannot be integrated into the request object. + +format: + { + thread_id: { + "rid": TraceMetricContext + } + } +""" +global_trace_metric_ctx_table: Dict[int, Dict[str, TraceMetricContext]] = {} + + +def global_init_trace_metric_ctx( + rid, + bootstrap_room, + module_name, + server_args, + metrics_collector=None, + time_stat_cls=None, + role: Optional[str] = None, +): + pid = threading.get_native_id() + rid = str(rid) + trace_metric_ctx = TraceMetricContext( + rid=rid, + bootstrap_room=bootstrap_room, + module_name=module_name, + server_args=server_args, + metrics_collector=metrics_collector, + time_stat_cls=time_stat_cls, + role=role, + ) + + global_trace_metric_ctx_table.setdefault(pid, {})[rid] = trace_metric_ctx + + return trace_metric_ctx + + +def global_get_trace_metric_ctx(rid) -> Union[TraceMetricContext, NullContext]: + pid = threading.get_native_id() + rid = str(rid) + if pid in global_trace_metric_ctx_table: + if rid in global_trace_metric_ctx_table[pid]: + return global_trace_metric_ctx_table[pid][rid] + return NullContext() + + +def global_set_trace_metric_ctx(trace_metric_ctx): + pid = threading.get_native_id() + rid = trace_metric_ctx.rid + global_trace_metric_ctx_table.setdefault(pid, {})[rid] = trace_metric_ctx + + +def global_del_trace_metric_ctx(rid): + pid = threading.get_native_id() + rid = str(rid) + if pid in global_trace_metric_ctx_table: + if rid in global_trace_metric_ctx_table[pid]: + del global_trace_metric_ctx_table[pid][rid] diff --git a/test/manual/test_tracing.py b/test/manual/test_tracing.py index 4e3763ac414e..c5b2d92df751 100644 --- a/test/manual/test_tracing.py +++ b/test/manual/test_tracing.py @@ -4,7 +4,7 @@ import time import unittest from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Dict, Optional, Union import requests import zmq @@ -24,7 +24,7 @@ @dataclass class Req: rid: int - trace_context: Optional[Dict[str, Any]] = None + req_context: Optional[Union[TraceReqContext, Dict]] = None class TestTrace(CustomTestCase): @@ -65,15 +65,23 @@ def __clear_trace_file(self): except: pass - def test_trace_enable(self): + def __test_trace_enable(self, trace_level, trace_module, expect_export_data): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) process = popen_launch_server( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-trace", "--otlp-traces-endpoint", "0.0.0.0:4317"], + other_args=[ + "--trace-level", + trace_level, + "--trace-module", + trace_module, + "--otlp-traces-endpoint", + "0.0.0.0:4317", + ], ) try: @@ -101,15 +109,37 @@ def test_trace_enable(self): # check trace file assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + if expect_export_data: + assert ( + os.path.getsize("/tmp/otel_trace.json") > 0 + ), "trace file is empty" + else: + assert ( + os.path.getsize("/tmp/otel_trace.json") == 0 + ), "trace file is not empty" finally: kill_process_tree(process.pid) - assert self.__stop_otel_jaeger() + + def test_trace_enable_level_1(self): + self.__test_trace_enable("1", "request", True) + + def test_trace_enable_level_2(self): + self.__test_trace_enable("2", "request", True) + + def test_trace_enable_level_3(self): + self.__test_trace_enable("3", "request", True) + + def test_trace_enable_level_0(self): + self.__test_trace_enable("0", "request", False) + + def test_trace_enable_module_invalid(self): + self.__test_trace_enable("1", "valid_module", False) def test_trace_engine_enable(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -119,7 +149,7 @@ def test_trace_engine_enable(self): engine = Engine( model_path=model_path, random_seed=42, - enable_trace=True, + trace_level=1, otlp_traces_endpoint="localhost:4317", ) @@ -134,11 +164,11 @@ def test_trace_engine_enable(self): assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: engine.shutdown() - assert self.__stop_otel_jaeger() def test_trace_engine_encode(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) prompt = "Today is a sunny day and I like" model_path = "Qwen/Qwen2-7B" @@ -146,7 +176,7 @@ def test_trace_engine_encode(self): engine = Engine( model_path=model_path, random_seed=42, - enable_trace=True, + trace_level=1, otlp_traces_endpoint="localhost:4317", is_embedding=True, ) @@ -162,19 +192,20 @@ def test_trace_engine_encode(self): assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: engine.shutdown() - assert self.__stop_otel_jaeger() def test_slice_trace_simple(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) try: process_tracing_init("0.0.0.0:4317", "test") trace_set_thread_info("Test") - trace_req_start(0) - trace_slice_start("test slice", 0) + req_context = TraceReqContext(0, tracing_enable=True) + req_context.trace_req_start() + req_context.trace_slice_start("test slice") time.sleep(1) - trace_slice_end("test slice", 0) - trace_req_finish(0) + req_context.trace_slice_end("test slice") + req_context.trace_req_finish() # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. time.sleep(10) @@ -182,23 +213,25 @@ def test_slice_trace_simple(self): assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: - assert self.__stop_otel_jaeger() + pass def test_slice_trace_complex(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) try: process_tracing_init("0.0.0.0:4317", "test") trace_set_thread_info("Test") - trace_req_start(0) - trace_slice_start("", 0, anonymous=True) + req_context = TraceReqContext(0, tracing_enable=True) + req_context.trace_req_start() + req_context.trace_slice_start("", anonymous=True) time.sleep(1) - trace_slice_end("slice A", 0, auto_next_anon=True) + req_context.trace_slice_end("slice A", auto_next_anon=True) time.sleep(1) - trace_slice_end("slice B", 0, auto_next_anon=True) + req_context.trace_slice_end("slice B", auto_next_anon=True) time.sleep(1) - trace_slice_end("slice C", 0, thread_finish_flag=True) - trace_req_finish(0) + req_context.trace_slice_end("slice C", thread_finish_flag=True) + req_context.trace_req_finish() # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. time.sleep(10) @@ -206,7 +239,7 @@ def test_slice_trace_complex(self): assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: - assert self.__stop_otel_jaeger() + pass def test_trace_context_propagete(self): def __process_work(): @@ -220,16 +253,17 @@ def __process_work(): try: req = recv_from_main.recv_pyobj() - trace_set_proc_propagate_context(req.rid, req.trace_context) - trace_slice_start("work", req.rid) + req.req_context.rebuild_thread_context() + req.req_context.trace_slice_start("work") time.sleep(1) - trace_slice_end("work", req.rid, thread_finish_flag=True) + req.req_context.trace_slice_end("work", thread_finish_flag=True) finally: recv_from_main.close() context.term() self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) context = zmq.Context(2) send_to_subproc = get_zmq_socket( @@ -246,15 +280,18 @@ def __process_work(): time.sleep(1) req = Req(rid=0) - trace_req_start(req.rid) - trace_slice_start("dispatch", req.rid) + req.req_context = TraceReqContext(0, tracing_enable=True) + req.req_context.trace_req_start() + req.req_context.trace_slice_start("dispatch") time.sleep(1) - req.trace_context = trace_get_proc_propagate_context(req.rid) + req_context = req.req_context send_to_subproc.send_pyobj(req) - trace_slice_end("dispatch", req.rid) + # restore + req.req_context = req_context + req.req_context.trace_slice_end("dispatch") subproc.join() - trace_req_finish(req.rid) + req.req_context.trace_req_finish() # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. time.sleep(10) @@ -265,7 +302,6 @@ def __process_work(): finally: send_to_subproc.close() context.term() - assert self.__stop_otel_jaeger() if __name__ == "__main__":