diff --git a/docs/references/production_request_trace.md b/docs/references/production_request_trace.md index a60c68b9ea8b..d1dfdd2f067d 100644 --- a/docs/references/production_request_trace.md +++ b/docs/references/production_request_trace.md @@ -17,23 +17,23 @@ This section explains how to configure the request tracing and export the trace pip install opentelemetry-sdk opentelemetry-api opentelemetry-exporter-otlp opentelemetry-exporter-otlp-proto-grpc ``` -2. launch opentelemetry collector and jaeger +2. Launch OpenTelemetry collector and Jaeger ```bash docker compose -f examples/monitoring/tracing_compose.yaml up -d ``` -3. start your SGLang server with tracing enabled +3. Start your SGLang server with tracing enabled ```bash # set env variables export SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS=500 export SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE=64 # start the prefill and decode server python -m sglang.launch_server --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 - # start the mini lb + # start the model-gate-way python -m sglang_router.launch_router --enable-trace --otlp-traces-endpoint 0.0.0.0:4317 ``` - Replace `0.0.0.0:4317` with the actual endpoint of the opentelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317. + Replace `0.0.0.0:4317` with the actual endpoint of the OpenTelemetry collector. If you launched the openTelemetry collector with tracing_compose.yaml, the default receiving port is 4317. To use the HTTP/protobuf span exporter, set the following environment variable and point to an HTTP endpoint, for example, `http://0.0.0.0:4318/v1/traces`. ```bash @@ -41,15 +41,33 @@ This section explains how to configure the request tracing and export the trace ``` -4. raise some requests +4. Raise some requests 5. Observe whether trace data is being exported * Access port 16686 of Jaeger using a web browser to visualize the request traces. * The OpenTelemetry Collector also exports trace data in JSON format to /tmp/otel_trace.json. In a follow-up patch, we will provide a tool to convert this data into a Perfetto-compatible format, enabling visualization of requests in the Perfetto UI. -## How to add Tracing for slices you're interested in? +6. Dynamically adjust trace level + The trace level accepts configurable values from `0` to `3`. The meanings of different trace level values are as follows: + ``` + 0: disable tracing + 1: Trace important slices + 2: Trace all slices except nested ones + 3: Trace all slices + ``` + The trace level can be dynamically set via HTTP API, for example: + ```bash + curl http://0.0.0.0:30000/set_trace_level?level=2 + ``` + Replace `0.0.0.0:30000` with your actual server address, and replace `level=2` with the level you want to set. + + **Note**: You must set the parameter `--enable-trace`; otherwise, the trace capability will not be enabled regardless of any dynamic adjustments to the trace level. + +## How to add Tracing for slices you're interested in?(API introduction) We have already inserted instrumentation points in the tokenizer and scheduler main threads. If you wish to trace additional request execution segments or perform finer-grained tracing, please use the APIs from the tracing package as described below. -1. initialization +**All of the following implementations are done in python/sglang/srt/observability/req_time_stats.py. If you want to add another slice, please do it here.** + +1. Initialization Every process involved in tracing during the initialization phase should execute: ```python @@ -63,98 +81,53 @@ We have already inserted instrumentation points in the tokenizer and scheduler m ``` The "thread label" can be regarded as the name of the thread, used to distinguish different threads in the visualization view. -2. Mark the beginning and end of a request +2. Create a trace context for a request + Each request needs to call `TraceReqContext()` to initialize a request context, which is used to generate slice spans and record request stage info. You can either store it within the request object or maintain it as a global variable. + +3. Mark the beginning and end of a request ``` - trace_req_start(rid, bootstrap_room) - trace_req_finish(rid) + trace_ctx.trace_req_start(). + trace_ctx.trace_req_finish() ``` - These two APIs must be called within the same process, for example, in the tokenizer. + trace_req_start() and trace_req_finish() must be called within the same process, for example, in the tokenizer. -3. Add tracing for slice +4. Add tracing for a slice * Add slice tracing normally: ```python - trace_slice_start("slice A", rid) - trace_slice_end("slice A", rid) - ``` + trace_ctx.trace_slice_start(RequestStage.TOKENIZER.stage_name) + trace_ctx.trace_slice_end(RequestStage.TOKENIZER.stage_name) - - Use the "anonymous" flag to not specify a slice name at the start of the slice, allowing the slice name to be determined by trace_slice_end. -
Note: Anonymous slices must not be nested. - ```python - trace_slice_start("", rid, anonymous = True) - trace_slice_end("slice A", rid) + or + trace_ctx.trace_slice(slice: TraceSliceContext) ``` - - In trace_slice_end, use auto_next_anon to automatically create the next anonymous slice, which can reduce the number of instrumentation points needed. + - The end of the last slice in a thread must be marked with thread_finish_flag=True, or explicitly call trace_ctx.abort(); otherwise, the thread's span will not be properly generated. ```python - trace_slice_start("", rid, anonymous = True) - trace_slice_end("slice A", rid, auto_next_anon = True) - trace_slice_end("slice B", rid, auto_next_anon = True) - trace_slice_end("slice C", rid, auto_next_anon = True) - trace_slice_end("slice D", rid) - ``` - - The end of the last slice in a thread must be marked with thread_finish_flag=True; otherwise, the thread's span will not be properly generated. - ```python - trace_slice_end("slice D", rid, thread_finish_flag = True) + trace_ctx.slice_end(RequestStage.D.stage_name, thread_finish_flag = True) + trace_ctx.abort() ``` -4. When the request execution flow transfers to another thread, the trace context needs to be explicitly propagated. - - sender: Execute the following code before sending the request to another thread via ZMQ - ```python - trace_context = trace_get_proc_propagate_context(rid) - req.trace_context = trace_context - ``` +5. When the request execution flow transfers to another thread, the thread context needs to be explicitly rebuilt. - receiver: Execute the following code after receiving the request via ZMQ ```python - trace_set_proc_propagate_context(rid, req.trace_context) - ``` - -5. When the request execution flow transfers to another node(PD disaggregation), the trace context needs to be explicitly propagated. - - sender: Execute the following code before sending the request to node thread via http - ```python - trace_context = trace_get_remote_propagate_context(bootstrap_room_list) - headers = {"trace_context": trace_context} - session.post(url, headers=headers) - ``` - - receiver: Execute the following code after receiving the request via http - ```python - trace_set_remote_propagate_context(request.headers['trace_context']) + trace_ctx.rebuild_thread_context() ``` ## How to Extend the Tracing Framework to Support Complex Tracing Scenarios The currently provided tracing package still has potential for further development. If you wish to build more advanced features upon it, you must first understand its existing design principles. -The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a two-level trace context structure and a four-level span structure: `SglangTraceReqContext`, `SglangTraceThreadContext`. Their relationship is as follows: +The core of the tracing framework's implementation lies in the design of the span structure and the trace context. To aggregate scattered slices and enable concurrent tracking of multiple requests, we have designed a three-level trace context structure or span structure: `TraceReqContext`, `TraceThreadContext` and `TraceSliceContext`. Their relationship is as follows: ``` -SglangTraceReqContext (req_id="req-123") -├── SglangTraceThreadContext(thread_label="scheduler", tp_rank=0) +TraceReqContext (req_id="req-123") +├── TraceThreadContext(thread_label="scheduler", tp_rank=0) +| └── TraceSliceContext(slice_name="prefill") | -└── SglangTraceThreadContext(thread_label="scheduler", tp_rank=1) +└── TraceThreadContext(thread_label="scheduler", tp_rank=1) + └── TraceSliceContext(slice_name="prefill") ``` -Each traced request maintains a global `SglangTraceReqContext`. For every thread processing the request, a corresponding `SglangTraceThreadContext` is recorded and composed within the `SglangTraceReqContext`. Within each thread, every currently traced slice (possibly nested) is stored in a list. +Each traced request maintains a global `TraceReqContext` and creates a corresponding request span. For every thread that processes the request, a `TraceThreadContext` is recorded and a thread span is created. The `TraceThreadContext` is nested within the `TraceReqContext`, and each currently traced code slice—potentially nested—is stored in its associated `TraceThreadContext`. In addition to the above hierarchy, each slice also records its previous slice via Span.add_link(), which can be used to trace the execution flow. - -When the request execution flow transfers to a new thread, the trace context needs to be explicitly propagated. In the framework, this is represented by `SglangTracePropagateContext`, which contains the context of the request span and the previous slice span. - - -We designed a four-level span structure, consisting of `bootstrap_room_span`, `req_root_span`, `thread_span`, and `slice_span`. Among them, `req_root_span` and `thread_span` correspond to `SglangTraceReqContext` and `SglangTraceThreadContext`, respectively, and `slice_span` is stored within the `SglangTraceThreadContext`. The `bootstrap_room_span` is designed to accommodate the separation of PD-disaggregation. On different nodes, we may want to add certain attributes to the `req_root_span`. However, if the `req_root_span` is shared across all nodes, the Prefill and Decode nodes would not be allowed to add attributes due to the constraints imposed by OpenTelemetry's design. - -``` -bootstrap room span -├── router req root span -| └── router thread span -| └── slice span -├── prefill req root span -| ├── tokenizer thread span -| | └── slice span -| └── scheduler thread span -| └── slice span -└── decode req root span - ├── tokenizer thread span - | └── slice span - └── scheduler thread span - └── slice span -``` diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 8d437988973b..85af3fcff936 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -21,7 +21,6 @@ from __future__ import annotations import logging -import time from collections import deque from dataclasses import dataclass from http import HTTPStatus @@ -47,7 +46,7 @@ prepare_abort, ) from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch +from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -60,7 +59,10 @@ ReqToTokenPool, ) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool -from sglang.srt.tracing.trace import trace_event_batch, trace_slice_end +from sglang.srt.observability.req_time_stats import ( + set_schedule_time_batch, + set_time_batch, +) from sglang.srt.utils import get_int_env_var from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -391,8 +393,6 @@ def _create_receiver_and_enqueue(self, req: Req, dp_rank: int) -> None: prefill_dp_rank=dp_rank, ) - req.add_latency(RequestStage.DECODE_PREPARE) - trace_slice_end(RequestStage.DECODE_PREPARE, req.rid, auto_next_anon=True) self.queue.append( DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False) ) @@ -668,13 +668,7 @@ def pop_preallocated( ) preallocated_reqs.append(decode_req) indices_to_remove.add(i) - decode_req.req.time_stats.decode_transfer_queue_entry_time = ( - time.perf_counter() - ) - decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP) - trace_slice_end( - RequestStage.DECODE_BOOTSTRAP, decode_req.req.rid, auto_next_anon=True - ) + decode_req.req.time_stats.set_decode_transfer_queue_entry_time() self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove @@ -887,12 +881,7 @@ def _commit_transfer_to_req(self, decode_req: DecodeRequest) -> bool: decode_req.kv_receiver.clear() decode_req.kv_receiver = None - trace_slice_end( - RequestStage.DECODE_TRANSFERRED, - decode_req.req.rid, - auto_next_anon=True, - ) - decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter() + decode_req.req.time_stats.set_wait_queue_entry_time() return True def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req]: @@ -956,7 +945,6 @@ def pop_transferred(self, rids_to_check: Optional[List[str]] = None) -> List[Req for i in indices_to_remove: idx = self.queue[i].metadata_buffer_index assert idx != -1 - self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED) self.req_to_metadata_buffer_idx_allocator.free(idx) self.queue = [ @@ -1079,7 +1067,7 @@ def get_next_disagg_decode_batch_to_run( ret = self.maybe_prepare_mlp_sync_batch(ret) if ret: - trace_event_batch("schedule", ret.reqs) + set_schedule_time_batch(ret) return ret def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: @@ -1107,7 +1095,6 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: # we can only add at least `num_not_used_batch` new batch to the running queue if i < num_not_used_batch: can_run_list.append(req) - req.add_latency(RequestStage.DECODE_WAITING) req.init_next_round_input(self.tree_cache) else: waiting_queue.append(req) @@ -1116,8 +1103,7 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: if len(can_run_list) == 0: return None - for req in can_run_list: - req.time_stats.forward_entry_time = time.perf_counter() + set_time_batch(can_run_list, "set_forward_entry_time") # construct a schedule batch with those requests and mark as decode new_batch = ScheduleBatch.init_new( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index fbc801635108..2922276ed5e3 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,7 +20,6 @@ from __future__ import annotations import logging -import time from collections import deque from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Type @@ -42,16 +41,11 @@ poll_and_all_reduce, prepare_abort, ) -from sglang.srt.managers.schedule_batch import ( - FINISH_LENGTH, - Req, - RequestStage, - ScheduleBatch, -) +from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool -from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end +from sglang.srt.observability.req_time_stats import set_schedule_time_batch if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -227,9 +221,7 @@ def add(self, req: Req, num_kv_heads: int) -> None: pp_rank=self.pp_rank, ) self._process_req(req) - req.add_latency(RequestStage.PREFILL_PREPARE) self.queue.append(req) - trace_slice_end(RequestStage.PREFILL_PREPARE, req.rid, auto_next_anon=True) def extend(self, reqs: List[Req], num_kv_heads: int) -> None: for req in reqs: @@ -239,6 +231,7 @@ def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: if len(req.origin_input_ids) > self.max_total_num_tokens: message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" logger.error(message) + req.time_stats.trace_ctx.abort(abort_info={"reason": message}) prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST) self.scheduler.stream_output([req], req.return_logprob) return True @@ -291,6 +284,7 @@ def pop_bootstrapped( except Exception as e: error_message += f" with exception {e}" logger.error(error_message) + req.time_stats.trace_ctx.abort(abort_info={"reason": error_message}) prepare_abort( req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR ) @@ -319,12 +313,7 @@ def pop_bootstrapped( bootstrapped_reqs.append(req) indices_to_remove.add(i) - req.time_stats.wait_queue_entry_time = time.perf_counter() - req.add_latency(RequestStage.PREFILL_BOOTSTRAP) - - trace_slice_end( - RequestStage.PREFILL_BOOTSTRAP, req.rid, auto_next_anon=True - ) + req.time_stats.set_wait_queue_entry_time() self.queue = [ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove @@ -354,7 +343,7 @@ def get_next_disagg_prefill_batch_to_run( batch = self.maybe_prepare_mlp_sync_batch(batch) if batch: - trace_event_batch("schedule", batch.reqs) + set_schedule_time_batch(batch) return batch @@ -469,14 +458,11 @@ def process_batch_result_disagg_prefill( zip(batch.reqs, next_token_ids, strict=True) ): if req.is_chunked <= 0: - if req.time_stats.prefill_finished_ts == 0.0: - req.time_stats.prefill_finished_ts = time.time() + req.time_stats.set_prefill_finished_time() # There is no output_ids for prefill req.output_ids.append(next_token_id) self.tree_cache.cache_unfinished_req(req) # update the tree and lock - req.add_latency(RequestStage.PREFILL_FORWARD) - trace_slice(RequestStage.PREFILL_FORWARD, req.rid, auto_next_anon=True) self.disagg_prefill_inflight_queue.append(req) if self.spec_algorithm.is_eagle() and batch.spec_info is not None: req.output_topk_p = batch.spec_info.topk_p[i] @@ -502,7 +488,7 @@ def process_batch_result_disagg_prefill( ) logprob_pt += num_input_logprobs self.send_kv_chunk(req, last_chunk=True) - req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter() + req.time_stats.set_prefill_transfer_queue_entry_time() if req.grammar is not None: # FIXME: this try-except block is for handling unexpected xgrammar issue. @@ -541,9 +527,7 @@ def process_batch_result_disagg_prefill( if self.enable_overlap: self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) - trace_slice( - RequestStage.PREFILL_CHUNKED_FORWARD, req.rid, auto_next_anon=True - ) + req.time_stats.set_last_chunked_prefill_finish_time() self.maybe_send_health_check_signal() @@ -584,6 +568,7 @@ def process_disagg_prefill_inflight_queue( if hasattr(req.disagg_kv_sender, "clear"): req.disagg_kv_sender.clear() done_reqs.append(req) + req.time_stats.set_prefill_kv_transfer_finish_time() elif poll == KVPoll.Failed: error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" try: @@ -591,6 +576,7 @@ def process_disagg_prefill_inflight_queue( except Exception as e: error_message += f" with exception {e}" logger.warning(error_message) + req.time_stats.trace_ctx.abort(abort_info={"reason": error_message}) release_kv_cache(req, self.tree_cache) # unlock the tree prepare_abort( req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR @@ -602,7 +588,7 @@ def process_disagg_prefill_inflight_queue( assert False, f"Unexpected polling state {poll=}" for req in done_reqs: - req.time_stats.completion_time = time.perf_counter() + req.time_stats.set_completion_time() # Stream requests which have finished transfer self.stream_output( @@ -612,13 +598,10 @@ def process_disagg_prefill_inflight_queue( ) for req in done_reqs: req: Req - req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE) + release_req_to_metadata_buffer( req, self.req_to_metadata_buffer_idx_allocator ) - trace_slice( - RequestStage.PREFILL_TRANSFER_KV_CACHE, req.rid, thread_finish_flag=True - ) self.disagg_prefill_inflight_queue = undone_reqs diff --git a/python/sglang/srt/dllm/mixin/scheduler.py b/python/sglang/srt/dllm/mixin/scheduler.py index 656355e20eed..5b87eff5fe33 100644 --- a/python/sglang/srt/dllm/mixin/scheduler.py +++ b/python/sglang/srt/dllm/mixin/scheduler.py @@ -1,14 +1,14 @@ from __future__ import annotations import logging -import time from typing import TYPE_CHECKING, List, Optional, Set, Union from sglang.srt.dllm.config import DllmConfig from sglang.srt.dllm.mixin.req import DllmReqPhase -from sglang.srt.managers.schedule_batch import Req, RequestStage, ScheduleBatch +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_policy import AddReqResult, PrefillAdder from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.observability.req_time_stats import set_time_batch logger = logging.getLogger(__name__) @@ -52,7 +52,8 @@ def get_new_batch_dllm(self: Scheduler) -> Optional[ScheduleBatch]: return None # Record metrics and update state - self._update_metrics_and_state_for_batch(can_run_list, adder, running_bs) + set_time_batch(can_run_list, "set_forward_entry_time") + self._update_state_for_batch(can_run_list, adder, running_bs) # Create and prepare batch new_batch = self._create_dllm_batch(can_run_list, forward_mode) @@ -147,13 +148,10 @@ def _process_batch_by_phase( if incoming_reqs: self.process_dllm_incoming_reqs(adder, incoming_reqs) - def _update_metrics_and_state_for_batch( + def _update_state_for_batch( self: Scheduler, can_run_list: List[Req], adder: PrefillAdder, running_bs: int ) -> None: - """Update metrics and state for the batch.""" - if self.enable_metrics: - for req in can_run_list: - req.add_latency(RequestStage.PREFILL_WAITING) + """Update state for the batch.""" if adder.preempt_list: for req in adder.preempt_list: @@ -167,14 +165,6 @@ def _update_metrics_and_state_for_batch( self.can_run_list = can_run_list self.running_bs = len(self.running_batch.reqs) - for req in can_run_list: - if req.time_stats.forward_entry_time == 0: - req.time_stats.forward_entry_time = time.perf_counter() - if self.enable_metrics: - self.metrics_collector.observe_queue_time( - req.time_stats.get_queueing_time(), - ) - def _create_dllm_batch( self: Scheduler, can_run_list: List[Req], forward_mode: ForwardMode ) -> ScheduleBatch: @@ -194,7 +184,7 @@ def _create_dllm_batch( new_batch.decoding_reqs = None # Record prefill stats for logging after forward - from sglang.srt.managers.scheduler_metrics_mixin import PrefillStats + from sglang.srt.observability.scheduler_metrics_mixin import PrefillStats new_batch.prefill_stats = PrefillStats( log_input_tokens=self.adder.log_input_tokens, diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 0ed5a1b44b86..da8d56508d41 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -69,8 +69,8 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( parse_remote_instance_transfer_engine_info_from_scheduler_infos, ) +from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info from sglang.srt.utils import ( MultiprocessingSerializer, assert_pkg_version, diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index a691a29752b3..5b67c1d78c5b 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -46,7 +46,16 @@ import requests import uvicorn import uvloop -from fastapi import Depends, FastAPI, File, Form, HTTPException, Request, UploadFile +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Query, + Request, + UploadFile, +) from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse @@ -143,13 +152,17 @@ ) from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager -from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( parse_remote_instance_transfer_engine_info_from_scheduler_infos, ) +from sglang.srt.observability.func_timer import enable_func_timer +from sglang.srt.observability.trace import ( + process_tracing_init, + set_global_trace_level, + trace_set_thread_info, +) from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info from sglang.srt.utils import ( add_prometheus_middleware, add_prometheus_track_response_middleware, @@ -871,6 +884,16 @@ async def stop_profile_async(): ) +@app.api_route("/set_trace_level", methods=["GET", "POST"]) +def set_trace_level(level: int = Query(..., ge=0)): + set_global_trace_level(level) + + return Response( + content="success", + status_code=200, + ) + + @app.api_route("/freeze_gc", methods=["GET", "POST"]) @auth_level(AuthLevel.ADMIN_OPTIONAL) async def freeze_gc_async(): diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 8fed454a7f33..6c1d3c2866fb 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -2,7 +2,6 @@ import json import logging -import time import uuid from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union @@ -14,6 +13,7 @@ from sglang.srt.entrypoints.openai.encoding_dsv32 import DS32EncodingError from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput +from sglang.srt.observability.req_time_stats import monotonic_time from sglang.srt.server_args import ServerArgs if TYPE_CHECKING: @@ -76,14 +76,11 @@ async def handle_request( """Handle the specific request type with common pattern If you want to override this method, you should be careful to record the validation time. """ - received_time = time.time() - received_time_perf = time.perf_counter() + received_time = monotonic_time() try: # Validate request - validation_start = time.perf_counter() error_msg = self._validate_request(request) - validation_time = time.perf_counter() - validation_start if error_msg: return self.create_error_response(error_msg) @@ -94,9 +91,7 @@ async def handle_request( if isinstance(adapted_request, (GenerateReqInput, EmbeddingReqInput)): # Only set timing fields if adapted_request supports them - adapted_request.validation_time = validation_time adapted_request.received_time = received_time - adapted_request.received_time_perf = received_time_perf # Note(Xinyuan): raw_request below is only used for detecting the connection of the client if hasattr(request, "stream") and request.stream: @@ -166,7 +161,6 @@ def _convert_to_internal_request( self, request: OpenAIServingRequest, raw_request: Request = None, - validation_time: float = None, ) -> tuple[GenerateReqInput, OpenAIServingRequest]: """Convert OpenAI request to internal format""" pass diff --git a/python/sglang/srt/entrypoints/openai/serving_transcription.py b/python/sglang/srt/entrypoints/openai/serving_transcription.py index 9ba0d9a43007..2b5661f4967a 100644 --- a/python/sglang/srt/entrypoints/openai/serving_transcription.py +++ b/python/sglang/srt/entrypoints/openai/serving_transcription.py @@ -14,6 +14,7 @@ """ OpenAI-compatible transcription endpoint handler for Whisper models. """ + from __future__ import annotations import io diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index 3fa9fcbcee25..1859852af377 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -29,8 +29,8 @@ import torch.distributed from sglang.srt.environ import envs -from sglang.srt.metrics.collector import ExpertDispatchCollector from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.observability.metrics_collector import ExpertDispatchCollector from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_int_env_var diff --git a/python/sglang/srt/grpc/grpc_request_manager.py b/python/sglang/srt/grpc/grpc_request_manager.py index 898031af391a..2685070cc669 100644 --- a/python/sglang/srt/grpc/grpc_request_manager.py +++ b/python/sglang/srt/grpc/grpc_request_manager.py @@ -11,7 +11,6 @@ import signal import sys import threading -import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional, Union @@ -19,6 +18,7 @@ import zmq import zmq.asyncio +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOutput, @@ -29,6 +29,11 @@ TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) +from sglang.srt.observability.req_time_stats import ( + APIServerReqTimeStats, + calibrate_time_diff, + real_time, +) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_or_create_event_loop, get_zmq_socket, kill_process_tree from sglang.utils import get_exception_traceback @@ -138,16 +143,9 @@ class GrpcReqState: obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput] # Metrics (same as TokenizerManager's ReqState) - created_time: float - finished_time: float = 0.0 - first_token_time: float = 0.0 - last_time: float = 0.0 + time_stats: APIServerReqTimeStats last_completion_tokens: int = 1 - # perf_counter equivalents for accurate time calculations - finished_time_perf: float = 0.0 - first_token_time_perf: float = 0.0 - # Streaming state stream_finished: bool = False input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming @@ -209,12 +207,15 @@ def __init__( self.is_pause_cond = asyncio.Condition() # Metrics - self.last_receive_tstamp = time.time() + self.last_receive_tstamp = real_time() # Crash dump for debugging self.crash_dump_request_list = [] self.crash_dump_performed = False + # disaggregation mode + self.disaggregation_mode = DisaggregationMode(server_args.disaggregation_mode) + # Bootstrap server (passed from serve_grpc, not started here) self.bootstrap_server = bootstrap_server @@ -365,29 +366,15 @@ async def _handle_single_request( obj.rid = request_id - # Create and register request state - # TODO: support log_request - state = GrpcReqState( - request_id=request_id, - grpc_context=grpc_context, - out_queue=asyncio.Queue(), - finished=False, - event=asyncio.Event(), - obj=obj, - created_time=time.time(), - ) - - # Track session if needed - if hasattr(obj, "session_params") and obj.session_params: - state.session_id = obj.session_params.session_id - state.is_session_request = True - - self.rid_to_state[request_id] = state + self._req_stats_init(obj, grpc_context) + state = self.rid_to_state[request_id] self.record_request_for_crash_dump(obj) try: # Send to scheduler - let exceptions bubble up to grpc_server.py + state.time_stats.set_api_server_dispatch_time() await self._send_to_scheduler(obj) + state.time_stats.set_api_server_dispatch_finish_time() is_stream = getattr(obj, "stream", False) @@ -436,26 +423,17 @@ async def embedding_request( obj.rid = request_id - # Create request state - state = GrpcReqState( - request_id=request_id, - grpc_context=None, - out_queue=asyncio.Queue(), - finished=False, - event=asyncio.Event(), - obj=obj, - created_time=time.time(), - ) - - # Register state - self.rid_to_state[request_id] = state + self._req_stats_init(obj) + state = self.rid_to_state[request_id] # Create future for result future = asyncio.Future() # Send to scheduler try: + state.time_stats.set_api_server_dispatch_time() await self._send_to_scheduler(obj) + state.time_stats.set_api_server_dispatch_finish_time() except Exception as e: del self.rid_to_state[request_id] future.set_exception(e) @@ -515,7 +493,7 @@ async def handle_loop(self): try: # Receive from scheduler recv_obj = await self.recv_from_scheduler.recv_pyobj() - self.last_receive_tstamp = time.time() + self.last_receive_tstamp = real_time() # Check for pause (optimized: check flag before acquiring lock) if self.is_pause: @@ -612,8 +590,6 @@ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput): # Collect all queue.put() tasks for parallel execution put_tasks = [] cleanup_tasks = [] - now = time.time() - now_perf_counter = time.perf_counter() # Process each request in the batch for i, rid in enumerate(batch_out.rids): @@ -628,10 +604,10 @@ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput): continue # Update metrics - if state.first_token_time == 0.0: - state.first_token_time = now - state.first_token_time_perf = now_perf_counter - state.last_time = now + if state.time_stats.first_token_time == 0.0: + state.time_stats.set_first_token_time() + else: + state.time_stats.set_last_time() # Extract output for this request output_data = { @@ -728,8 +704,7 @@ def get_part(attr_name): # Handle completion if output_data["finished"]: state.finished = True - state.finished_time = now - state.finished_time_perf = now_perf_counter + state.time_stats.set_finished_time() state.stream_finished = True state.event.set() @@ -772,8 +747,7 @@ async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput): # Mark as finished state.finished = True - state.finished_time = time.time() - state.finished_time_perf = time.perf_counter() + state.time_stats.set_finished_time() state.event.set() async def _handle_health_check_output(self, health_out: HealthCheckOutput): @@ -805,8 +779,7 @@ async def _handle_health_check_output(self, health_out: HealthCheckOutput): # Mark as finished state.finished = True - state.finished_time = time.time() - state.finished_time_perf = time.perf_counter() + state.time_stats.set_finished_time() state.event.set() async def _handle_abort_req(self, recv_obj: AbortReq): @@ -883,7 +856,7 @@ def record_request_for_crash_dump(self, obj): if len(self.crash_dump_request_list) < 100: self.crash_dump_request_list.append( { - "time": time.time(), + "time": real_time(), "request_id": getattr(obj, "rid", "unknown"), "type": type(obj).__name__, } @@ -1002,6 +975,34 @@ async def sigterm_watchdog(self): while not self.gracefully_exit: await asyncio.sleep(1.0) + def _req_stats_init( + self, + obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], + grpc_context: Optional[grpc.ServicerContext] = None, + ): + calibrate_time_diff() + # Create and register request state + # TODO: support log_request + # TODO: support request tracing + time_stats = APIServerReqTimeStats(disagg_mode=self.disaggregation_mode) + state = GrpcReqState( + request_id=obj.rid, + grpc_context=grpc_context, + out_queue=asyncio.Queue(), + finished=False, + event=asyncio.Event(), + obj=obj, + time_stats=time_stats, + ) + + # Track session if needed + if hasattr(obj, "session_params") and obj.session_params: + state.session_id = obj.session_params.session_id + state.is_session_request = True + + self.rid_to_state[obj.rid] = state + time_stats.set_created_time() + async def print_exception_wrapper(func): """ diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 14990e5b9da0..62477a1730e4 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -35,22 +35,16 @@ TokenizedGenerateReqInput, WatchLoadUpdateReq, ) -from sglang.srt.managers.schedule_batch import Req, RequestStage +from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.metrics.cpu_monitor import start_cpu_monitor_thread +from sglang.srt.observability.cpu_monitor import start_cpu_monitor_thread +from sglang.srt.observability.req_time_stats import DPControllerReqTimeStats +from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info from sglang.srt.server_args import ( DP_ATTENTION_HANDSHAKE_PORT_DELTA, PortArgs, ServerArgs, ) -from sglang.srt.tracing.trace import ( - process_tracing_init, - trace_get_proc_propagate_context, - trace_set_proc_propagate_context, - trace_set_thread_info, - trace_slice_end, - trace_slice_start, -) from sglang.srt.utils import numa_utils from sglang.srt.utils.common import ( bind_port, @@ -197,15 +191,11 @@ def update_active_ranks(self, ranks: ActiveRanksOutput): self.status = ranks.status def dispatching_with_trace(self, req: Req): - if self.server_args.enable_trace: - trace_set_proc_propagate_context(req.rid, req.trace_context) - trace_slice_start(RequestStage.DC_DISPATCH, req.rid) - req.trace_context = trace_get_proc_propagate_context(req.rid) + req.time_stats = DPControllerReqTimeStats.new_from_obj(req.time_stats) + req.time_stats.set_dp_dispatch_time() self.dispatching(req) - - if self.server_args.enable_trace: - trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True) + req.time_stats.set_dp_dispatch_finish_time() def init_dispatcher(self): self._request_dispatcher = TypeBasedDispatcher( diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 652227860389..3ebfac7a03c3 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -34,7 +34,7 @@ FreezeGCReq, ) from sglang.srt.managers.multi_tokenizer_mixin import MultiHttpWorkerDetokenizerMixin -from sglang.srt.metrics.cpu_monitor import start_cpu_monitor_thread +from sglang.srt.observability.cpu_monitor import start_cpu_monitor_thread from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( configure_logger, @@ -400,11 +400,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): retraction_counts=recv_obj.retraction_counts, token_steps=recv_obj.token_steps, load=recv_obj.load, - queue_time=recv_obj.queue_time, - forward_entry_time=recv_obj.forward_entry_time, - prefill_launch_delay=recv_obj.prefill_launch_delay, - prefill_launch_latency=recv_obj.prefill_launch_latency, - prefill_finished_ts=recv_obj.prefill_finished_ts, + time_stats=recv_obj.time_stats, ) def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index ff1774567314..393f1cf06a32 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -30,6 +30,11 @@ from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data +from sglang.srt.observability.req_time_stats import ( + APIServerReqTimeStats, + DPControllerReqTimeStats, + SchedulerReqTimeStats, +) from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.utils import ImageData @@ -65,43 +70,6 @@ def regenerate_rids(self): return self.rids -@dataclass -class RequestTimingMetricsMixin: - """ - Mixin class containing common request-level timing metrics. - - This class consolidates the timing metrics that are shared across all batch output types - to avoid code duplication and ensure consistency. - """ - - # Queue duration: time spent waiting in queue before request is scheduled. - queue_time: Optional[List[Optional[float]]] - - # Forward entry time: timestamp when the request enters the forward pass stage. - # This corresponds to `forward_entry_time` in TimeStats. - # In different modes: - # - Unified/PD-colocate: timestamp when forward computation begins (covers prefill + decode) - # - Prefill instance (P): timestamp when prefill forward pass begins - # - Decode instance (D): timestamp when decode forward pass begins - # Note: This is NOT the same as prefill_start_time. There may be a delay between - # forward_entry_time and prefill_start_time (see prefill_launch_delay). - forward_entry_time: Optional[List[Optional[float]]] - - # Prefill launch delay: time spent waiting between forward entry and prefill start. - # Calculated as: prefill_start_time - forward_entry_time - # This represents the delay between when the request enters the forward stage - # and when prefill computation actually begins. - prefill_launch_delay: Optional[List[Optional[float]]] - - # Prefill launch latency: time spent during prefill kernel launch. - # Calculated as: prefill_end_time_host - prefill_start_time_host - prefill_launch_latency: Optional[List[Optional[float]]] - - # Prefill finished time: timestamp when prefill phase completes (wall clock time). - # This marks when the prefill computation finishes. - prefill_finished_ts: Optional[List[Optional[float]]] - - @dataclass class SpeculativeDecodingMetricsMixin: """ @@ -124,23 +92,6 @@ class SpeculativeDecodingMetricsMixin: spec_acceptance_histogram: List[List[int]] -@dataclass -class APIServingTimingMixin: - # Validation step duration - validation_time: Optional[float] = None - - # For metrics - received_time: Optional[float] = None - - # Perf_counter equivalents for accurate time calculations - received_time_perf: Optional[float] = None - - -_API_SERVING_TIMING_MIXIN_FIELDS = tuple( - APIServingTimingMixin.__dataclass_fields__.keys() -) - - # Parameters for a session @dataclass class SessionParams: @@ -169,7 +120,7 @@ class SessionParams: @dataclass -class GenerateReqInput(BaseReq, APIServingTimingMixin): +class GenerateReqInput(BaseReq): # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None # The token ids for text; one can specify either text or input_ids @@ -268,6 +219,7 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): # Propagates trace context via Engine.generate/async_generate external_trace_header: Optional[Dict] = None + received_time: Optional[float] = None # For EPD-disaggregated inference need_wait_for_image: Optional[bool] = None @@ -684,10 +636,7 @@ def __getitem__(self, i): return_entropy=self.return_entropy, external_trace_header=self.external_trace_header, http_worker_ipc=self.http_worker_ipc, - **{ - field: getattr(self, field) - for field in _API_SERVING_TIMING_MIXIN_FIELDS - }, + received_time=self.received_time, ) @@ -759,9 +708,6 @@ class TokenizedGenerateReqInput(BaseReq): # Whether to disallow logging for this request (e.g. due to ZDR) no_logs: bool = False - # tracing context - trace_context: Optional[Dict] = None - # (Internal) Whether to return bytes for image generation return_bytes: bool = False @@ -771,6 +717,9 @@ class TokenizedGenerateReqInput(BaseReq): need_wait_for_image: bool = False num_items_assigned: Optional[List] = None + # For observability + time_stats: Optional[Union[APIServerReqTimeStats, DPControllerReqTimeStats]] = None + @dataclass class BatchTokenizedGenerateReqInput(BaseBatchReq): @@ -788,7 +737,7 @@ def __iter__(self): @dataclass -class EmbeddingReqInput(BaseReq, APIServingTimingMixin): +class EmbeddingReqInput(BaseReq): # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[List[str]], List[str], str]] = None # The image input. It can be an image instance, file name, URL, or base64 encoded string. @@ -812,8 +761,6 @@ class EmbeddingReqInput(BaseReq, APIServingTimingMixin): log_metrics: bool = True # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None - # Validation step duration - validation_time: Optional[float] = None # For cross-encoder requests is_cross_encoder_request: bool = False # Priority for the request @@ -826,6 +773,7 @@ class EmbeddingReqInput(BaseReq, APIServingTimingMixin): # Propagates trace context via Engine.encode/async_encode external_trace_header: Optional[Dict] = None + received_time: Optional[float] = None # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings. dimensions: Optional[int] = None @@ -933,10 +881,7 @@ def __getitem__(self, i): external_trace_header=self.external_trace_header, dimensions=self.dimensions, http_worker_ipc=self.http_worker_ipc, - **{ - field: getattr(self, field) - for field in _API_SERVING_TIMING_MIXIN_FIELDS - }, + received_time=self.received_time, ) @@ -960,6 +905,8 @@ class TokenizedEmbeddingReqInput(BaseReq): dimensions: Optional[int] = None # LoRA related lora_id: Optional[str] = None # None means just use the base model + # For observability + time_stats: Optional[Union[APIServerReqTimeStats, DPControllerReqTimeStats]] = None @dataclass @@ -978,9 +925,7 @@ def __iter__(self): @dataclass -class BatchTokenIDOutput( - BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin -): +class BatchTokenIDOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): # The finish reason finished_reasons: List[BaseFinishReason] # For incremental decoding @@ -1040,6 +985,9 @@ class BatchTokenIDOutput( # Detailed breakdown of cached tokens by source (device/host/storage) cached_tokens_details: Optional[List[Optional[Dict[str, Any]]]] = None + # For observability + time_stats: Optional[List[SchedulerReqTimeStats]] = None + @dataclass class BatchMultimodalDecodeReq(BaseBatchReq): @@ -1074,9 +1022,7 @@ class BatchMultimodalDecodeReq(BaseBatchReq): @dataclass -class BatchStrOutput( - BaseBatchReq, RequestTimingMetricsMixin, SpeculativeDecodingMetricsMixin -): +class BatchStrOutput(BaseBatchReq, SpeculativeDecodingMetricsMixin): # The finish reason finished_reasons: List[dict] # The output decoded strings @@ -1131,6 +1077,9 @@ class BatchStrOutput( # Detailed breakdown of cached tokens by source (device/host/storage) cached_tokens_details: Optional[List[Optional[Dict[str, Any]]]] = None + # For observability + time_stats: Optional[List[SchedulerReqTimeStats]] = None + @dataclass class BatchMultimodalOutput(BaseBatchReq): @@ -1158,9 +1107,12 @@ class BatchMultimodalOutput(BaseBatchReq): # Detailed breakdown of cached tokens by source (device/host/storage) cached_tokens_details: Optional[List[Optional[Dict[str, Any]]]] = None + # For observability + time_stats: Optional[List[SchedulerReqTimeStats]] = None + @dataclass -class BatchEmbeddingOutput(BaseBatchReq, RequestTimingMetricsMixin): +class BatchEmbeddingOutput(BaseBatchReq): # The finish reason finished_reasons: List[BaseFinishReason] # The output embedding @@ -1177,6 +1129,9 @@ class BatchEmbeddingOutput(BaseBatchReq, RequestTimingMetricsMixin): # Detailed breakdown of cached tokens by source (device/host/storage) cached_tokens_details: Optional[List[Optional[Dict[str, Any]]]] = None + # For observability + time_stats: Optional[List[SchedulerReqTimeStats]] = None + @dataclass class ClearHiCacheReqInput(BaseReq): diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index e1236aa0f3aa..9e7679dc81bf 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -131,17 +131,7 @@ def _handle_output_by_index(output, i): spec_acceptance_histogram=_extract_field_by_index( output, "spec_acceptance_histogram", i ), - queue_time=_extract_field_by_index(output, "queue_time", i), - forward_entry_time=_extract_field_by_index(output, "forward_entry_time", i), - prefill_launch_delay=_extract_field_by_index( - output, "prefill_launch_delay", i - ), - prefill_launch_latency=_extract_field_by_index( - output, "prefill_launch_latency", i - ), - prefill_finished_ts=_extract_field_by_index( - output, "prefill_finished_ts", i - ), + time_stats=_extract_field_by_index(output, "time_stats", i), finished_reasons=_extract_field_by_index(output, "finished_reasons", i), decoded_texts=_extract_field_by_index(output, "decoded_texts", i), decode_ids=_extract_field_by_index(output, "decode_ids", i), @@ -228,17 +218,7 @@ def _handle_output_by_index(output, i): spec_acceptance_histogram=_extract_field_by_index( output, "spec_acceptance_histogram", i ), - queue_time=_extract_field_by_index(output, "queue_time", i), - forward_entry_time=_extract_field_by_index(output, "forward_entry_time", i), - prefill_launch_delay=_extract_field_by_index( - output, "prefill_launch_delay", i - ), - prefill_launch_latency=_extract_field_by_index( - output, "prefill_launch_latency", i - ), - prefill_finished_ts=_extract_field_by_index( - output, "prefill_finished_ts", i - ), + time_stats=_extract_field_by_index(output, "time_stats", i), finished_reasons=_extract_field_by_index(output, "finished_reasons", i), output_strs=_extract_field_by_index(output, "output_strs", i), output_ids=_extract_field_by_index(output, "output_ids", i), diff --git a/python/sglang/srt/managers/prefill_delayer.py b/python/sglang/srt/managers/prefill_delayer.py index 8df34fe8ee6f..37a1f63f24e0 100644 --- a/python/sglang/srt/managers/prefill_delayer.py +++ b/python/sglang/srt/managers/prefill_delayer.py @@ -9,7 +9,7 @@ from sglang.srt.utils import get_bool_env_var if TYPE_CHECKING: - from sglang.srt.metrics.collector import SchedulerMetricsCollector + from sglang.srt.observability.metrics_collector import SchedulerMetricsCollector _DEBUG_LOG = get_bool_env_var("SGLANG_PREFILL_DELAYER_DEBUG_LOG") diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c0799579802e..a7a74e6eecb0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,7 +1,5 @@ from __future__ import annotations -import enum - from sglang.srt.dllm.config import DllmConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils.common import ceil_align @@ -41,7 +39,6 @@ import dataclasses import logging import re -import time from enum import Enum, auto from functools import lru_cache from http import HTTPStatus @@ -72,16 +69,20 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator -from sglang.srt.metrics.collector import ( - DPCooperationInfo, - SchedulerMetricsCollector, - TimeStats, -) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, ) +from sglang.srt.observability.metrics_collector import ( + DPCooperationInfo, + SchedulerMetricsCollector, +) +from sglang.srt.observability.req_time_stats import ( + APIServerReqTimeStats, + DPControllerReqTimeStats, + SchedulerReqTimeStats, +) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs, get_global_server_args @@ -92,7 +93,7 @@ from typing import Any, Dict from sglang.srt.configs.model_config import ModelConfig - from sglang.srt.managers.scheduler_metrics_mixin import PrefillStats + from sglang.srt.observability.scheduler_metrics_mixin import PrefillStats from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm @@ -480,35 +481,6 @@ def merge(self, other: MultimodalInputs): # other args would be kept intact -class RequestStage(str, enum.Enum): - # Tokenizer - TOKENIZE = "tokenize" - TOKENIZER_DISPATCH = "dispatch" - - # DP controller - DC_DISPATCH = "dc_dispatch" - - # common/non-disaggregation - PREFILL_WAITING = "prefill_waiting" - REQUEST_PROCESS = "request_process" - DECODE_LOOP = "decode_loop" - PREFILL_FORWARD = "prefill_forward" - PREFILL_CHUNKED_FORWARD = "chunked_prefill" - - # disaggregation prefill - PREFILL_PREPARE = "prefill_prepare" - PREFILL_BOOTSTRAP = "prefill_bootstrap" - PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache" - - # disaggregation decode - DECODE_PREPARE = "decode_prepare" - DECODE_BOOTSTRAP = "decode_bootstrap" - DECODE_WAITING = "decode_waiting" - DECODE_TRANSFERRED = "decode_transferred" - DECODE_FAKE_OUTPUT = "fake_output" - DECODE_QUICK_FINISH = "quick_finish" - - class Req(ReqDllmMixin): """The input and output status of a request.""" @@ -545,6 +517,9 @@ def __init__( routing_key: Optional[str] = None, dimensions: Optional[int] = None, http_worker_ipc: Optional[str] = None, + time_stats: Optional[ + Union[APIServerReqTimeStats, DPControllerReqTimeStats] + ] = None, ): # Input and output info self.rid = rid @@ -779,11 +754,15 @@ def __init__( self.retraction_count = 0 self.retraction_mb_id = None - # For metrics + # For observability self.metrics_collector = metrics_collector - self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode) + if time_stats is not None: + self.time_stats = SchedulerReqTimeStats.new_from_obj(time_stats) + else: + self.time_stats = SchedulerReqTimeStats(disagg_mode=disagg_mode) + self.time_stats.set_metrics_collector(metrics_collector) + self.time_stats.set_scheduler_recv_time() self.has_log_time_stats: bool = False - self.last_tic = time.monotonic() # For disaggregation self.bootstrap_host: str = bootstrap_host @@ -853,16 +832,6 @@ def pop_overallocated_kv_cache(self) -> Tuple[int, int]: self.kv_overallocated_freed = True return self.kv_committed_len, self.kv_allocated_len - def add_latency(self, stage: RequestStage): - if self.metrics_collector is None: - return - - now = time.monotonic() - self.metrics_collector.observe_per_stage_req_latency( - stage.value, now - self.last_tic - ) - self.last_tic = now - def update_spec_acceptance_histogram(self, accepted_draft_tokens: int): """Update the speculative decoding acceptance histogram. diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 430f5f153507..3b4d264da9ed 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -145,7 +145,6 @@ ModelWorkerBatch, MultimodalInputs, Req, - RequestStage, ScheduleBatch, ) from sglang.srt.managers.schedule_policy import ( @@ -155,11 +154,6 @@ ) from sglang.srt.managers.scheduler_dp_attn_mixin import SchedulerDPAttnMixin from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker -from sglang.srt.managers.scheduler_metrics_mixin import ( - RECORD_STEP_TIME, - PrefillStats, - SchedulerMetricsMixin, -) from sglang.srt.managers.scheduler_output_processor_mixin import ( SchedulerOutputProcessorMixin, ) @@ -180,18 +174,20 @@ from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin +from sglang.srt.observability.req_time_stats import ( + real_time, + set_schedule_time_batch, + set_time_batch, +) +from sglang.srt.observability.scheduler_metrics_mixin import ( + RECORD_STEP_TIME, + PrefillStats, + SchedulerMetricsMixin, +) +from sglang.srt.observability.trace import process_tracing_init, trace_set_thread_info from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.tracing.trace import ( - process_tracing_init, - trace_event_batch, - trace_set_proc_propagate_context, - trace_set_thread_info, - trace_slice_batch, - trace_slice_end, - trace_slice_start, -) from sglang.srt.utils import ( DynamicGradMode, broadcast_pyobj, @@ -316,7 +312,6 @@ def __init__( self.enable_metrics_for_all_schedulers = ( server_args.enable_metrics_for_all_schedulers ) - self.enable_trace = server_args.enable_trace self.stream_interval = server_args.stream_interval self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm @@ -1322,14 +1317,6 @@ def recv_requests( prepare_abort(req, error_msg, status_code=status_code) self.stream_output([req], req.return_logprob) - if self.enable_trace: - for req in recv_reqs: - if isinstance( - req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) - ): - trace_set_proc_propagate_context(req.rid, req.trace_context) - trace_slice_start("", req.rid, anonymous=True) - return recv_reqs def _split_work_and_control_reqs(self, recv_reqs: List): @@ -1527,6 +1514,7 @@ def handle_generate_request( routing_key=recv_req.routing_key, http_worker_ipc=recv_req.http_worker_ipc, dllm_config=self.dllm_config, + time_stats=recv_req.time_stats, ) req.tokenizer = self.tokenizer @@ -1538,6 +1526,9 @@ def handle_generate_request( f"bootstrap room id. {req.rid=}" ) logger.error(error_msg) + recv_req.time_stats.trace_ctx.abort( + abort_info={"reason": error_msg} + ) prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST) self.stream_output([req], req.return_logprob) return @@ -1558,6 +1549,9 @@ def handle_generate_request( req = session.create_req( recv_req, self.tokenizer, self.model_config.vocab_size ) + # TODO: set trace context + if self.enable_metrics: + req.time_stats.set_metrics_collector(self.metrics_collector) if isinstance(req.finished_reason, FINISH_ABORT): self.init_req_max_new_tokens(req) self._add_request_to_queue(req) @@ -1684,18 +1678,19 @@ def _add_request_to_queue(self, req: Req, is_retracted: bool = False): return self._prefetch_kvcache(req) self.waiting_queue.append(req) - req.time_stats.wait_queue_entry_time = time.perf_counter() - trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True) + req.time_stats.set_wait_queue_entry_time() elif self.disaggregation_mode == DisaggregationMode.PREFILL: self._prefetch_kvcache(req) self.disagg_prefill_bootstrap_queue.add( req, self.model_config.num_key_value_heads ) - req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter() + req.time_stats.set_prefill_bootstrap_queue_entry_time() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted) if not is_retracted: - req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter() + req.time_stats.set_decode_prealloc_queue_entry_time() + else: + req.time_stats.set_retract_time() else: raise ValueError(f"Invalid {self.disaggregation_mode=}") @@ -1719,6 +1714,7 @@ def _set_or_validate_priority(self, req: Req) -> bool: }, rid=req.rid, ) + req.time_stats.trace_ctx.abort(abort_info=abort_req.finished_reason) self.send_to_tokenizer.send_output(abort_req, req) return False return True @@ -1769,6 +1765,7 @@ def _abort_on_queued_limit(self, recv_req: Req) -> bool: ), req_to_abort, ) + req_to_abort.time_stats.trace_ctx.abort(abort_info={"reason": message}) return req_to_abort.rid == recv_req.rid def _abort_on_waiting_timeout(self): @@ -1815,6 +1812,7 @@ def handle_embedding_request( dimensions=recv_req.dimensions, lora_id=recv_req.lora_id, http_worker_ipc=recv_req.http_worker_ipc, + time_stats=recv_req.time_stats, ) req.tokenizer = self.tokenizer @@ -1947,7 +1945,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: ret = self.maybe_prepare_mlp_sync_batch(ret, need_sync=need_mlp_sync) if ret: - trace_event_batch("schedule", ret.reqs) + set_schedule_time_batch(ret) return ret @@ -2118,11 +2116,6 @@ def _get_new_batch_prefill_raw( if len(can_run_list) == 0: return None - if self.enable_metrics: - # only record queue time when enable_metrics is True to avoid overhead - for req in can_run_list: - req.add_latency(RequestStage.PREFILL_WAITING) - self.waiting_queue = [ x for x in self.waiting_queue if x not in set(can_run_list) ] @@ -2143,14 +2136,7 @@ def _get_new_batch_prefill_raw( self.can_run_list = can_run_list self.running_bs = len(self.running_batch.reqs) - # Record metrics - for req in can_run_list: - if req.time_stats.forward_entry_time == 0: - req.time_stats.forward_entry_time = time.perf_counter() - if self.enable_metrics: - self.metrics_collector.observe_queue_time( - req.time_stats.get_queueing_time(), - ) + set_time_batch(can_run_list, "set_forward_entry_time") # Create a new batch new_batch = ScheduleBatch.init_new( @@ -2291,9 +2277,7 @@ def run_batch( # Capture prefill start time for EXTEND mode if batch.forward_mode == ForwardMode.EXTEND: - current_time = time.perf_counter() - for req in batch.reqs: - req.time_stats.prefill_start_time_host = current_time + set_time_batch(batch.reqs, "set_prefill_run_batch_start_time") # Place holder handling for pd-disagg decode event loop if batch.forward_mode.is_prebuilt(): @@ -2410,9 +2394,7 @@ def run_batch( # Capture prefill end time for EXTEND mode if batch.forward_mode == ForwardMode.EXTEND: - current_time = time.perf_counter() - for req in batch.reqs: - req.time_stats.prefill_end_time_host = current_time + set_time_batch(batch.reqs, "set_prefill_run_batch_end_time") if ( self.server_args.enable_dp_attention @@ -2451,7 +2433,6 @@ def process_batch_result( ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) - trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs) elif batch.forward_mode.is_extend(): if batch.is_dllm(): self.process_batch_result_dllm(batch, result) @@ -3009,7 +2990,7 @@ class IdleSleeper: def __init__(self, sockets): self.poller = zmq.Poller() - self.last_empty_time = time.time() + self.last_empty_time = real_time() for s in sockets: self.poller.register(s, zmq.POLLIN) @@ -3019,9 +3000,9 @@ def maybe_sleep(self): self.poller.poll(1000) if ( self.empty_cache_interval > 0 - and time.time() - self.last_empty_time > self.empty_cache_interval + and real_time() - self.last_empty_time > self.empty_cache_interval ): - self.last_empty_time = time.time() + self.last_empty_time = real_time() torch.cuda.empty_cache() diff --git a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py index 12ef400dbd75..58122c8e454d 100644 --- a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py +++ b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py @@ -9,8 +9,8 @@ from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.environ import envs from sglang.srt.managers.schedule_batch import ScheduleBatch -from sglang.srt.metrics.collector import DPCooperationInfo from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.observability.metrics_collector import DPCooperationInfo from sglang.srt.utils.common import require_mlp_tp_gather if TYPE_CHECKING: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index bc3449da02ec..b0883ae29cde 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import time from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -18,12 +17,10 @@ from sglang.srt.managers.schedule_batch import ( BaseFinishReason, Req, - RequestStage, ScheduleBatch, ) from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.server_args import get_global_server_args -from sglang.srt.tracing.trace import trace_slice, trace_slice_batch, trace_slice_end if TYPE_CHECKING: from sglang.srt.managers.scheduler import ( @@ -86,20 +83,13 @@ def _get_cached_tokens_details(self, req: Req) -> Optional[dict]: def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): assert self.disaggregation_mode == DisaggregationMode.DECODE for req in batch.reqs: + req.time_stats.set_decode_prebuilt_finish_time() req.check_finished() if req.finished(): - req.time_stats.forward_entry_time = req.time_stats.completion_time = ( - time.perf_counter() - ) - trace_slice_end( - RequestStage.DECODE_QUICK_FINISH, - req.rid, - thread_finish_flag=True, - ) + req.time_stats.set_quick_finish_time() release_kv_cache(req, self.tree_cache) # Note: Logprobs should be handled on the prefill engine. - trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs) self.stream_output(batch.reqs, batch.return_logprob) def maybe_collect_routed_experts(self: Scheduler, req: Req): @@ -174,8 +164,7 @@ def process_batch_result_prefill( continue if req.is_chunked <= 0: - if req.time_stats.prefill_finished_ts == 0.0: - req.time_stats.prefill_finished_ts = time.time() + req.time_stats.set_prefill_finished_time() # req output_ids are set here req.output_ids.append(next_token_id) @@ -184,7 +173,7 @@ def process_batch_result_prefill( if req.finished(): self.maybe_collect_routed_experts(req) release_kv_cache(req, self.tree_cache) - req.time_stats.completion_time = time.perf_counter() + req.time_stats.set_completion_time() elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) @@ -241,13 +230,6 @@ def process_batch_result_prefill( self.abort_request(AbortReq(rid=req.rid)) req.grammar.finished = req.finished() - trace_slice( - RequestStage.PREFILL_FORWARD, - req.rid, - auto_next_anon=not req.finished(), - thread_finish_flag=req.finished(), - ) - else: # being chunked reqs' prefill is not finished req.is_chunked -= 1 @@ -276,11 +258,7 @@ def process_batch_result_prefill( ) logprob_pt += num_input_logprobs - trace_slice( - RequestStage.PREFILL_CHUNKED_FORWARD, - req.rid, - auto_next_anon=True, - ) + req.time_stats.set_last_chunked_prefill_finish_time() else: # embedding or reward model if result.copy_done is not None: @@ -312,24 +290,20 @@ def process_batch_result_prefill( req.embedding = embeddings[i] if req.is_chunked <= 0: + req.time_stats.set_prefill_finished_time() # Dummy output token for embedding models req.output_ids.append(0) req.check_finished() if req.finished(): release_kv_cache(req, self.tree_cache) + req.time_stats.set_completion_time() else: self.tree_cache.cache_unfinished_req(req) else: # being chunked reqs' prefill is not finished req.is_chunked -= 1 - - trace_slice( - RequestStage.PREFILL_FORWARD, - req.rid, - auto_next_anon=not req.finished(), - thread_finish_flag=req.finished(), - ) + req.time_stats.set_last_chunked_prefill_finish_time() self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) @@ -405,7 +379,7 @@ def process_batch_result_dllm( req.check_finished() if req.finished(): release_kv_cache(req, self.tree_cache) - req.time_stats.completion_time = time.perf_counter() + req.time_stats.set_completion_time() break self.tree_cache.cache_unfinished_req(req) @@ -474,6 +448,8 @@ def process_batch_result_decode( # Update Mamba last track seqlen self._mamba_prefix_cache_update(req, batch, result, i) + req.time_stats.set_last_decode_finish_time() + req.check_finished(new_accepted_len) if req.finished(): @@ -486,7 +462,7 @@ def process_batch_result_decode( else: release_kv_cache(req, self.tree_cache) - req.time_stats.completion_time = time.perf_counter() + req.time_stats.set_completion_time() self.maybe_collect_customized_info(i, req, logits_output) @@ -924,11 +900,7 @@ def stream_output_generation( routed_experts = None customized_info = {} - queue_times = [] - forward_entry_times = [] - prefill_launch_delays = [] - prefill_launch_latencies = [] - prefill_finished_timestamps = [] + time_stats = [] if return_logprob: input_token_logprobs_val = [] @@ -1034,16 +1006,7 @@ def stream_output_generation( retraction_counts.append(req.retraction_count) - queue_times.append(req.time_stats.get_queueing_time()) - forward_entry_times.append(req.time_stats.forward_entry_time) - - prefill_launch_delays.append(req.time_stats.get_prefill_launch_delay()) - prefill_launch_latencies.append( - req.time_stats.get_prefill_launch_latency() - ) - prefill_finished_timestamps.append( - req.time_stats.get_prefill_finished_ts() - ) + time_stats.append(req.time_stats) if not self.spec_algorithm.is_none(): spec_verify_ct.append(req.spec_verify_ct) @@ -1151,11 +1114,7 @@ def stream_output_generation( spec_verify_ct=spec_verify_ct, spec_accepted_tokens=spec_accepted_tokens, spec_acceptance_histogram=spec_acceptance_histogram, - queue_time=queue_times, - forward_entry_time=forward_entry_times, - prefill_launch_delay=prefill_launch_delays, - prefill_launch_latency=prefill_launch_latencies, - prefill_finished_ts=prefill_finished_timestamps, + time_stats=time_stats, finished_reasons=finished_reasons, decoded_texts=decoded_texts, decode_ids=decode_ids_list, @@ -1200,11 +1159,7 @@ def stream_output_embedding(self: Scheduler, reqs: List[Req]): prompt_tokens = [] cached_tokens = [] cached_tokens_details = [] # Detailed breakdown by cache source - queue_times = [] - forward_entry_times = [] - prefill_launch_delays = [] - prefill_launch_latencies = [] - prefill_finished_timestamps = [] + time_stats = [] retraction_counts = [] for req in reqs: if req.finished(): @@ -1217,27 +1172,13 @@ def stream_output_embedding(self: Scheduler, reqs: List[Req]): # Collect detailed cache breakdown if available cached_tokens_details.append(self._get_cached_tokens_details(req)) - - queue_times.append(req.time_stats.get_queueing_time()) - forward_entry_times.append(req.time_stats.forward_entry_time) - - prefill_launch_delays.append(req.time_stats.get_prefill_launch_delay()) - prefill_launch_latencies.append( - req.time_stats.get_prefill_launch_latency() - ) - prefill_finished_timestamps.append( - req.time_stats.get_prefill_finished_ts() - ) + time_stats.append(req.time_stats) retraction_counts.append(req.retraction_count) self.send_to_detokenizer.send_output( BatchEmbeddingOutput( rids=rids, http_worker_ipcs=http_worker_ipcs, - queue_time=queue_times, - forward_entry_time=forward_entry_times, - prefill_launch_delay=prefill_launch_delays, - prefill_launch_latency=prefill_launch_latencies, - prefill_finished_ts=prefill_finished_timestamps, + time_stats=time_stats, finished_reasons=finished_reasons, embeddings=embeddings, prompt_tokens=prompt_tokens, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f113e5409c9d..db13ddce18b0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -24,7 +24,6 @@ import socket import sys import threading -import time from collections import deque from contextlib import nullcontext from datetime import datetime @@ -73,16 +72,26 @@ ) from sglang.srt.managers.mm_utils import TensorTransportMode, wrap_shm_features from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors -from sglang.srt.managers.request_metrics_exporter import RequestMetricsExporterManager -from sglang.srt.managers.schedule_batch import MultimodalDataItem, RequestStage +from sglang.srt.managers.schedule_batch import MultimodalDataItem from sglang.srt.managers.scheduler import is_health_check_generate_req from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin from sglang.srt.managers.tokenizer_manager_multiitem_mixin import ( TokenizerManagerMultiItemMixin, ) -from sglang.srt.metrics.collector import TokenizerMetricsCollector -from sglang.srt.metrics.cpu_monitor import start_cpu_monitor_thread +from sglang.srt.observability.cpu_monitor import start_cpu_monitor_thread +from sglang.srt.observability.metrics_collector import TokenizerMetricsCollector +from sglang.srt.observability.req_time_stats import ( + APIServerReqTimeStats, + calibrate_time_diff, + convert_time_to_realtime, + real_time, + set_time_batch, +) +from sglang.srt.observability.request_metrics_exporter import ( + RequestMetricsExporterManager, +) +from sglang.srt.observability.trace import SpanAttributes, extract_trace_headers from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ( PortArgs, @@ -90,16 +99,6 @@ set_global_server_args_for_tokenizer, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.tracing.trace import ( - SpanAttributes, - extract_trace_headers, - trace_get_proc_propagate_context, - trace_req_finish, - trace_req_start, - trace_set_remote_propagate_context, - trace_slice_end, - trace_slice_start, -) from sglang.srt.utils import ( configure_gc_warning, freeze_gc, @@ -133,21 +132,9 @@ class ReqState: finished: bool event: asyncio.Event obj: Union[GenerateReqInput, EmbeddingReqInput] - - # For metrics - created_time: float - finished_time: float = 0.0 - first_token_time: float = 0.0 - last_time: float = 0.0 + time_stats: APIServerReqTimeStats last_completion_tokens: int = 1 - # perf_counter equivalents for accurate time calculations - finished_time_perf: float = 0.0 - first_token_time_perf: float = 0.0 - - request_sent_to_scheduler_ts: float = 0.0 - response_sent_to_client_ts: float = 0.0 - # For streaming output last_output_offset: int = 0 @@ -198,7 +185,6 @@ def __init__( self.enable_metrics = server_args.enable_metrics self.preferred_sampling_params = server_args.preferred_sampling_params self.crash_dump_folder = server_args.crash_dump_folder - self.enable_trace = server_args.enable_trace set_global_server_args_for_tokenizer(server_args) # Init model config @@ -347,7 +333,7 @@ def init_running_status(self): # Health check self.server_status = ServerStatus.Starting self.gracefully_exit = False - self.last_receive_tstamp = 0 + self.last_receive_tstamp = real_time() # For load balancing self.current_load = 0 @@ -357,6 +343,7 @@ def init_running_status(self): self.session_futures = {} # session_id -> asyncio event def init_request_logging_and_dumping(self): + # TODO: Refactor and organize the log export code. # Request logging self.request_logger = RequestLogger( log_requests=self.server_args.log_requests, @@ -485,20 +472,18 @@ def init_request_dispatcher(self): self.sampling_params_class = SamplingParams self.signal_handler_class = SignalHandler - self.req_state_class = ReqState async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, ): - created_time = obj.received_time if obj.received_time else time.time() self.auto_create_handle_loop() # Normalize the request obj.normalize_batch_and_arguments() - if self.enable_trace: - self._trace_request_start(obj, created_time, request) + + self._req_stats_init(obj, request) if self.server_args.language_only: self._handle_epd_disaggregation_encode_request(obj) if self.server_args.tokenizer_worker_num > 1: @@ -516,13 +501,12 @@ async def generate_request( # Tokenize the request and send it to the scheduler if obj.is_single: tokenized_obj = await self._tokenize_one_request(obj) - state = self._send_one_request(obj, tokenized_obj, created_time) + state = self.rid_to_state[obj.rid] + self._send_one_request(tokenized_obj) async for response in self._wait_one_response(obj, state, request): yield response else: - async for response in self._handle_batch_request( - obj, request, created_time - ): + async for response in self._handle_batch_request(obj, request): yield response def _detect_input_format( @@ -749,7 +733,6 @@ async def _tokenize_one_request( mm_inputs = None self._validate_one_request(obj, input_ids) - trace_slice_end(RequestStage.TOKENIZE, obj.rid) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids ) @@ -967,6 +950,9 @@ def _create_tokenized_object( http_worker_ipc=obj.http_worker_ipc, ) + tokenized_obj.time_stats = self.rid_to_state[obj.rid].time_stats + self.rid_to_state[obj.rid].time_stats.set_tokenize_finish_time() + return tokenized_obj async def _batch_tokenize_and_process( @@ -1010,7 +996,6 @@ async def _batch_tokenize_and_process( req, req.text, input_ids_list[i], None, None, token_type_ids ) ) - trace_slice_end(RequestStage.TOKENIZE, req.rid) logger.debug(f"Completed batch processing for {batch_size} requests") return tokenized_objs @@ -1062,31 +1047,18 @@ def _should_use_batch_tokenization(self, batch_size, requests) -> bool: def _send_one_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], - created_time: Optional[float] = None, ): - trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid) - tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid) + tokenized_obj.time_stats.set_api_server_dispatch_time() tokenized_obj = wrap_shm_features(tokenized_obj) self.send_to_scheduler.send_pyobj(tokenized_obj) - state = self.req_state_class( - [], False, asyncio.Event(), obj, created_time=created_time - ) - state.request_sent_to_scheduler_ts = time.time() - self.rid_to_state[obj.rid] = state - trace_slice_end( - RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True - ) - return state + tokenized_obj.time_stats.set_api_server_dispatch_finish_time() def _send_batch_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], tokenized_objs: List[ Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput] ], - created_time: Optional[float] = None, ): """Send a batch of tokenized requests as a single batched request to the scheduler.""" if isinstance(tokenized_objs[0], TokenizedGenerateReqInput): @@ -1094,14 +1066,9 @@ def _send_batch_request( else: batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs) + set_time_batch(tokenized_objs, "set_api_server_dispatch_time") self.send_to_scheduler.send_pyobj(batch_req) - # Create states for each individual request in the batch - for i, tokenized_obj in enumerate(tokenized_objs): - tmp_obj = obj[i] - state = self.req_state_class( - [], False, asyncio.Event(), tmp_obj, created_time=created_time - ) - self.rid_to_state[tmp_obj.rid] = state + set_time_batch(tokenized_objs, "set_api_server_dispatch_finish_time") async def _wait_one_response( self, @@ -1135,13 +1102,13 @@ async def _wait_one_response( state.out_list = [] if state.finished: - # For non-streaming cases, response has not been sent yet (`response_sent_to_client_ts` has not been set yet). + # For non-streaming cases, response has not been sent yet (`response_sent_to_client_time` has not been set yet). # Record response sent time right before we log finished results and metrics. - if not state.response_sent_to_client_ts: - state.response_sent_to_client_ts = time.time() + if not state.time_stats.response_sent_to_client_time: + state.time_stats.set_response_sent_to_client_time() out["meta_info"][ "response_sent_to_client_ts" - ] = state.response_sent_to_client_ts + ] = state.time_stats.get_response_sent_to_client_realtime() self.request_logger.log_finished_request( obj, out, @@ -1198,11 +1165,11 @@ async def _wait_one_response( if is_stream: # Record response sent time right before we send response. - if not state.response_sent_to_client_ts: - state.response_sent_to_client_ts = time.time() + if not state.time_stats.response_sent_to_client_time: + state.time_stats.set_response_sent_to_client_time() out["meta_info"][ "response_sent_to_client_ts" - ] = state.response_sent_to_client_ts + ] = state.time_stats.get_response_sent_to_client_realtime() yield out else: if ( @@ -1221,7 +1188,6 @@ async def _handle_batch_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], request: Optional[fastapi.Request] = None, - created_time: Optional[float] = None, ): batch_size = obj.batch_size @@ -1230,16 +1196,14 @@ async def _handle_batch_request( if getattr(obj, "parallel_sample_num", 1) == 1: if self._should_use_batch_tokenization(batch_size, obj): tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj) - self._send_batch_request(obj, tokenized_objs, created_time) + self._send_batch_request(tokenized_objs) # Set up generators for each request in the batch for i in range(batch_size): tmp_obj = obj[i] - generators.append( - self._wait_one_response( - tmp_obj, self.rid_to_state[tmp_obj.rid], request - ) - ) + state = self.rid_to_state[tmp_obj.rid] + state.obj = tmp_obj + generators.append(self._wait_one_response(tmp_obj, state, request)) rids.append(tmp_obj.rid) else: # Sequential tokenization and processing @@ -1251,9 +1215,9 @@ async def _handle_batch_request( for i in range(batch_size): tmp_obj = obj[i] tokenized_obj = await self._tokenize_one_request(tmp_obj) - state = self._send_one_request( - tmp_obj, tokenized_obj, created_time - ) + state = self.rid_to_state[tmp_obj.rid] + state.obj = tmp_obj + self._send_one_request(tokenized_obj) generators.append( self._wait_one_response(tmp_obj, state, request) ) @@ -1281,7 +1245,10 @@ async def _handle_batch_request( tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params) tokenized_obj.sampling_params.max_new_tokens = 0 tokenized_obj.stream = False - state = self._send_one_request(tmp_obj, tokenized_obj, created_time) + self._req_stats_init(tmp_obj) + state = self.rid_to_state[tmp_obj.rid] + tokenized_obj.time_stats = state.time_stats + self._send_one_request(tokenized_obj) await self._wait_one_response(tmp_obj, state, request).__anext__() # Expand requests, assign new rids for them, and send them @@ -1290,10 +1257,16 @@ async def _handle_batch_request( tmp_obj = copy.copy(objs[i]) tokenized_obj = copy.copy(tokenized_objs[i]) tokenized_obj.rid = tmp_obj.regenerate_rid() - state = self._send_one_request(tmp_obj, tokenized_obj, created_time) + self._req_stats_init(tmp_obj) + state = self.rid_to_state[tmp_obj.rid] + tokenized_obj.time_stats = state.time_stats + self._send_one_request(tokenized_obj) generators.append(self._wait_one_response(tmp_obj, state, request)) rids.append(tmp_obj.rid) + self.rid_to_state[objs[i].rid].time_stats.set_finished_time() + del self.rid_to_state[objs[i].rid] + # Wait for all requests is_stream = hasattr(obj, "stream") and obj.stream if not is_stream: @@ -1476,7 +1449,7 @@ async def handle_loop(self): with self.soft_watchdog.disable(): recv_obj = await self.recv_from_detokenizer.recv_pyobj() self._result_dispatcher(recv_obj) - self.last_receive_tstamp = time.time() + self.last_receive_tstamp = real_time() self.soft_watchdog.feed() def _handle_batch_output( @@ -1506,16 +1479,9 @@ def _handle_batch_output( } if self.enable_metrics: - self._add_metric_if_present(recv_obj, "queue_time", meta_info, i) - self._add_metric_if_present( - recv_obj, "prefill_launch_delay", meta_info, i - ) - self._add_metric_if_present( - recv_obj, "prefill_launch_latency", meta_info, i - ) - self._add_metric_if_present( - recv_obj, "prefill_finished_ts", meta_info, i - ) + if recv_obj.time_stats is not None: + scheduler_time_stats = recv_obj.time_stats[i] + meta_info.update(scheduler_time_stats.convert_to_output_meta_info()) if getattr(state.obj, "return_logprob", False): self.convert_logprob_style( @@ -1596,20 +1562,30 @@ def _handle_batch_output( state.finished = recv_obj.finished_reasons[i] is not None if state.finished: - state.finished_time = time.time() - state.finished_time_perf = time.perf_counter() - meta_info["e2e_latency"] = state.finished_time - state.created_time + state.time_stats.trace_ctx.trace_set_root_attrs( + self.convert_to_span_attrs(state, recv_obj, i) + ) + state.time_stats.set_finished_time() + meta_info["e2e_latency"] = state.time_stats.get_e2e_latency() if self.server_args.speculative_algorithm: self._calculate_spec_decoding_metrics(meta_info, recv_obj, i) if self.enable_metrics: - self._calculate_timing_metrics(meta_info, state, recv_obj, i) - - trace_req_finish( - rid, - ts=int(state.finished_time * 1e9), - attrs=self.convert_to_span_attrs(state, recv_obj, i), - ) + scheduler_time_stats = ( + recv_obj.time_stats[i] + if recv_obj.time_stats is not None + else None + ) + completion_tokens = ( + recv_obj.completion_tokens[i] + if not isinstance(recv_obj, BatchEmbeddingOutput) + else 0 + ) + meta_info.update( + state.time_stats.convert_to_output_meta_info( + scheduler_time_stats, completion_tokens + ) + ) del self.rid_to_state[rid] @@ -1872,74 +1848,6 @@ def _calculate_spec_decoding_metrics( i ] - def _calculate_timing_metrics( - self, - meta_info: Dict[str, Any], - state: ReqState, - recv_obj: Union[ - BatchStrOutput, - BatchEmbeddingOutput, - BatchMultimodalOutput, - BatchTokenIDOutput, - ], - i: int, - ) -> None: - """Calculate request-level timing metrics, such as inference time, decode throughput, and time per token.""" - # Request timing timestamps. - if state.created_time > 0: - meta_info["request_received_ts"] = state.created_time - if state.request_sent_to_scheduler_ts > 0: - meta_info["request_sent_to_scheduler_ts"] = ( - state.request_sent_to_scheduler_ts - ) - if state.response_sent_to_client_ts > 0: - meta_info["response_sent_to_client_ts"] = state.response_sent_to_client_ts - if state.finished_time > 0: - meta_info["decode_finished_ts"] = state.finished_time - - # Inference time calculation. - if ( - hasattr(recv_obj, "forward_entry_time") - and recv_obj.forward_entry_time - and recv_obj.forward_entry_time[i] is not None - and state.finished_time_perf > 0.0 - ): - inference_time = state.finished_time_perf - recv_obj.forward_entry_time[i] - meta_info["inference_time"] = inference_time - - # Decode throughput, time per token calculation. Only calculated if TTFT is available. - if ( - state.first_token_time_perf > 0.0 - and state.finished_time_perf > 0.0 - and not isinstance(recv_obj, BatchEmbeddingOutput) - and recv_obj.completion_tokens[i] > 0 - ): - decode_time = state.finished_time_perf - state.first_token_time_perf - completion_tokens = recv_obj.completion_tokens[i] - meta_info["decode_throughput"] = completion_tokens / decode_time - - def _add_metric_if_present( - self, - recv_obj: Any, - attr_name: str, - meta_info: Dict[str, Any], - index: int, - ) -> None: - """Add a metric to meta_info if it exists and is not None. - - Args: - recv_obj: The received object that may contain the metric attribute - attr_name: The name of the attribute to check - meta_info: The dictionary to add the metric to - index: The index to access the metric value in the attribute list - """ - if ( - hasattr(recv_obj, attr_name) - and getattr(recv_obj, attr_name) - and getattr(recv_obj, attr_name)[index] is not None - ): - meta_info[attr_name] = getattr(recv_obj, attr_name)[index] - def _request_has_grammar(self, obj: GenerateReqInput) -> bool: return ( obj.sampling_params.get("json_schema", None) @@ -1962,26 +1870,23 @@ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): else self.metrics_collector.labels ) if ( - state.first_token_time == 0.0 + state.time_stats.first_token_time == 0.0 and self.disaggregation_mode != DisaggregationMode.PREFILL ): - state.first_token_time = state.last_time = time.time() - state.first_token_time_perf = time.perf_counter() + state.time_stats.set_first_token_time() state.last_completion_tokens = completion_tokens self.metrics_collector.observe_time_to_first_token( - labels, state.first_token_time - state.created_time + labels, state.time_stats.get_first_token_latency() ) else: num_new_tokens = completion_tokens - state.last_completion_tokens if num_new_tokens: - new_time = time.time() - interval = new_time - state.last_time self.metrics_collector.observe_inter_token_latency( labels, - interval, + state.time_stats.get_interval(), num_new_tokens, ) - state.last_time = new_time + state.time_stats.set_last_time() state.last_completion_tokens = completion_tokens if state.finished: @@ -2005,7 +1910,7 @@ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): recv_obj.prompt_tokens[i], completion_tokens, recv_obj.cached_tokens[i], - state.finished_time - state.created_time, + state.time_stats.get_e2e_latency(), self._request_has_grammar(state.obj), retraction_count, cached_tokens_details, @@ -2013,7 +1918,12 @@ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int): def dump_requests(self, state: ReqState, out_dict: dict): self.dump_request_list.append( - (state.obj, out_dict, state.created_time, time.time()) + ( + state.obj, + out_dict, + convert_time_to_realtime(state.time_stats.created_time), + convert_time_to_realtime(state.time_stats.finished_time), + ) ) if len(self.dump_request_list) >= self.dump_requests_threshold: @@ -2029,9 +1939,14 @@ def dump_requests(self, state: ReqState, out_dict: dict): self.dump_request_list = [] def record_request_for_crash_dump(self, state: ReqState, out_dict: dict): - current_time = time.time() + current_time = real_time() self.crash_dump_request_list.append( - (state.obj, out_dict, state.created_time, current_time) + ( + state.obj, + out_dict, + convert_time_to_realtime(state.time_stats.created_time), + current_time, + ) ) # Remove requests older than 5 minutes based on finish time while ( @@ -2081,12 +1996,13 @@ def dump_requests_before_crash( unfinished_requests = [] for rid, state in self.rid_to_state.items(): if not state.finished: + state.time_stats.set_finished_time() unfinished_requests.append( ( state.obj, state.out_list[-1] if state.out_list else {}, - state.created_time, - time.time(), + convert_time_to_realtime(state.time_stats.created_time), + convert_time_to_realtime(state.time_stats.finished_time), ) ) if unfinished_requests: @@ -2163,7 +2079,7 @@ def _handle_abort_req(self, recv_obj: AbortReq): return state = self.rid_to_state[recv_obj.rid] state.finished = True - state.finished_time = time.time() + state.time_stats.set_finished_time() abort_message = recv_obj.abort_message or "Abort in waiting queue" finish_reason = { @@ -2176,7 +2092,7 @@ def _handle_abort_req(self, recv_obj: AbortReq): "id": recv_obj.rid, "finish_reason": finish_reason, "weight_version": self.server_args.weight_version, - "e2e_latency": state.finished_time - state.created_time, + "e2e_latency": state.time_stats.get_e2e_latency(), } is_stream = getattr(state.obj, "stream", False) if getattr(state.obj, "return_logprob", False): @@ -2288,53 +2204,58 @@ async def _resolve_lora_path(self, obj: Union[GenerateReqInput, EmbeddingReqInpu # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests. obj.lora_id = await self.lora_registry.acquire(obj.lora_path) - def _trace_request_start( + def _req_stats_init( self, obj: Union[GenerateReqInput, EmbeddingReqInput], - created_time: Optional[float] = None, request: Optional[fastapi.Request] = None, ): + calibrate_time_diff() + created_time = obj.received_time + external_trace_header = None - if request: - if "trace_context" in request.headers: - trace_set_remote_propagate_context(request.headers["trace_context"]) - else: + if self.server_args.enable_trace: + if request: external_trace_header = extract_trace_headers(request.headers) - elif obj.external_trace_header: - # When the request comes form the rust grpc server or Engine there isn't a - # real request object but we still need to propagate the trace context from - # the trace context that is explicitly passed in - external_trace_header = obj.external_trace_header - - if obj.is_single: - bootstrap_room = ( - obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None - ) - trace_req_start( - obj.rid, - bootstrap_room, - ts=int(created_time * 1e9), - role=self.server_args.disaggregation_mode, - external_trace_header=external_trace_header, - ) - trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True) - else: - for i in range(len(obj.rid)): + obj.external_trace_header = external_trace_header + elif obj.external_trace_header: + # When the request comes form the rust grpc server or Engine there isn't a + # real request object but we still need to propagate the trace context from + # the trace context that is explicitly passed in + external_trace_header = obj.external_trace_header + + if not hasattr(obj, "is_single") or obj.is_single: + time_stats = APIServerReqTimeStats(disagg_mode=self.disaggregation_mode) + state = ReqState([], False, asyncio.Event(), obj, time_stats) + self.rid_to_state[obj.rid] = state + + if self.server_args.enable_trace: bootstrap_room = ( - obj.bootstrap_room[i] - if hasattr(obj, "bootstrap_room") and obj.bootstrap_room - else None + obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None ) - trace_req_start( - obj.rid[i], + time_stats.init_trace_ctx( + obj.rid, bootstrap_room, - ts=int(created_time * 1e9), - role=self.server_args.disaggregation_mode, - external_trace_header=external_trace_header, - ) - trace_slice_start( - "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True + external_trace_header, ) + time_stats.set_created_time(created_time) + else: + for i in range(len(obj.rid)): + time_stats = APIServerReqTimeStats(disagg_mode=self.disaggregation_mode) + state = ReqState([], False, asyncio.Event(), obj[i], time_stats) + self.rid_to_state[obj.rid[i]] = state + + if self.server_args.enable_trace: + bootstrap_room = ( + obj.bootstrap_room[i] + if hasattr(obj, "bootstrap_room") and obj.bootstrap_room + else None + ) + time_stats.init_trace_ctx( + obj.rid[i], + bootstrap_room, + external_trace_header, + ) + time_stats.set_created_time(created_time) def _handle_epd_disaggregation_encode_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput] @@ -2361,7 +2282,7 @@ def convert_to_span_attrs( """Convert attributes to span attributes.""" span_attrs = {} - if not self.enable_trace: + if not self.server_args.enable_trace: return span_attrs # Token usage attributes @@ -2412,30 +2333,7 @@ def convert_to_span_attrs( ) # Latency attributes - if state.first_token_time and state.created_time: - span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN] = ( - state.first_token_time - state.created_time - ) - - if state.finished_time and state.created_time: - span_attrs[SpanAttributes.GEN_AI_LATENCY_E2E] = ( - state.finished_time - state.created_time - ) - - if state.first_token_time_perf and state.finished_time_perf: - span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE] = ( - state.finished_time_perf - state.first_token_time_perf - ) - - if state.request_sent_to_scheduler_ts and state.finished_time: - span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE] = ( - state.finished_time - state.request_sent_to_scheduler_ts - ) - - if state.request_sent_to_scheduler_ts and state.first_token_time: - span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL] = ( - state.first_token_time - state.request_sent_to_scheduler_ts - ) + span_attrs.update(state.time_stats.convert_to_gen_ai_span_attrs()) return span_attrs diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index a28adc751f17..a383c5aa2c63 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -17,7 +17,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool -from sglang.srt.metrics.collector import RadixCacheMetricsCollector +from sglang.srt.observability.metrics_collector import RadixCacheMetricsCollector if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index d7cd472a988c..b95ff3033009 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -38,7 +38,7 @@ compute_node_hash_values, split_node_hash_value, ) -from sglang.srt.metrics.collector import StorageMetricsCollector +from sglang.srt.observability.metrics_collector import StorageMetricsCollector from sglang.srt.utils import bind_to_closest_numa_node_cuda if TYPE_CHECKING: diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py index 5a5dd0b4ddf1..b202968429f3 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py @@ -19,7 +19,7 @@ ) from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient -from sglang.srt.metrics.collector import StorageMetrics +from sglang.srt.observability.metrics_collector import StorageMetrics logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py index 208e87c8ccd3..c28a80683ae6 100644 --- a/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +++ b/python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py @@ -17,7 +17,7 @@ HiCacheStorageExtraInfo, ) from sglang.srt.mem_cache.memory_pool_host import HostKVCache, HostTensorAllocator -from sglang.srt.metrics.collector import StorageMetrics +from sglang.srt.observability.metrics_collector import StorageMetrics DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB SETUP_TIMEOUT = 600 # 10min diff --git a/python/sglang/srt/metrics/cpu_monitor.py b/python/sglang/srt/observability/cpu_monitor.py similarity index 100% rename from python/sglang/srt/metrics/cpu_monitor.py rename to python/sglang/srt/observability/cpu_monitor.py diff --git a/python/sglang/srt/metrics/func_timer.py b/python/sglang/srt/observability/func_timer.py similarity index 98% rename from python/sglang/srt/metrics/func_timer.py rename to python/sglang/srt/observability/func_timer.py index 51d445ab44e2..a29fc1478fb1 100644 --- a/python/sglang/srt/metrics/func_timer.py +++ b/python/sglang/srt/observability/func_timer.py @@ -20,7 +20,7 @@ from functools import wraps from typing import Any, Callable, Optional -from sglang.srt.metrics.utils import exponential_buckets +from sglang.srt.observability.utils import exponential_buckets enable_metrics = False diff --git a/python/sglang/srt/metrics/label_transform.py b/python/sglang/srt/observability/label_transform.py similarity index 100% rename from python/sglang/srt/metrics/label_transform.py rename to python/sglang/srt/observability/label_transform.py diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/observability/metrics_collector.py similarity index 88% rename from python/sglang/srt/metrics/collector.py rename to python/sglang/srt/observability/metrics_collector.py index 255d41ccc037..cca592aca7c8 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/observability/metrics_collector.py @@ -20,10 +20,9 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union -from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs -from sglang.srt.metrics.utils import exponential_buckets, generate_buckets from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.observability.utils import exponential_buckets, generate_buckets from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var from sglang.srt.utils.gauge_histogram import GaugeHistogram @@ -48,155 +47,6 @@ def get_histogram_conf_from_env(env_var_name: str) -> Optional[List[float]]: return [float(x) for x in env_var_value.split(",")] -@dataclass -class TimeStats: - """ - Store the timestamps for each stage of a request. - - Unified: wait_queue -> forward -> completion - Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion - Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion - """ - - disagg_mode: DisaggregationMode = DisaggregationMode.NULL - lb_entry_time: float = 0.0 - wait_queue_entry_time: float = 0.0 - forward_entry_time: float = 0.0 - completion_time: float = 0.0 - prefill_bootstrap_queue_entry_time: float = 0.0 - prefill_transfer_queue_entry_time: float = 0.0 - decode_prealloc_queue_entry_time: float = 0.0 - decode_transfer_queue_entry_time: float = 0.0 - # TODO: correct set them - bootstrap_duration: float = 0.0 - alloc_waiting_duration: float = 0.0 - prefill_start_time_host: float = 0.0 - prefill_end_time_host: float = 0.0 - transfer_speed_gb_s: float = 0.0 - transfer_total_mb: float = 0.0 - # Number of prefill retries for this request - prefill_retry_count: int = 0 - - # Timestamp when prefill phase finishes, obtained from `time.time()`. - # Note that this differs from the other `_time` fields tracked by the - # `TimeStats` class, which are obtained from `time.perf_counter()`. - # We use `time.time()` instead of `time.perf_counter()` here in order to - # maintain unit consistency with other timestamp fields tracked by the `ReqState` class. - prefill_finished_ts: float = 0.0 - - def get_queueing_time(self) -> float: - return self.forward_entry_time - self.wait_queue_entry_time - - def get_prefill_launch_delay(self) -> Optional[float]: - if self.prefill_start_time_host > 0.0: - return self.prefill_start_time_host - self.forward_entry_time - return None - - def get_prefill_launch_latency(self) -> Optional[float]: - if self.prefill_start_time_host > 0.0 and self.prefill_end_time_host > 0.0: - return self.prefill_end_time_host - self.prefill_start_time_host - return None - - def get_prefill_finished_ts(self) -> Optional[float]: - if self.prefill_finished_ts > 0.0: - return self.prefill_finished_ts - return None - - def convert_to_duration(self) -> str: - if self.disagg_mode == DisaggregationMode.NULL: - queue_duration = self.forward_entry_time - self.wait_queue_entry_time - forward_duration = self.completion_time - self.forward_entry_time - - if SGLANG_TEST_REQUEST_TIME_STATS: - assert ( - queue_duration >= 0 and forward_duration >= 0 - ), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" - - return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}" - elif self.disagg_mode == DisaggregationMode.PREFILL: - bootstrap_duration = ( - self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time - ) - queue_duration = self.forward_entry_time - self.wait_queue_entry_time - forward_duration = self.completion_time - self.forward_entry_time - - if SGLANG_TEST_REQUEST_TIME_STATS: - if self.wait_queue_entry_time > 0: - assert ( - bootstrap_duration >= 0 - and queue_duration >= 0 - and forward_duration >= 0 - ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" - - other = max( - 0.0, - bootstrap_duration - - (self.alloc_waiting_duration + self.bootstrap_duration), - ) - return ( - f"bootstrap_queue_duration({self.format_duration(bootstrap_duration)}) " - f"= alloc_wait({self.format_duration(self.alloc_waiting_duration)}) " - f"+ bootstrap({self.format_duration(self.bootstrap_duration)}) " - f"+ other({self.format_duration(other)}); " - f"queue_duration={self.format_duration(queue_duration)}, " - f"forward_duration={self.format_duration(forward_duration)}, " - f"start={self.prefill_bootstrap_queue_entry_time:.3f}, " - f"transfer_speed={self.transfer_speed_gb_s:.2f}GB/s, " - f"transfer_total={self.transfer_total_mb:.2f}MB, " - f"#retries={self.prefill_retry_count}" - ) - elif self.disagg_mode == DisaggregationMode.DECODE: - prealloc_duration = ( - self.decode_transfer_queue_entry_time - - self.decode_prealloc_queue_entry_time - ) - transfer_duration = ( - self.wait_queue_entry_time - self.decode_transfer_queue_entry_time - ) - queue_duration = self.forward_entry_time - self.wait_queue_entry_time - forward_duration = self.completion_time - self.forward_entry_time - - if SGLANG_TEST_REQUEST_TIME_STATS: - if self.wait_queue_entry_time > 0: - assert ( - prealloc_duration >= 0 - and transfer_duration >= 0 - and queue_duration >= 0 - and forward_duration >= 0 - ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}" - - other = max( - 0.0, - prealloc_duration - - (self.alloc_waiting_duration + self.bootstrap_duration), - ) - return ( - f"prealloc_queue_duration({self.format_duration(prealloc_duration)}) " - f"= alloc_wait({self.format_duration(self.alloc_waiting_duration)}) " - f"+ bootstrap({self.format_duration(self.bootstrap_duration)}) " - f"+ other({self.format_duration(other)}); " - f"transfer_duration={self.format_duration(transfer_duration)}; " - f"queue_duration={self.format_duration(queue_duration)}, " - f"forward_duration={self.format_duration(forward_duration)}, " - f"start={self.decode_prealloc_queue_entry_time:.3f}" - ) - else: - return "Unknown Time Stats" - - def format_duration(self, duration: float) -> str: - return f"{duration * 1e3:.2f}ms" - - def disagg_mode_str(self) -> str: - if self.disagg_mode == DisaggregationMode.NULL: - return "unified" - elif self.disagg_mode == DisaggregationMode.DECODE: - return "decode" - elif self.disagg_mode == DisaggregationMode.PREFILL: - return "prefill" - else: - return "unknown" - - @dataclass class SchedulerStats: # Basics diff --git a/python/sglang/srt/observability/req_time_stats.py b/python/sglang/srt/observability/req_time_stats.py new file mode 100644 index 000000000000..bb859aefd6cf --- /dev/null +++ b/python/sglang/srt/observability/req_time_stats.py @@ -0,0 +1,971 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Request Time Stats.""" + +from __future__ import annotations + +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.observability.metrics_collector import ( + SchedulerMetricsCollector, + TokenizerMetricsCollector, +) +from sglang.srt.observability.trace import ( + SpanAttributes, + TraceNullContext, + TraceReqContext, + TraceSliceContext, + get_global_tracing_enabled, +) +from sglang.srt.utils import get_bool_env_var + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import ScheduleBatch + +SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS") + + +logger = logging.getLogger(__name__) + +# Reduce system time calls by computing time.time() based on calibrated perf_counter() values. +global_diff_realtime_monotonic = time.time() - time.perf_counter() + + +def calibrate_time_diff(): + # due to NTP, the diff between time.time() and time.perf_counter() can change + # periodically calibrate the diff + global global_diff_realtime_monotonic + global_diff_realtime_monotonic = time.time() - time.perf_counter() + + +def real_time(): + return time.time() + + +def monotonic_time(): + return time.perf_counter() + + +def convert_time_to_realtime(time_value: float) -> float: + # note: Within the time scale of a single request's latency, + # we assume that the diff does not change significantly. + return time_value + global_diff_realtime_monotonic + + +def convert_time_to_realtime_ns(time_value: float) -> int: + return int((time_value + global_diff_realtime_monotonic) * 1e9) + + +def convert_time_cross_thread( + time_value: float, old_diff: float, new_diff: float +) -> float: + # note: precision loss + return time_value + old_diff - new_diff + + +@dataclass +class RequestStageConfig: + stage_name: str + level: int = 0 + # whether to call metrics_collector.observe_per_stage_req_latency + metrics_is_observed: bool = False + + +class RequestStage: + # Tokenizer/gRPC Server + TOKENIZE = RequestStageConfig( + "tokenize", + level=1, + ) + API_SERVER_DISPATCH = RequestStageConfig( + "dispatch", + level=2, + ) + + # DP controller + DC_DISPATCH = RequestStageConfig( + "dc_dispatch", + level=2, + ) + + # common/non-disaggregation + REQUEST_PROCESS = RequestStageConfig( + "request_process", + level=2, + metrics_is_observed=True, + ) + PREFILL_WAITING = RequestStageConfig( + "prefill_waiting", + level=1, + # equal to "observe_queue_time" + metrics_is_observed=False, + ) + DECODE_FORWARD = RequestStageConfig( + "decode_forward", + level=1, + ) + DECODE_LOOP = RequestStageConfig( + "decode_loop", + level=3, + ) + PREFILL_FORWARD = RequestStageConfig( + "prefill_forward", + level=1, + metrics_is_observed=True, + ) + PREFILL_CHUNKED_FORWARD = RequestStageConfig( + "chunked_prefill", + level=3, + metrics_is_observed=True, + ) + + # disaggregation prefill + PREFILL_PREPARE = RequestStageConfig( + "prefill_prepare", + level=1, + ) + PREFILL_BOOTSTRAP = RequestStageConfig( + "prefill_bootstrap", + level=1, + metrics_is_observed=True, + ) + PREFILL_TRANSFER_KV_CACHE = RequestStageConfig( + "prefill_transfer_kv_cache", + level=1, + metrics_is_observed=True, + ) + + # disaggregation decode + DECODE_PREPARE = RequestStageConfig( + "decode_prepare", + level=1, + metrics_is_observed=True, + ) + DECODE_BOOTSTRAP = RequestStageConfig( + "decode_bootstrap", + level=1, + metrics_is_observed=True, + ) + DECODE_WAITING = RequestStageConfig( + "decode_waiting", + level=1, + metrics_is_observed=True, + ) + DECODE_TRANSFERRED = RequestStageConfig( + "decode_transferred", + level=1, + metrics_is_observed=True, + ) + DECODE_FAKE_OUTPUT = RequestStageConfig( + "fake_output", + level=3, + metrics_is_observed=True, + ) + DECODE_QUICK_FINISH = RequestStageConfig( + "quick_finish", + level=1, + metrics_is_observed=True, + ) + + # mini lb + MINI_LB_LAUNCH = RequestStageConfig( + "mini_lb_launch", + level=1, + ) + + WAIT_PD_FINISH = RequestStageConfig( + "wait_pd_finish", + level=2, + ) + + # other + ANONYMOUS = RequestStageConfig("") + + +@dataclass +class ReqTimeStatsBase: + enable_metrics: bool = False + metrics_collector: Optional[ + Union[SchedulerMetricsCollector, TokenizerMetricsCollector] + ] = None + trace_ctx: Union[TraceReqContext, TraceNullContext] = field( + default_factory=TraceNullContext + ) + disagg_mode: DisaggregationMode = DisaggregationMode.NULL + diff_realtime_monotonic: float = 0.0 + + @classmethod + def new_from_obj(cls, obj: ReqTimeStatsBase, *args, **kwargs) -> "ReqTimeStatsBase": + calibrate_time_diff() + new_obj = cls(*args, **kwargs) + if obj is None: + return new_obj + for key, value in obj.__dict__.items(): + if hasattr(new_obj, key): + setattr(new_obj, key, value) + + new_obj.trace_ctx.rebuild_thread_context() + + return new_obj + + def disagg_mode_str(self) -> str: + if self.disagg_mode == DisaggregationMode.NULL: + return "unified" + elif self.disagg_mode == DisaggregationMode.DECODE: + return "decode" + elif self.disagg_mode == DisaggregationMode.PREFILL: + return "prefill" + else: + return "unknown" + + def set_metrics_collector( + self, collector: Union[SchedulerMetricsCollector, TokenizerMetricsCollector] + ): + if collector: + self.enable_metrics = True + self.metrics_collector = collector + + def observe_per_stage_req_latency(self, stage: RequestStageConfig, latency: float): + if self.enable_metrics and stage.metrics_is_observed: + self.metrics_collector.observe_per_stage_req_latency( + stage.stage_name, latency + ) + + def init_trace_ctx( + self, + rid: str, + bootstrap_room: Optional[int], + external_trace_header: Optional[Dict[str, str]] = None, + ): + self.trace_ctx = TraceReqContext( + rid=rid, + bootstrap_room=bootstrap_room, + role=self.disagg_mode_str(), + module_name="request", + external_trace_header=external_trace_header, + ) + + if not self.trace_ctx.tracing_enable: + self.trace_ctx = TraceNullContext() + + def trace_slice( + self, + stage: RequestStageConfig, + start_time: float, + end_time: float, + attrs: Optional[Dict] = None, + ): + if self.trace_ctx.tracing_enable: + _slice = TraceSliceContext( + slice_name=stage.stage_name, + start_time_ns=convert_time_to_realtime_ns(start_time), + end_time_ns=convert_time_to_realtime_ns(end_time), + level=stage.level, + attrs=attrs, + ) + self.trace_ctx.trace_slice(_slice) + + def __getstate__(self) -> object: + # The object is propagated to other processes via serialization and deserialization methods, + # requiring the metric collector to be reconfigured. + return { + "disagg_mode": self.disagg_mode, + "enable_metrics": False, + "trace_ctx": self.trace_ctx, + "diff_realtime_monotonic": global_diff_realtime_monotonic, + } + + def __setstate__(self, state: object): + for key in state.keys(): + if key.endswith("time"): + state[key] = convert_time_cross_thread( + state[key], + state["diff_realtime_monotonic"], + global_diff_realtime_monotonic, + ) + self.__dict__.update(state) + + +@dataclass +class APIServerReqTimeStats(ReqTimeStatsBase): + # get by time.perf_counter() + created_time: float = 0.0 + finished_time: float = 0.0 + first_token_time: float = 0.0 + last_time: float = 0.0 + tokenize_finish_time: float = 0.0 + api_server_dispatch_time: float = 0.0 + api_server_dispatch_finish_time: float = 0.0 + response_sent_to_client_time: float = 0.0 + + def __getstate__(self) -> object: + state = {} + # send to DP controller or Scheduler + # If necessary, can propagate the timestamp here, for example: + # state = { + # "created_time": self.created_time, + # "api_server_dispatch_time": self.api_server_dispatch_time, + # } + state.update(super().__getstate__()) + return state + + def set_created_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.created_time = ts + + self.trace_ctx.trace_req_start(convert_time_to_realtime_ns(ts)) + + def set_finished_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.finished_time = ts + + self.trace_ctx.trace_req_finish(convert_time_to_realtime_ns(ts)) + + def set_first_token_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.first_token_time = ts + self.last_time = ts + + def set_last_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.last_time = ts + + def set_tokenize_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.tokenize_finish_time = ts + + stage = RequestStage.TOKENIZE + self.trace_slice(stage, self.created_time, ts) + + def set_api_server_dispatch_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.api_server_dispatch_time = ts + + self.trace_ctx.trace_slice_start( + RequestStage.API_SERVER_DISPATCH.stage_name, + RequestStage.API_SERVER_DISPATCH.level, + convert_time_to_realtime_ns(ts), + ) + + def set_api_server_dispatch_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.api_server_dispatch_finish_time = ts + + self.trace_ctx.trace_slice_end( + RequestStage.API_SERVER_DISPATCH.stage_name, + RequestStage.API_SERVER_DISPATCH.level, + convert_time_to_realtime_ns(ts), + thread_finish_flag=True, + ) + + def set_response_sent_to_client_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.response_sent_to_client_time = ts + + def get_interval(self): + return time.perf_counter() - self.last_time + + def get_first_token_latency(self): + return self.first_token_time - self.created_time + + def get_e2e_latency(self): + return self.finished_time - self.created_time + + def get_decode_latency(self): + return self.finished_time - self.first_token_time + + def get_response_sent_to_client_realtime(self): + return convert_time_to_realtime(self.response_sent_to_client_time) + + def convert_to_output_meta_info( + self, scheduler_time_stats=None, completion_tokens=0 + ): + meta_info = {} + if self.created_time > 0.0: + meta_info["request_received_ts"] = convert_time_to_realtime( + self.created_time + ) + if self.api_server_dispatch_finish_time > 0.0: + meta_info["api_server_dispatch_finish_ts"] = convert_time_to_realtime( + self.api_server_dispatch_finish_time + ) + if self.response_sent_to_client_time > 0.0: + meta_info["response_sent_to_client_ts"] = convert_time_to_realtime( + self.response_sent_to_client_time + ) + if self.finished_time > 0.0: + meta_info["request_finished_ts"] = convert_time_to_realtime( + self.finished_time + ) + + if ( + scheduler_time_stats + and hasattr(scheduler_time_stats, "forward_entry_time") + and self.finished_time > 0.0 + ): + meta_info["inference_time"] = ( + self.finished_time - scheduler_time_stats.forward_entry_time + ) + + decode_latency = self.get_decode_latency() + if decode_latency > 0.0 and completion_tokens > 0: + meta_info["decode_throughput"] = completion_tokens / decode_latency + return meta_info + + def convert_to_gen_ai_span_attrs(self): + span_attrs = {} + if self.first_token_time and self.created_time: + span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN] = ( + self.first_token_time - self.created_time + ) + + if self.finished_time and self.created_time: + span_attrs[SpanAttributes.GEN_AI_LATENCY_E2E] = ( + self.finished_time - self.created_time + ) + + if self.first_token_time and self.finished_time: + span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE] = ( + self.finished_time - self.first_token_time + ) + + if self.api_server_dispatch_finish_time and self.finished_time: + span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE] = ( + self.finished_time - self.api_server_dispatch_finish_time + ) + + if self.api_server_dispatch_finish_time and self.first_token_time: + span_attrs[SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL] = ( + self.first_token_time - self.api_server_dispatch_finish_time + ) + + return span_attrs + + +@dataclass +class DPControllerReqTimeStats(ReqTimeStatsBase): + # propagated from tokenizer/grpc_server, get by time.perf_counter() + created_time: float = 0.0 + api_server_dispatch_time: float = 0.0 + + # new timestamp, get by time.perf_counter() + dc_dispatch_time: float = 0.0 + dc_dispatch_finish_time: float = 0.0 + + def __getstate__(self) -> object: + state = {} + # send to Scheduler + # If necessary, can propagate the timestamp here, for example: + # state = { + # "created_time": self.created_time, + # "api_server_dispatch_time": self.api_server_dispatch_time, + # "dc_dispatch_time": self.dc_dispatch_time, + # } + state.update(super().__getstate__()) + return state + + def set_dp_dispatch_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.dc_dispatch_time = ts + + self.trace_ctx.trace_slice_start( + RequestStage.DC_DISPATCH.stage_name, + RequestStage.DC_DISPATCH.level, + convert_time_to_realtime_ns(ts), + ) + + def set_dp_dispatch_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.dc_dispatch_finish_time = ts + + self.trace_ctx.trace_slice_end( + RequestStage.DC_DISPATCH.stage_name, + RequestStage.DC_DISPATCH.level, + convert_time_to_realtime_ns(ts), + thread_finish_flag=True, + ) + + +@dataclass +class SchedulerReqTimeStats(ReqTimeStatsBase): + """ + Store the timestamps for each stage of a request. + + Unified: wait_queue -> forward -> completion + Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion + Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion + """ + + # Placeholder: not used currently + # propagated from tokenizer/grpc_server or dp controller + created_time: float = 0.0 + api_server_dispatch_time: float = 0.0 + dc_dispatch_time: float = 0.0 + + # common, get by time.perf_counter() + wait_queue_entry_time: float = 0.0 + forward_entry_time: float = 0.0 + prefill_run_batch_start_time: float = 0.0 + prefill_run_batch_end_time: float = 0.0 + prefill_finished_time: float = 0.0 + completion_time: float = 0.0 + + # prefill node, get by time.perf_counter() + prefill_bootstrap_queue_entry_time: float = 0.0 + prefill_transfer_queue_entry_time: float = 0.0 + prefill_kv_transfer_finish_time: float = 0.0 + + # decode node, get by time.perf_counter() + decode_prealloc_queue_entry_time: float = 0.0 + decode_transfer_queue_entry_time: float = 0.0 + decode_prebuilt_finish_time: float = 0.0 + + # only for request tracing + scheduler_recv_time: float = 0.0 + last_chunked_prefill_finish_time: float = 0.0 + last_decode_finish_time: float = 0.0 + decode_ct: int = 0 + last_decode_scheduled_time: float = 0.0 + last_forward_entry_time: float = 0.0 + last_prefill_finished_time: float = 0.0 + + # other + transfer_speed_gb_s: float = 0.0 + transfer_total_mb: float = 0.0 + # Number of prefill retries for this request + prefill_retry_count: int = 0 + + def __getstate__(self) -> object: + # send to detokenizer/tokenizer + if not self.enable_metrics: + return {} + + state = { + "wait_queue_entry_time": self.wait_queue_entry_time, + "forward_entry_time": self.forward_entry_time, + "prefill_run_batch_start_time": self.prefill_run_batch_start_time, + "prefill_run_batch_end_time": self.prefill_run_batch_end_time, + "prefill_finished_time": self.prefill_finished_time, + "diff_realtime_monotonic": global_diff_realtime_monotonic, + } + return state + + def set_scheduler_recv_time(self, ts=None): + calibrate_time_diff() + if ts is None: + ts = time.perf_counter() + self.scheduler_recv_time = ts + + def set_retract_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + # retract + self.last_forward_entry_time = 0.0 + self.last_prefill_finished_time = 0.0 + self.last_chunked_prefill_finish_time = 0.0 + self.last_decode_finish_time = 0.0 + self.last_decode_scheduled_time = 0.0 + + self.trace_ctx.trace_event("retract", 1, convert_time_to_realtime_ns(ts)) + + def set_wait_queue_entry_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + if self.wait_queue_entry_time == 0.0: + if self.enable_metrics or self.trace_ctx.tracing_enable: + if self.disagg_mode == DisaggregationMode.PREFILL: + stage = RequestStage.PREFILL_BOOTSTRAP + slice_start_time = self.prefill_bootstrap_queue_entry_time + elif self.disagg_mode == DisaggregationMode.DECODE: + stage = RequestStage.DECODE_TRANSFERRED + slice_start_time = self.decode_transfer_queue_entry_time + else: + stage = RequestStage.REQUEST_PROCESS + slice_start_time = self.scheduler_recv_time + + self.observe_per_stage_req_latency(stage, ts - slice_start_time) + self.trace_slice(stage, slice_start_time, ts) + else: + self.set_retract_time(ts) + + self.wait_queue_entry_time = ts + + def set_forward_entry_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + if self.forward_entry_time == 0.0: + self.forward_entry_time = ts + self.last_forward_entry_time = ts + + if self.enable_metrics: + self.metrics_collector.observe_queue_time(self.get_queueing_time()) + + if self.enable_metrics or self.trace_ctx.tracing_enable: + if self.disagg_mode == DisaggregationMode.DECODE: + stage = RequestStage.DECODE_WAITING + else: + stage = RequestStage.PREFILL_WAITING + slice_start_time = self.wait_queue_entry_time + + self.observe_per_stage_req_latency(stage, ts - slice_start_time) + self.trace_slice(stage, slice_start_time, ts) + + if self.disagg_mode == DisaggregationMode.DECODE: + self.trace_ctx.trace_slice_start( + RequestStage.DECODE_FORWARD.stage_name, + RequestStage.DECODE_FORWARD.level, + convert_time_to_realtime_ns(ts), + ) + else: + self.trace_ctx.trace_slice_start( + RequestStage.PREFILL_FORWARD.stage_name, + RequestStage.PREFILL_FORWARD.level, + convert_time_to_realtime_ns(ts), + ) + elif self.last_forward_entry_time == 0.0: + self.last_forward_entry_time = ts + + def set_prefill_run_batch_start_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.prefill_run_batch_start_time = ts + + def set_prefill_run_batch_end_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.prefill_run_batch_end_time = ts + + def set_last_chunked_prefill_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + last_time = self.last_chunked_prefill_finish_time + self.last_chunked_prefill_finish_time = ts + + if last_time == 0.0: + last_time = self.last_forward_entry_time + + stage = RequestStage.PREFILL_CHUNKED_FORWARD + self.observe_per_stage_req_latency(stage, ts - last_time) + self.trace_slice(stage, last_time, ts) + + def set_prefill_finished_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + if self.prefill_finished_time == 0.0: + self.prefill_finished_time = ts + self.last_prefill_finished_time = ts + + stage = RequestStage.PREFILL_FORWARD + self.observe_per_stage_req_latency(stage, ts - self.last_forward_entry_time) + + if self.trace_ctx.tracing_enable: + if self.last_chunked_prefill_finish_time > 0: + self.trace_slice( + RequestStage.PREFILL_CHUNKED_FORWARD, + self.last_chunked_prefill_finish_time, + ts, + ) + + self.trace_ctx.trace_slice_end( + stage.stage_name, stage.level, convert_time_to_realtime_ns(ts) + ) + if ( + self.disagg_mode == DisaggregationMode.NULL + and self.last_decode_scheduled_time > 0 + ): + self.trace_ctx.trace_slice_start( + RequestStage.DECODE_FORWARD.stage_name, + RequestStage.DECODE_FORWARD.level, + convert_time_to_realtime_ns(ts), + ) + elif self.last_prefill_finished_time == 0.0: + # retract + self.last_prefill_finished_time = ts + if self.last_chunked_prefill_finish_time > 0: + self.trace_slice( + RequestStage.PREFILL_CHUNKED_FORWARD, + self.last_chunked_prefill_finish_time, + ts, + ) + else: + self.trace_slice( + RequestStage.PREFILL_FORWARD, self.last_forward_entry_time, ts + ) + + def set_last_decode_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + last_time = self.last_decode_finish_time + self.last_decode_finish_time = ts + + if self.enable_metrics or self.trace_ctx.tracing_enable: + if last_time == 0.0: + if self.disagg_mode == DisaggregationMode.DECODE: + last_time = self.decode_prebuilt_finish_time + else: + if ( + self.last_decode_scheduled_time + < self.last_prefill_finished_time + ): + last_time = self.last_prefill_finished_time + else: + last_time = self.last_decode_scheduled_time + stage = RequestStage.DECODE_LOOP + self.observe_per_stage_req_latency(stage, ts - last_time) + attrs = {"decode_ct": self.decode_ct} + self.trace_slice(stage, last_time, ts, attrs) + self.decode_ct += 1 + + def set_last_scheduled_time(self, forward_mode: ForwardMode, ts=None, attrs=None): + if ts is None: + ts = time.perf_counter() + + if self.trace_ctx.tracing_enable: + if ( + self.disagg_mode == DisaggregationMode.NULL + and forward_mode.is_decode() + and self.last_decode_scheduled_time == 0.0 + and self.last_prefill_finished_time > 0 + ): + self.trace_slice( + RequestStage.DECODE_WAITING, self.last_prefill_finished_time, ts + ) + self.trace_ctx.trace_slice_start( + RequestStage.DECODE_FORWARD.stage_name, + RequestStage.DECODE_FORWARD.level, + convert_time_to_realtime_ns(ts), + ) + self.last_decode_finish_time = ts + + self.trace_ctx.trace_event( + "schedule", 3, convert_time_to_realtime_ns(ts), attrs + ) + + if forward_mode.is_decode(): + self.last_decode_scheduled_time = ts + + def set_completion_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.completion_time = ts + + self.trace_ctx.abort() + + def set_quick_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.set_completion_time(ts) + self.forward_entry_time = ts + + def set_prefill_bootstrap_queue_entry_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.prefill_bootstrap_queue_entry_time = ts + + stage = RequestStage.PREFILL_PREPARE + self.observe_per_stage_req_latency(stage, ts - self.scheduler_recv_time) + self.trace_slice(stage, self.scheduler_recv_time, ts) + + def set_prefill_transfer_queue_entry_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.prefill_transfer_queue_entry_time = ts + + def set_prefill_kv_transfer_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.prefill_kv_transfer_finish_time = ts + + stage = RequestStage.PREFILL_TRANSFER_KV_CACHE + self.observe_per_stage_req_latency( + stage, ts - self.prefill_transfer_queue_entry_time + ) + self.trace_slice(stage, self.prefill_transfer_queue_entry_time, ts) + + def set_decode_prealloc_queue_entry_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.decode_prealloc_queue_entry_time = ts + + stage = RequestStage.DECODE_PREPARE + self.observe_per_stage_req_latency(stage, ts - self.scheduler_recv_time) + self.trace_slice(stage, self.scheduler_recv_time, ts) + + def set_decode_transfer_queue_entry_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.decode_transfer_queue_entry_time = ts + + stage = RequestStage.DECODE_BOOTSTRAP + self.observe_per_stage_req_latency( + stage, ts - self.decode_prealloc_queue_entry_time + ) + self.trace_slice(stage, self.decode_prealloc_queue_entry_time, ts) + + def set_decode_prebuilt_finish_time(self, ts=None): + if ts is None: + ts = time.perf_counter() + self.decode_prebuilt_finish_time = ts + + stage = RequestStage.DECODE_FAKE_OUTPUT + self.observe_per_stage_req_latency(stage, ts - self.last_forward_entry_time) + self.trace_slice(stage, self.last_forward_entry_time, ts) + + def get_queueing_time(self) -> float: + return self.forward_entry_time - self.wait_queue_entry_time + + def get_prefill_waiting_latency(self) -> Optional[float]: + if self.prefill_run_batch_start_time > 0.0: + return self.prefill_run_batch_start_time - self.forward_entry_time + return None + + def get_prefill_launch_latency(self) -> Optional[float]: + if ( + self.prefill_run_batch_start_time > 0.0 + and self.prefill_run_batch_end_time > 0.0 + ): + return self.prefill_run_batch_end_time - self.prefill_run_batch_start_time + return None + + def convert_to_duration(self) -> str: + if self.disagg_mode == DisaggregationMode.NULL: + queue_duration = self.forward_entry_time - self.wait_queue_entry_time + forward_duration = self.completion_time - self.forward_entry_time + + if SGLANG_TEST_REQUEST_TIME_STATS: + assert ( + queue_duration >= 0 and forward_duration >= 0 + ), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" + + return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time:.3f}" + elif self.disagg_mode == DisaggregationMode.PREFILL: + bootstrap_duration = ( + self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time + ) + queue_duration = self.forward_entry_time - self.wait_queue_entry_time + forward_duration = self.completion_time - self.forward_entry_time + + if SGLANG_TEST_REQUEST_TIME_STATS: + if self.wait_queue_entry_time > 0: + assert ( + bootstrap_duration >= 0 + and queue_duration >= 0 + and forward_duration >= 0 + ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" + + return ( + f"bootstrap_queue_duration({self.format_duration(bootstrap_duration)}) " + f"queue_duration={self.format_duration(queue_duration)}, " + f"forward_duration={self.format_duration(forward_duration)}, " + f"start={self.prefill_bootstrap_queue_entry_time:.3f}, " + f"transfer_speed={self.transfer_speed_gb_s:.2f}GB/s, " + f"transfer_total={self.transfer_total_mb:.2f}MB, " + f"#retries={self.prefill_retry_count}" + ) + elif self.disagg_mode == DisaggregationMode.DECODE: + prealloc_duration = ( + self.decode_transfer_queue_entry_time + - self.decode_prealloc_queue_entry_time + ) + transfer_duration = ( + self.wait_queue_entry_time - self.decode_transfer_queue_entry_time + ) + queue_duration = self.forward_entry_time - self.wait_queue_entry_time + forward_duration = self.completion_time - self.forward_entry_time + + if SGLANG_TEST_REQUEST_TIME_STATS: + if self.wait_queue_entry_time > 0: + assert ( + prealloc_duration >= 0 + and transfer_duration >= 0 + and queue_duration >= 0 + and forward_duration >= 0 + ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0. {self=}" + + return ( + f"prealloc_queue_duration({self.format_duration(prealloc_duration)}) " + f"transfer_duration={self.format_duration(transfer_duration)}; " + f"queue_duration={self.format_duration(queue_duration)}, " + f"forward_duration={self.format_duration(forward_duration)}, " + f"start={self.decode_prealloc_queue_entry_time:.3f}" + ) + else: + return "Unknown Time Stats" + + def convert_to_output_meta_info(self): + meta_data = {} + if self.forward_entry_time > 0.0: + meta_data["forward_entry_time"] = convert_time_to_realtime( + self.forward_entry_time + ) + if self.prefill_finished_time > 0.0: + meta_data["prefill_finished_time"] = convert_time_to_realtime( + self.prefill_finished_time + ) + meta_data.update( + { + "queue_time": self.get_queueing_time(), + "prefill_waiting_latency": self.get_prefill_waiting_latency(), + "prefill_launch_latency": self.get_prefill_launch_latency(), + } + ) + return meta_data + + def format_duration(self, duration: float) -> str: + return f"{duration * 1e3:.2f}ms" + + +def set_schedule_time_batch(batch: ScheduleBatch): + # only for tracing + if not get_global_tracing_enabled(): + return + + ts = time.perf_counter() + bid = uuid.uuid4().hex[:8] + _attrs = {"bid": bid, "batch_size": len(batch.reqs)} + if batch.forward_mode.is_decode(): + _attrs["forward_mode"] = "decode" + elif batch.forward_mode.is_prefill(): + _attrs["forward_mode"] = "prefill" + elif batch.forward_mode.is_prebuilt(): + _attrs["forward_mode"] = "prebuilt" + + for req in batch.reqs: + req.time_stats.set_last_scheduled_time(batch.forward_mode, ts, _attrs) + + +def set_time_batch(reqs: List[Any], set_func: str): + if reqs is None or len(reqs) == 0: + return + + ts = time.perf_counter() + for req in reqs: + method = getattr(req.time_stats, set_func) + method(ts) diff --git a/python/sglang/srt/managers/request_metrics_exporter.py b/python/sglang/srt/observability/request_metrics_exporter.py similarity index 100% rename from python/sglang/srt/managers/request_metrics_exporter.py rename to python/sglang/srt/observability/request_metrics_exporter.py diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/observability/scheduler_metrics_mixin.py similarity index 99% rename from python/sglang/srt/managers/scheduler_metrics_mixin.py rename to python/sglang/srt/observability/scheduler_metrics_mixin.py index 30b2732b9f5a..46899179508c 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/observability/scheduler_metrics_mixin.py @@ -23,7 +23,7 @@ ) from sglang.srt.managers.scheduler import ScheduleBatch from sglang.srt.managers.utils import GenerationBatchResult -from sglang.srt.metrics.collector import ( +from sglang.srt.observability.metrics_collector import ( DPCooperationInfo, SchedulerMetricsCollector, SchedulerStats, diff --git a/python/sglang/srt/metrics/startup_func_log_and_timer.py b/python/sglang/srt/observability/startup_func_log_and_timer.py similarity index 100% rename from python/sglang/srt/metrics/startup_func_log_and_timer.py rename to python/sglang/srt/observability/startup_func_log_and_timer.py diff --git a/python/sglang/srt/observability/trace.py b/python/sglang/srt/observability/trace.py new file mode 100644 index 000000000000..5f7139c6ee93 --- /dev/null +++ b/python/sglang/srt/observability/trace.py @@ -0,0 +1,701 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""package for sglang requests tracing""" + +from __future__ import annotations + +import logging +import os +import random +import threading +import time +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Optional + +from sglang.srt.utils import get_int_env_var + +logger = logging.getLogger(__name__) +opentelemetry_imported = False +opentelemetry_initialized = False +_trace_context_propagator = None +tracer: Optional[trace.Tracer] = None + +global_trace_level = 3 + +TRACE_HEADERS = ["traceparent", "tracestate"] + +try: + from opentelemetry import context, propagate, trace + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter as GRPCSpanExporter, + ) + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter as HTTPSpanExporter, + ) + from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, + ) + from opentelemetry.sdk.resources import SERVICE_NAME, Resource + from opentelemetry.sdk.trace import TracerProvider, id_generator + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.trace import Status, StatusCode + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + _trace_context_propagator = TraceContextTextMapPropagator() + + opentelemetry_imported = True +except ImportError: + + class id_generator: + class IdGenerator: + pass + + logger.debug("opentelemetry package is not installed, tracing disabled") + + +def extract_trace_headers(headers: Mapping[str, str]) -> Optional[Dict]: + return {h: headers[h] for h in TRACE_HEADERS if h in headers} + + +def set_global_trace_level(level: int): + global global_trace_level + global_trace_level = level + + +@dataclass +class TraceThreadInfo: + host_id: str + pid: int + thread_label: str + tp_rank: int + dp_rank: int + + +@dataclass +class TraceEvent: + event_name: str + ts: int + attrs: Dict[str, Any] + + +@dataclass +class TraceSliceContext: + slice_name: str + start_time_ns: int + end_time_ns: Optional[int] = None + span: Optional[trace.span.Span] = None + level: int = 1 + attrs: Optional[Dict[str, Any]] = None + events: Optional[List[TraceEvent]] = None + + +@dataclass +class TraceThreadContext: + thread_info: TraceThreadInfo + cur_slice_stack: Optional[List[TraceSliceContext]] = None + thread_span: Optional[trace.span.Span] = None + + +class TraceCustomIdGenerator(id_generator.IdGenerator): + """ + The default IdGenerator may produce duplicate trace IDs across multiple TP scheduler processes, + hence a custom IdGenerator is implemented. + """ + + def __init__(self): + super().__init__() + self.local_random = random.Random() + self.local_random.seed(time.time()) + + def generate_trace_id(self) -> int: + return self.local_random.getrandbits(64) + + def generate_span_id(self) -> int: + return self.local_random.getrandbits(64) + + +# global variables +threads_info: Dict[int, TraceThreadInfo] = {} + +get_cur_time_ns = lambda: int(time.time() * 1e9) +if hasattr(time, "time_ns"): + get_cur_time_ns = lambda: int(time.time_ns()) + + +def __get_host_id() -> str: + """ + In distributed tracing systems, obtain a unique node identifier + and inject it into all subsequently generated spans + to prevent PID conflicts between threads on different nodes. + """ + if os.path.exists("/etc/machine-id"): + try: + with open("/etc/machine-id", "r") as f: + return f.read().strip() + except: + pass + + mac = uuid.getnode() + if mac != 0: + return uuid.UUID(int=mac).hex + + return "unknown" + + +# Should be called by each tracked process. +def process_tracing_init(otlp_endpoint, server_name): + global opentelemetry_initialized + global get_cur_time_ns + global tracer + if not opentelemetry_imported: + opentelemetry_initialized = False + raise RuntimeError( + "opentelemetry package is not installed!!! Please not enable tracing or install opentelemetry" + ) + + try: + resource = Resource.create( + attributes={ + SERVICE_NAME: server_name, + } + ) + tracer_provider = TracerProvider( + resource=resource, id_generator=TraceCustomIdGenerator() + ) + + schedule_delay_millis = get_int_env_var( + "SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", 500 + ) + max_export_batch_size = get_int_env_var( + "SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", 64 + ) + + processor = BatchSpanProcessor( + span_exporter=get_otlp_span_exporter(otlp_endpoint), + schedule_delay_millis=schedule_delay_millis, + max_export_batch_size=max_export_batch_size, + ) + tracer_provider.add_span_processor(processor) + trace.set_tracer_provider(tracer_provider) + except Exception as e: + opentelemetry_initialized = False + raise RuntimeError( + f"initialize opentelemetry error:{e}. Please set correct otlp endpoint." + ) + + opentelemetry_initialized = True + tracer = trace.get_tracer("sglang server") + + +def get_global_tracing_enabled(): + return opentelemetry_initialized + + +def get_otlp_span_exporter(endpoint): + protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc") + supported_protocols = {"grpc", "http/protobuf"} + + if protocol not in supported_protocols: + raise ValueError( + f"Unsupported OTLP protocol '{protocol}' configured. " + f"Supported protocols are: {', '.join(sorted(supported_protocols))}" + ) + + if protocol == "grpc": + return GRPCSpanExporter(endpoint=endpoint, insecure=True) + elif protocol == "http/protobuf": + return HTTPSpanExporter(endpoint=endpoint) + + +# Should be called by each tracked thread. +def trace_set_thread_info( + thread_label: str, tp_rank: Optional[int] = None, dp_rank: Optional[int] = None +): + if not opentelemetry_initialized: + return + + pid = threading.get_native_id() + if pid in threads_info: + return + + threads_info[pid] = TraceThreadInfo( + host_id=__get_host_id(), + pid=pid, + thread_label=thread_label, + tp_rank=tp_rank, + dp_rank=dp_rank, + ) + + +class TraceReqContext: + def __init__( + self, + rid, + bootstrap_room=None, + role="unified", + module_name="", + external_trace_header: Optional[Dict[str, str]] = None, + ): + self.rid: str = str(rid) + self.trace_level = global_trace_level + self.tracing_enable: bool = opentelemetry_initialized and self.trace_level > 0 + + if not self.tracing_enable: + return + + self.start_time_ns: Optional[int] = None + self.thread_context: Optional[TraceThreadContext] = None + self.bootstrap_room: Optional[int] = bootstrap_room + self.role: str = role + self.module_name = module_name + + # Indicates whether this instance is a replica from the main process. + # When True, root_span is None and only root_span_context is preserved. + self.is_copy: bool = False + self.root_span: Optional[trace.span.Span] = None + self.root_span_context: Optional[context.Context] = None + # Record the most recently completed span as the previous span for the next span to be created. + self.last_span_context: Optional[trace.span.SpanContext] = None + self.external_trace_header: Optional[Dict[str, str]] = external_trace_header + + self.events_cache: List[TraceEvent] = [] + + self.pid: int = threading.get_native_id() + + def is_tracing_enabled(self) -> bool: + return self.tracing_enable + + def __create_thread_context(self, ts: int): + if self.pid not in threads_info: + trace_set_thread_info("unknown") + + thread_info = threads_info[self.pid] + thread_context = TraceThreadContext( + thread_info=thread_info, + cur_slice_stack=[], + ) + + thread_name = f"{thread_info.thread_label}" + if thread_info.tp_rank is not None: + thread_name += f" [TP {thread_info.tp_rank}] " + thread_name += f"(host:{thread_info.host_id[:8]} | pid:{self.pid})" + thread_context.thread_span = tracer.start_span( + name=thread_name, + start_time=ts, + context=self.root_span_context, + ) + + if thread_info.tp_rank is not None: + thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank}) + + thread_context.thread_span.set_attributes( + { + "host_id": thread_info.host_id, + "pid": thread_info.pid, + "thread_label": thread_info.thread_label, + } + ) + + return thread_context + + def __getstate__(self) -> Optional[Dict[str, Any]]: + if not self.tracing_enable: + return {"tracing_enable": False} + + if not self.root_span_context: + return {"tracing_enable": False} + + state = { + "tracing_enable": self.tracing_enable, + "rid": self.rid, + "bootstrap_room": self.bootstrap_room, + "start_time_ns": self.start_time_ns, + "role": self.role, + "trace_level": self.trace_level, + "module_name": self.module_name, + "is_copy": self.is_copy, + "pid": self.pid, + "thread_context": None, + "root_span": None, + "last_span_context": None, + } + + carrier: dict[str, str] = {} + propagate.inject(carrier, self.root_span_context) + state["root_span_context"] = carrier + + prev_span_context = self.last_span_context + if self.thread_context and self.thread_context.cur_slice_stack: + cur_slice = self.thread_context.cur_slice_stack[0] + if cur_slice.span: + prev_span_context = cur_slice.span.get_span_context() + + if prev_span_context: + state["last_span_context"] = { + "span_id": prev_span_context.span_id, + "trace_id": prev_span_context.trace_id, + } + + return state + + def __setstate__(self, state: Dict[str, Any]): + self.__dict__.update(state) + if not opentelemetry_initialized: + self.tracing_enable = False + if not self.tracing_enable: + return + + self.is_copy = True + self.pid = threading.get_native_id() + self.root_span_context = propagate.extract(self.root_span_context) + if self.last_span_context: + self.last_span_context = trace.span.SpanContext( + trace_id=self.last_span_context["trace_id"], + span_id=self.last_span_context["span_id"], + is_remote=True, + ) + self.events_cache = [] + + def rebuild_thread_context(self, ts: Optional[int] = None): + if not self.tracing_enable: + return + + ts = ts or get_cur_time_ns() + self.thread_context = self.__create_thread_context(ts) + + def trace_req_start( + self, + ts: Optional[int] = None, + ): + if not self.tracing_enable: + return + + ts = ts or get_cur_time_ns() + + # create req context and root span + self.start_time_ns = ts + + external_trace_context = _trace_context_propagator.extract( + self.external_trace_header or {} + ) + + # Drop the worker_id added by MultiTokenizer + orig_rid = self.rid.split("_")[-1] + role = "" if self.role == "unified" else self.role + attrs = {"rid": orig_rid, "module": f"sglang::{self.module_name}"} + if self.bootstrap_room: + attrs["bootstrap_room"] = str(hex(self.bootstrap_room)) + root_span = tracer.start_span( + name=f"{role} Req {orig_rid[:8]}", + start_time=ts, + context=external_trace_context, + attributes=attrs, + ) + + self.root_span = root_span + self.root_span_context = trace.set_span_in_context(root_span) + + # create thread context and thread span + self.thread_context = self.__create_thread_context(ts) + + def trace_req_finish( + self, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None + ): + if not self.tracing_enable: + return + + if not self.root_span: + return + + ts = ts or get_cur_time_ns() + + # End all unclosed thread spans. + self.abort() + + if attrs: + self.root_span.set_attributes(attrs) + + self.root_span.end(end_time=ts) + self.root_span = None + + def __check_fast_return(self, level=None): + if not self.tracing_enable: + return True + + if not self.thread_context: + return True + + if level and level > self.trace_level: + return True + + return False + + def trace_slice_start( + self, + name: str, + level: int, + ts: Optional[int] = None, + ): + if self.__check_fast_return(level): + return + + ts = ts or get_cur_time_ns() + + cur_slice = TraceSliceContext( + slice_name=name, + start_time_ns=ts, + level=level, + attrs={}, + events=[], + ) + + parent_span = self.thread_context.thread_span + prev_span_context = None + if not self.thread_context.cur_slice_stack: + if self.last_span_context: + prev_span_context = self.last_span_context + else: + parent_span = self.thread_context.cur_slice_stack[-1].span + + parent_span_context = trace.set_span_in_context(parent_span) + + span = tracer.start_span( + name=cur_slice.slice_name, + start_time=cur_slice.start_time_ns, + context=parent_span_context, + ) + cur_slice.span = span + + if prev_span_context: + span.add_link(prev_span_context) + + self.thread_context.cur_slice_stack.append(cur_slice) + + def trace_slice_end( + self, + name: str, + level: int, + ts: Optional[int] = None, + attrs: Optional[Dict[str, Any]] = None, + thread_finish_flag: bool = False, + ): + if self.__check_fast_return(level): + return + + if not self.thread_context.cur_slice_stack: + logger.warning( + f"No matching with the SLICE_START event {name} is required." + ) + return + + cur_slice = self.thread_context.cur_slice_stack[-1] + ts = ts or get_cur_time_ns() + + # check if slice_name matching and level matching + # unlikely path, excepting error API usage + if cur_slice.slice_name != name or cur_slice.level != level: + logger.warning( + f"Slice name mismatch: {name} != {cur_slice.slice_name} or level mismatch: {level} != {cur_slice.level}" + ) + self.thread_context.cur_slice_stack.pop() + return + + span = cur_slice.span + + if attrs: + span.set_attributes(attrs) + + if self.events_cache: + new_events_cache = [] + for event in self.events_cache: + if event.ts >= cur_slice.start_time_ns and event.ts < ts: + span.add_event( + name=event.event_name, + timestamp=event.ts, + attributes=event.attrs, + ) + else: + new_events_cache.append(event) + self.events_cache = new_events_cache + + span.end(end_time=ts) + + self.thread_context.cur_slice_stack.pop() + # only for first level slice + if not self.thread_context.cur_slice_stack: + self.last_span_context = span.get_span_context() + + if thread_finish_flag: + self.abort(ts) + + def trace_slice( + self, + slice: TraceSliceContext, + thread_finish_flag: bool = False, + ): + if self.__check_fast_return(slice.level): + return + + parent_span = self.thread_context.thread_span + prev_span_context = None + if not self.thread_context.cur_slice_stack: + if self.last_span_context: + prev_span_context = self.last_span_context + else: + parent_span = self.thread_context.cur_slice_stack[-1].span + + parent_span_context = trace.set_span_in_context(parent_span) + + span = tracer.start_span( + name=slice.slice_name, + start_time=slice.start_time_ns, + context=parent_span_context, + ) + + if prev_span_context: + span.add_link(prev_span_context) + + if slice.attrs: + span.set_attributes(slice.attrs) + + if slice.events: + for event in slice.events: + span.add_event( + name=event.event_name, timestamp=event.ts, attributes=event.attrs + ) + + if self.events_cache: + new_events_cache = [] + for event in self.events_cache: + if event.ts >= slice.start_time_ns and event.ts < slice.end_time_ns: + span.add_event( + name=event.event_name, + timestamp=event.ts, + attributes=event.attrs, + ) + else: + new_events_cache.append(event) + self.events_cache = new_events_cache + + span.end(end_time=slice.end_time_ns) + + # only for first level slice + if not self.thread_context.cur_slice_stack: + self.last_span_context = span.get_span_context() + + if thread_finish_flag: + self.abort(slice.end_time_ns) + + # Add event to the current slice on the same thread with the same rid. + def trace_event( + self, + name: str, + level: int, + ts: Optional[int] = None, + attrs: Dict[str, Any] = None, + ): + if self.__check_fast_return(level): + return + + ts = ts or get_cur_time_ns() + + if attrs is None: + attrs = {} + self.events_cache.append(TraceEvent(name, ts, attrs)) + + def trace_set_root_attrs(self, attrs: Dict[str, Any]): + if not self.tracing_enable: + return + + if self.root_span: + self.root_span.set_attributes(attrs) + + def trace_set_thread_attrs(self, attrs: Dict[str, Any]): + if self.__check_fast_return(): + return + + if self.thread_context.thread_span: + self.thread_context.thread_span.set_attributes(attrs) + + def abort(self, ts=None, abort_info: Optional[Dict] = None): + if self.__check_fast_return(): + return + + # close all slice spans (unlikely, except error API usage) + ts = ts or get_cur_time_ns() + while len(self.thread_context.cur_slice_stack) > 0: + if self.thread_context.cur_slice_stack[-1].span: + self.thread_context.cur_slice_stack[-1].span.end(end_time=ts) + self.thread_context.cur_slice_stack.pop() + + # set abort info into thread span + if self.thread_context.thread_span: + if abort_info: + from sglang.srt.managers.schedule_batch import BaseFinishReason + + if isinstance(abort_info, BaseFinishReason): + abort_info = abort_info.to_json() + self.thread_context.thread_span.set_status(Status(StatusCode.ERROR)) + self.thread_context.thread_span.set_attributes(abort_info) + + if self.events_cache: + for event in self.events_cache: + self.thread_context.thread_span.add_event( + name=event.event_name, + timestamp=event.ts, + attributes=event.attrs, + ) + self.events_cache = [] + + self.thread_context.thread_span.end(end_time=ts) + self.thread_context = None + + def __del__(self): + self.abort(abort_info={"reason": "have unclosed span, auto closed"}) + + +@dataclass +class TraceNullContext: + tracing_enable: bool = False + + def __getattr__(self, name): + return self + + def __call__(self, *args, **kwargs): + return self + + +class SpanAttributes: + # Attribute names copied from here to avoid version conflicts: + # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md + GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens" + GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens" + GEN_AI_USAGE_CACHED_TOKENS = "gen_ai.usage.cached_tokens" + GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens" + GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p" + GEN_AI_REQUEST_TOP_K = "gen_ai.request.top_k" + GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature" + GEN_AI_RESPONSE_MODEL = "gen_ai.response.model" + GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons" + GEN_AI_REQUEST_ID = "gen_ai.request.id" + GEN_AI_REQUEST_N = "gen_ai.request.n" + GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" + GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" + GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e" + GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill" + GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" + GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference" diff --git a/python/sglang/srt/metrics/utils.py b/python/sglang/srt/observability/utils.py similarity index 100% rename from python/sglang/srt/metrics/utils.py rename to python/sglang/srt/observability/utils.py diff --git a/python/sglang/srt/tracing/trace.py b/python/sglang/srt/tracing/trace.py deleted file mode 100644 index 6ae1c87389a6..000000000000 --- a/python/sglang/srt/tracing/trace.py +++ /dev/null @@ -1,761 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""package for sglang requests tracing""" - -from __future__ import annotations - -import base64 -import json -import logging -import os -import random -import threading -import time -import uuid -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from sglang.srt.utils import get_int_env_var - -if TYPE_CHECKING: - from sglang.srt.managers.scheduler import Req -from typing import Any, Dict, List, Mapping, Optional - -logger = logging.getLogger(__name__) -opentelemetry_imported = False -tracing_enabled = False -_trace_context_propagator = None - -TRACE_HEADERS = ["traceparent", "tracestate"] - -try: - from opentelemetry import context, propagate, trace - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( - OTLPSpanExporter as GRPCSpanExporter, - ) - from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter as HTTPSpanExporter, - ) - from opentelemetry.sdk.environment_variables import ( - OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, - ) - from opentelemetry.sdk.resources import SERVICE_NAME, Resource - from opentelemetry.sdk.trace import TracerProvider, id_generator - from opentelemetry.sdk.trace.export import BatchSpanProcessor - from opentelemetry.trace.propagation.tracecontext import ( - TraceContextTextMapPropagator, - ) - - _trace_context_propagator = TraceContextTextMapPropagator() - - opentelemetry_imported = True -except ImportError: - - class id_generator: - class IdGenerator: - pass - - logger.debug("opentelemetry package is not installed, tracing disabled") - - -def is_tracing_enabled() -> bool: - return tracing_enabled - - -def extract_trace_headers(headers: Mapping[str, str]) -> Optional[Dict]: - return {h: headers[h] for h in TRACE_HEADERS if h in headers} - - -@dataclass -class SglangTraceThreadInfo: - host_id: str - pid: int - thread_label: str - tp_rank: int - dp_rank: int - tracer: trace.Tracer - - -@dataclass -class SglangTraceSliceContext: - slice_name: str - span: Optional[trace.span.Span] = None - # When True, defers slice_name assignment until trace_slice_end() - anonymous: bool = False - - -@dataclass -class SglangTraceThreadContext: - thread_info: SglangTraceThreadInfo - cur_slice_stack: List[SglangTraceSliceContext] - thread_span: Optional[trace.span.Span] = None - # Record the most recently completed span as the previous span for the next span to be created. - last_span_context: Optional[trace.span.SpanContext] = None - - -@dataclass -class SglangTraceReqContext: - rid: str - start_time_ns: int - threads_context: Dict[int, SglangTraceThreadContext] - bootstrap_room: Optional[int] = None - - # Indicates whether this instance is a replica from the main process. - # When True, root_span is None and only root_span_context is preserved. - is_copy: bool = False - bootstrap_room_span: Optional[trace.span.Span] = None - bootstrap_room_span_context: Optional[context.Context] = None - root_span: Optional[trace.span.Span] = None - root_span_context: Optional[context.Context] = None - - -@dataclass -class SglangTracePropagateContext: - root_span_context: context.Context - prev_span_context: Optional[trace.span.SpanContext] - - def to_dict(self): - carrier: dict[str, str] = {} - propagate.inject(carrier, self.root_span_context) - - if self.prev_span_context: - return { - "root_span": carrier, - "prev_span": { - "span_id": self.prev_span_context.span_id, - "trace_id": self.prev_span_context.trace_id, - }, - } - else: - return {"root_span": carrier, "prev_span": "None"} - - @classmethod - def instance_from_dict(cls, d): - if "root_span" not in d or "prev_span" not in d: - return None - - carrier = d["root_span"] - root_span_context = propagate.extract(carrier) - - if d["prev_span"] == "None": - prev_span_context = None - else: - prev_span_context = trace.span.SpanContext( - trace_id=d["prev_span"]["trace_id"], - span_id=d["prev_span"]["span_id"], - is_remote=True, - ) - - return cls(root_span_context, prev_span_context) - - -class SglangTraceCustomIdGenerator(id_generator.IdGenerator): - """ - The default IdGenerator may produce duplicate trace IDs across multiple TP scheduler processes, - hence a custom IdGenerator is implemented. - """ - - def __init__(self): - super().__init__() - self.local_random = random.Random() - self.local_random.seed(time.time()) - - def generate_trace_id(self) -> int: - return self.local_random.getrandbits(64) - - def generate_span_id(self) -> int: - return self.local_random.getrandbits(64) - - -# global variables -remote_trace_contexts: Dict[str, SglangTracePropagateContext] = {} -threads_info: Dict[int, SglangTraceThreadInfo] = {} -reqs_context: Dict[str, SglangTraceReqContext] = {} - -__get_cur_time_ns = lambda: int(time.time() * 1e9) - - -def __get_host_id() -> str: - """ - In distributed tracing systems, obtain a unique node identifier - and inject it into all subsequently generated spans - to prevent PID conflicts between threads on different nodes. - """ - if os.path.exists("/etc/machine-id"): - try: - with open("/etc/machine-id", "r") as f: - return f.read().strip() - except: - pass - - mac = uuid.getnode() - if mac != 0: - return uuid.UUID(int=mac).hex - - return "unknown" - - -# Should be called by each tracked process. -def process_tracing_init(otlp_endpoint, server_name): - global tracing_enabled - global __get_cur_time_ns - if not opentelemetry_imported: - logger.warning(f"Tracing is disabled because the packages cannot be imported.") - tracing_enabled = False - return - - try: - resource = Resource.create( - attributes={ - SERVICE_NAME: server_name, - } - ) - tracer_provider = TracerProvider( - resource=resource, id_generator=SglangTraceCustomIdGenerator() - ) - - schedule_delay_millis = get_int_env_var( - "SGLANG_OTLP_EXPORTER_SCHEDULE_DELAY_MILLIS", 500 - ) - max_export_batch_size = get_int_env_var( - "SGLANG_OTLP_EXPORTER_MAX_EXPORT_BATCH_SIZE", 64 - ) - - processor = BatchSpanProcessor( - span_exporter=get_otlp_span_exporter(otlp_endpoint), - schedule_delay_millis=schedule_delay_millis, - max_export_batch_size=max_export_batch_size, - ) - tracer_provider.add_span_processor(processor) - trace.set_tracer_provider(tracer_provider) - except Exception as e: - logger.error( - f"Initialize OpenTelemetry error: {e}. Please set correct otlp endpoint." - ) - tracing_enabled = False - return - - if hasattr(time, "time_ns"): - __get_cur_time_ns = lambda: int(time.time_ns()) - - tracing_enabled = True - - -def get_otlp_span_exporter(endpoint): - protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc") - supported_protocols = {"grpc", "http/protobuf"} - - if protocol not in supported_protocols: - raise ValueError( - f"Unsupported OTLP protocol '{protocol}' configured. " - f"Supported protocols are: {', '.join(sorted(supported_protocols))}" - ) - - if protocol == "grpc": - return GRPCSpanExporter(endpoint=endpoint, insecure=True) - elif protocol == "http/protobuf": - return HTTPSpanExporter(endpoint=endpoint) - - -# Should be called by each tracked thread. -def trace_set_thread_info( - thread_label: str, tp_rank: Optional[int] = None, dp_rank: Optional[int] = None -): - if not tracing_enabled: - return - - pid = threading.get_native_id() - if pid in threads_info: - return - - threads_info[pid] = SglangTraceThreadInfo( - host_id=__get_host_id(), - pid=pid, - thread_label=thread_label, - tp_rank=tp_rank, - dp_rank=dp_rank, - tracer=trace.get_tracer("sglang server"), - ) - - -def __create_thread_context(pid, req_span_context, ts: Optional[int] = None): - if pid not in threads_info: - trace_set_thread_info("unknown") - - thread_info = threads_info[pid] - thread_context = SglangTraceThreadContext( - thread_info=thread_info, - cur_slice_stack=[], - ) - - thread_name = f"{thread_info.thread_label}" - if thread_info.tp_rank is not None: - thread_name += f" [TP {thread_info.tp_rank}] " - thread_name += f"(host:{thread_info.host_id[:8]} | pid:{pid})" - ts = ts or __get_cur_time_ns() - thread_context.thread_span = thread_context.thread_info.tracer.start_span( - name=thread_name, - start_time=ts, - context=req_span_context, - ) - - if thread_info.tp_rank is not None: - thread_context.thread_span.set_attributes({"tp_rank": thread_info.tp_rank}) - - thread_context.thread_span.set_attributes( - { - "host_id": thread_info.host_id, - "pid": thread_info.pid, - "thread_label": thread_info.thread_label, - } - ) - - return thread_context - - -def trace_get_proc_propagate_context( - rid, remote_propagate=False -) -> Optional[Dict[str, Any]]: - if not tracing_enabled: - return None - - rid = str(rid) - if rid not in reqs_context or not reqs_context[rid].root_span_context: - return None - - pid = threading.get_native_id() - prev_span_context = None - thread_context = reqs_context[rid].threads_context[pid] - if thread_context.cur_slice_stack: - cur_slice_info = thread_context.cur_slice_stack[0] - prev_span_context = cur_slice_info.span.get_span_context() - elif thread_context.last_span_context: - prev_span_context = thread_context.last_span_context - - root_span_context = reqs_context[rid].root_span_context - if remote_propagate: - root_span_context = reqs_context[rid].bootstrap_room_span_context - - trace_context = SglangTracePropagateContext(root_span_context, prev_span_context) - return trace_context.to_dict() - - -def trace_set_proc_propagate_context(rid, trace_context: Optional[Dict[str, Any]]): - if not tracing_enabled: - return - if not trace_context: - return - - trace_context = SglangTracePropagateContext.instance_from_dict(trace_context) - if not trace_context: - return - - rid = str(rid) - # Create a copy of the request context - if rid not in reqs_context: - reqs_context[rid] = SglangTraceReqContext( - rid=rid, - start_time_ns=__get_cur_time_ns(), - threads_context={}, - root_span_context=trace_context.root_span_context, - is_copy=True, - ) - - pid = threading.get_native_id() - - if pid in reqs_context[rid].threads_context: - return - - # Create new thread context. - reqs_context[rid].threads_context[pid] = __create_thread_context( - pid, - trace_context.root_span_context, - reqs_context[rid].start_time_ns, - ) - - reqs_context[rid].threads_context[ - pid - ].last_span_context = trace_context.prev_span_context - - -def trace_get_remote_propagate_context(bootstrap_room_list: List[str]): - if not tracing_enabled: - return "" - - reqs_trace_contexts = {} - for bootstrap_room in bootstrap_room_list: - # In the router, rid is also the bootstrap room. - bootstrap_room = str(bootstrap_room) - - if bootstrap_room not in reqs_context: - continue - - _context = trace_get_proc_propagate_context( - bootstrap_room, remote_propagate=True - ) - reqs_trace_contexts[bootstrap_room] = _context - - json_str = json.dumps(reqs_trace_contexts, ensure_ascii=False) - return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") - - -def trace_set_remote_propagate_context(base64_str): - if not tracing_enabled: - return - - if base64_str is None or base64_str == "" or base64_str == "None": - return - - base64_bytes = base64.b64decode(base64_str) - json_str = base64_bytes.decode("utf-8") - remote_reqs_trace_contexts = json.loads(json_str) - - for bootstrap_room in remote_reqs_trace_contexts: - if bootstrap_room in remote_trace_contexts: - continue - - remote_trace_contexts[bootstrap_room] = ( - SglangTracePropagateContext.instance_from_dict( - remote_reqs_trace_contexts[bootstrap_room] - ) - ) - - -def trace_req_start( - rid: str, - bootstrap_room: Optional[int] = None, - ts: Optional[int] = None, - role: Optional[str] = "null", - external_trace_header: Optional[Dict[str, str]] = None, -): - if not tracing_enabled: - return - - rid = str(rid) - - ts = ts or __get_cur_time_ns() - - pid = threading.get_native_id() - if pid not in threads_info: - return - - # create req context and root span - bootstrap_room = 0 if bootstrap_room is None else bootstrap_room - reqs_context[rid] = SglangTraceReqContext( - rid=rid, - start_time_ns=ts, - threads_context={}, - bootstrap_room=bootstrap_room, - is_copy=False, - ) - - # create bootstrap room span - tracer = threads_info[pid].tracer - if str(bootstrap_room) not in remote_trace_contexts: - attrs = {"bootstrap_room": str(hex(bootstrap_room))} - external_trace_context = _trace_context_propagator.extract( - external_trace_header or {} - ) - bootstrap_room_span = tracer.start_span( - name=f"Bootstrap Room {hex(bootstrap_room)}", - start_time=ts, - attributes=attrs, - context=external_trace_context, - ) - reqs_context[rid].bootstrap_room_span = bootstrap_room_span - bootstrap_room_span_context = trace.set_span_in_context(bootstrap_room_span) - else: - bootstrap_room_span_context = remote_trace_contexts[ - str(bootstrap_room) - ].root_span_context - - # Drop the worker_id added by MultiTokenizer - orig_rid = rid.split("_")[-1] - role = "" if role == "null" else role - attrs = {"rid": orig_rid} - root_span = tracer.start_span( - name=f"{role} Req {orig_rid[:8]}", - start_time=ts, - context=bootstrap_room_span_context, - attributes=attrs, - ) - - root_span.set_attributes( - { - "rid": rid, - } - ) - - reqs_context[rid].root_span = root_span - reqs_context[rid].root_span_context = trace.set_span_in_context(root_span) - reqs_context[rid].bootstrap_room_span_context = bootstrap_room_span_context - - # create thread context and thread span - reqs_context[rid].threads_context[pid] = __create_thread_context( - pid, - reqs_context[rid].root_span_context, - ts, - ) - if str(bootstrap_room) in remote_trace_contexts: - reqs_context[rid].threads_context[pid].last_span_context = ( - remote_trace_contexts[str(bootstrap_room)].prev_span_context - ) - - -def trace_req_finish( - rid: str, ts: Optional[int] = None, attrs: Optional[Dict[str, Any]] = None -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - req_context = reqs_context[rid] - ts = ts or __get_cur_time_ns() - - # End all unclosed thread spans. - for thread_context in req_context.threads_context.values(): - thread_context.thread_span.end(end_time=ts) - - if attrs: - req_context.root_span.set_attributes(attrs) - - req_context.root_span.end(end_time=ts) - if str(req_context.bootstrap_room) in remote_trace_contexts: - del remote_trace_contexts[str(req_context.bootstrap_room)] - elif req_context.bootstrap_room_span: - req_context.bootstrap_room_span.end(end_time=ts) - - del reqs_context[rid] - - -def trace_slice_start( - name: str, - rid: str, - ts: Optional[int] = None, - anonymous: bool = False, -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return - - thread_context = reqs_context[rid].threads_context[pid] - - ts = ts or __get_cur_time_ns() - - slice_info = SglangTraceSliceContext( - slice_name=name, - anonymous=anonymous, - ) - - # find prev slice - prev_span_context = None - if not thread_context.cur_slice_stack: - if thread_context.last_span_context: - prev_span_context = thread_context.last_span_context - - parent_span = thread_context.thread_span - if thread_context.cur_slice_stack: - parent_span = thread_context.cur_slice_stack[-1].span - - parent_span_context = trace.set_span_in_context(parent_span) - span = thread_context.thread_info.tracer.start_span( - name=slice_info.slice_name, - start_time=ts, - context=parent_span_context, - ) - - if prev_span_context: - span.add_link(prev_span_context) - - slice_info.span = span - - thread_context.cur_slice_stack.append(slice_info) - - -def trace_slice_end( - name: str, - rid: str, - ts: Optional[int] = None, - attrs: Optional[Dict[str, Any]] = None, - auto_next_anon: bool = False, - thread_finish_flag: bool = False, -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return - - thread_context = reqs_context[rid].threads_context[pid] - - if not thread_context.cur_slice_stack: - logger.warning(f"No matching with the SLICE_START event{name} is required.") - return - - ts = ts or __get_cur_time_ns() - slice_info = thread_context.cur_slice_stack[-1] - span = slice_info.span - - if slice_info.anonymous: - span.update_name(name) - else: - span = slice_info.span - if slice_info.slice_name != name: - span.set_status(trace.Status(trace.StatusCode.ERROR)) - logger.warning(f"Slice name mismatch: {name} != {slice_info.slice_name}") - - if attrs: - span.set_attributes(attrs) - - span.end(end_time=ts) - - thread_context.cur_slice_stack.pop() - if len(thread_context.cur_slice_stack) == 0: - thread_context.last_span_context = span.get_span_context() - - # If this is the last slice in the thread, - # release the thread context and check whether to release the request context. - if thread_finish_flag: - thread_context.thread_span.end(end_time=ts) - del reqs_context[rid].threads_context[pid] - if reqs_context[rid].is_copy and not reqs_context[rid].threads_context: - del reqs_context[rid] - return - - if auto_next_anon: - trace_slice_start("", rid, ts, True) - - -# alias -trace_slice = trace_slice_end - - -# Add event to the current slice on the same thread with the same rid. -def trace_event( - name: str, rid: str, ts: Optional[int] = None, attrs: Dict[str, Any] = None -): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return - - thread_context = reqs_context[rid].threads_context[pid] - - if not thread_context.cur_slice_stack: - logger.warning(f"No slice is currently being traced.") - return - - ts = ts or __get_cur_time_ns() - - slice_info = thread_context.cur_slice_stack[-1] - slice_info.span.add_event(name=name, timestamp=ts, attributes=attrs) - - -# Add attrs to the current slice on the same thread with the same rid. -def trace_slice_add_attr(rid: str, attrs: Dict[str, Any]): - if not tracing_enabled: - return - - rid = str(rid) - if rid not in reqs_context: - return - - pid = threading.get_native_id() - if pid not in reqs_context[rid].threads_context: - return - - thread_context = reqs_context[rid].threads_context[pid] - - if not thread_context.cur_slice_stack: - logger.warning(f"No slice is currently being traced.") - return - - slice_info = thread_context.cur_slice_stack[-1] - slice_info.span.set_attributes(attrs) - - -def trace_slice_batch( - name: str, - reqs: List[Req], -): - if not tracing_enabled: - return - - for req in reqs: - trace_slice( - name, - req.rid, - auto_next_anon=not req.finished(), - thread_finish_flag=req.finished(), - ) - - -def trace_event_batch( - name: str, - reqs: List[Req], - ts: Optional[int] = None, - attrs: Dict[str, Any] = {}, -): - if not tracing_enabled: - return - - bid = uuid.uuid4().hex[:8] - _attrs = {"bid": bid, "batch_size": len(reqs)} - _attrs.update(attrs) - - for req in reqs: - trace_event(name, req.rid, ts=ts, attrs=_attrs) - - -class SpanAttributes: - # Attribute names copied from here to avoid version conflicts: - # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md - GEN_AI_USAGE_COMPLETION_TOKENS = "gen_ai.usage.completion_tokens" - GEN_AI_USAGE_PROMPT_TOKENS = "gen_ai.usage.prompt_tokens" - GEN_AI_USAGE_CACHED_TOKENS = "gen_ai.usage.cached_tokens" - GEN_AI_REQUEST_MAX_TOKENS = "gen_ai.request.max_tokens" - GEN_AI_REQUEST_TOP_P = "gen_ai.request.top_p" - GEN_AI_REQUEST_TOP_K = "gen_ai.request.top_k" - GEN_AI_REQUEST_TEMPERATURE = "gen_ai.request.temperature" - GEN_AI_RESPONSE_MODEL = "gen_ai.response.model" - GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons" - GEN_AI_REQUEST_ID = "gen_ai.request.id" - GEN_AI_REQUEST_N = "gen_ai.request.n" - GEN_AI_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" - GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" - GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e" - GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill" - GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" - GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference" diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 4636128fa72f..30367a6ccfd1 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -93,7 +93,7 @@ from typing_extensions import Literal from sglang.srt.environ import envs -from sglang.srt.metrics.func_timer import enable_func_timer +from sglang.srt.observability.func_timer import enable_func_timer if TYPE_CHECKING: # Apparently importing this here is necessary to avoid a segfault, see comment in load_video below diff --git a/test/manual/test_tracing.py b/test/manual/test_tracing.py index 4e3763ac414e..bdb4a14a6c3e 100644 --- a/test/manual/test_tracing.py +++ b/test/manual/test_tracing.py @@ -4,13 +4,14 @@ import time import unittest from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Optional, Union import requests import zmq from sglang import Engine -from sglang.srt.tracing.trace import * +from sglang.srt.observability.trace import * +from sglang.srt.observability.trace import get_cur_time_ns, set_global_trace_level from sglang.srt.utils import get_zmq_socket, kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -24,7 +25,7 @@ @dataclass class Req: rid: int - trace_context: Optional[Dict[str, Any]] = None + req_context: Optional[Union[TraceReqContext]] = None class TestTrace(CustomTestCase): @@ -65,22 +66,33 @@ def __clear_trace_file(self): except: pass - def test_trace_enable(self): + def __test_trace_enable(self, trace_level, expect_export_data): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) process = popen_launch_server( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-trace", "--otlp-traces-endpoint", "0.0.0.0:4317"], + other_args=[ + "--enable-trace", + "--otlp-traces-endpoint", + "0.0.0.0:4317", + ], ) try: - # Make some requests to generate trace data response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") self.assertEqual(response.status_code, 200) + # set trace level + response = requests.get( + f"{DEFAULT_URL_FOR_TEST}/set_trace_level?level={trace_level}" + ) + self.assertEqual(response.status_code, 200) + + # Make some requests to generate trace data response = requests.post( f"{DEFAULT_URL_FOR_TEST}/generate", json={ @@ -101,15 +113,34 @@ def test_trace_enable(self): # check trace file assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" - assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" + if expect_export_data: + assert ( + os.path.getsize("/tmp/otel_trace.json") > 0 + ), "trace file is empty" + else: + assert ( + os.path.getsize("/tmp/otel_trace.json") == 0 + ), "trace file is not empty" finally: kill_process_tree(process.pid) - assert self.__stop_otel_jaeger() + + def test_trace_enable_level_1(self): + self.__test_trace_enable("1", True) + + def test_trace_enable_level_2(self): + self.__test_trace_enable("2", True) + + def test_trace_enable_level_3(self): + self.__test_trace_enable("3", True) + + def test_trace_enable_level_0(self): + self.__test_trace_enable("0", False) def test_trace_engine_enable(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -134,11 +165,11 @@ def test_trace_engine_enable(self): assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: engine.shutdown() - assert self.__stop_otel_jaeger() def test_trace_engine_encode(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) prompt = "Today is a sunny day and I like" model_path = "Qwen/Qwen2-7B" @@ -162,19 +193,21 @@ def test_trace_engine_encode(self): assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: engine.shutdown() - assert self.__stop_otel_jaeger() def test_slice_trace_simple(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) try: process_tracing_init("0.0.0.0:4317", "test") trace_set_thread_info("Test") - trace_req_start(0) - trace_slice_start("test slice", 0) + set_global_trace_level(3) + req_context = TraceReqContext(0) + req_context.trace_req_start() + req_context.trace_slice_start("test slice", level=1) time.sleep(1) - trace_slice_end("test slice", 0) - trace_req_finish(0) + req_context.trace_slice_end("test slice", level=1) + req_context.trace_req_finish() # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. time.sleep(10) @@ -182,23 +215,29 @@ def test_slice_trace_simple(self): assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: - assert self.__stop_otel_jaeger() + pass def test_slice_trace_complex(self): self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) try: process_tracing_init("0.0.0.0:4317", "test") trace_set_thread_info("Test") - trace_req_start(0) - trace_slice_start("", 0, anonymous=True) - time.sleep(1) - trace_slice_end("slice A", 0, auto_next_anon=True) + set_global_trace_level(3) + req_context = TraceReqContext(0) + req_context.trace_req_start() + t1 = get_cur_time_ns() time.sleep(1) - trace_slice_end("slice B", 0, auto_next_anon=True) + req_context.trace_event("event test", 1) + t2 = get_cur_time_ns() time.sleep(1) - trace_slice_end("slice C", 0, thread_finish_flag=True) - trace_req_finish(0) + t3 = get_cur_time_ns() + slice1 = TraceSliceContext("slice A", t1, t2) + slice2 = TraceSliceContext("slice B", t2, t3) + req_context.trace_slice(slice1) + req_context.trace_slice(slice2, thread_finish_flag=True) + req_context.trace_req_finish() # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. time.sleep(10) @@ -206,7 +245,7 @@ def test_slice_trace_complex(self): assert os.path.isfile("/tmp/otel_trace.json"), "trace file not exist" assert os.path.getsize("/tmp/otel_trace.json") > 0, "trace file is empty" finally: - assert self.__stop_otel_jaeger() + pass def test_trace_context_propagete(self): def __process_work(): @@ -220,16 +259,19 @@ def __process_work(): try: req = recv_from_main.recv_pyobj() - trace_set_proc_propagate_context(req.rid, req.trace_context) - trace_slice_start("work", req.rid) + req.req_context.rebuild_thread_context() + req.req_context.trace_slice_start("work", level=1) time.sleep(1) - trace_slice_end("work", req.rid, thread_finish_flag=True) + req.req_context.trace_slice_end( + "work", level=1, thread_finish_flag=True + ) finally: recv_from_main.close() context.term() self.__clear_trace_file() assert self.__launch_otel_jaeger() + self.addCleanup(self.__stop_otel_jaeger) context = zmq.Context(2) send_to_subproc = get_zmq_socket( @@ -246,15 +288,15 @@ def __process_work(): time.sleep(1) req = Req(rid=0) - trace_req_start(req.rid) - trace_slice_start("dispatch", req.rid) + req.req_context = TraceReqContext(0) + req.req_context.trace_req_start() + req.req_context.trace_slice_start("dispatch", level=1) time.sleep(1) - req.trace_context = trace_get_proc_propagate_context(req.rid) send_to_subproc.send_pyobj(req) - trace_slice_end("dispatch", req.rid) + req.req_context.trace_slice_end("dispatch", level=1) subproc.join() - trace_req_finish(req.rid) + req.req_context.trace_req_finish() # sleep for a few seconds to wait for opentelemetry collector to asynchronously export data to file. time.sleep(10) @@ -265,7 +307,6 @@ def __process_work(): finally: send_to_subproc.close() context.term() - assert self.__stop_otel_jaeger() if __name__ == "__main__": diff --git a/test/registered/metrics/test_cpu_monitor.py b/test/registered/metrics/test_cpu_monitor.py index d12142f64db8..2d2fb234097c 100644 --- a/test/registered/metrics/test_cpu_monitor.py +++ b/test/registered/metrics/test_cpu_monitor.py @@ -10,7 +10,7 @@ class TestCpuMonitor(unittest.TestCase): def test_cpu_monitor(self): from prometheus_client import REGISTRY - from sglang.srt.metrics.cpu_monitor import start_cpu_monitor_thread + from sglang.srt.observability.cpu_monitor import start_cpu_monitor_thread thread = start_cpu_monitor_thread("test", interval=0.1) self.assertTrue(thread.is_alive()) diff --git a/test/registered/metrics/test_metrics.py b/test/registered/metrics/test_metrics.py index 5a02d5703dce..409ebdbcb636 100644 --- a/test/registered/metrics/test_metrics.py +++ b/test/registered/metrics/test_metrics.py @@ -6,7 +6,7 @@ from prometheus_client.samples import Sample from sglang.srt.environ import envs -from sglang.srt.metrics.collector import ( +from sglang.srt.observability.metrics_collector import ( ROUTING_KEY_REQ_COUNT_BUCKET_BOUNDS, compute_routing_key_stats, ) diff --git a/test/registered/metrics/test_metrics_utils.py b/test/registered/metrics/test_metrics_utils.py index b88656626ede..fc2c8bc74ee4 100644 --- a/test/registered/metrics/test_metrics_utils.py +++ b/test/registered/metrics/test_metrics_utils.py @@ -1,6 +1,9 @@ import unittest -from sglang.srt.metrics.utils import generate_buckets, two_sides_exponential_buckets +from sglang.srt.observability.utils import ( + generate_buckets, + two_sides_exponential_buckets, +) from sglang.test.ci.ci_register import register_cpu_ci register_cpu_ci(est_time=1, suite="stage-a-cpu-only")