Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--crash-dump-folder` | Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled. | `None` | Type: str |
| `--show-time-cost` | Show time cost of custom marks. | `False` | bool flag (set to enable) |
| `--enable-metrics` | Enable log prometheus metrics. | `False` | bool flag (set to enable) |
| `--enable-mfu-metrics` | Enable estimated MFU-related prometheus metrics. | `False` | bool flag (set to enable) |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sufeng-buaa Does this metric have any side effect?

If not should we just reuse the flag above?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think calculating it once per batch still costs time in every round. Even though the overlap scheduling mechanism may hide that overhead. We already have a lot of metric-related code, and if it keeps growing, it might eventually turn into a performance problem. Honestly, I think even the existing metrics should be managed in different levels.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kangyan-Zhou @sufeng-buaa any next steps on this? Would love to help out / contribute in any way

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any more comments. Does @Kangyan-Zhou think anything else needs to be changed?

| `--enable-metrics-for-all-schedulers` | Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) to record request metrics separately. This is especially useful when dp_attention is enabled, as otherwise all metrics appear to come from TP 0. | `False` | bool flag (set to enable) |
| `--tokenizer-metrics-custom-labels-header` | Specify the HTTP header for passing custom labels for tokenizer metrics. | `x-custom-labels` | Type: str |
| `--tokenizer-metrics-allowed-custom-labels` | The custom labels allowed for tokenizer metrics. The labels are specified via a dict in '--tokenizer-metrics-custom-labels-header' field in HTTP requests, e.g., {'label1': 'value1', 'label2': 'value2'} is allowed if '--tokenizer-metrics-allowed-custom-labels label1 label2' is set. | `None` | List[str] |
Expand Down
38 changes: 37 additions & 1 deletion docs/references/production_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ This section describes how to set up the monitoring stack (Prometheus + Grafana)
python -m sglang.launch_server \
--model-path <your_model_path> \
--port 30000 \
--enable-metrics
--enable-metrics \
--enable-mfu-metrics
```
Replace `<your_model_path>` with the actual path to your model (e.g., `meta-llama/Meta-Llama-3.1-8B-Instruct`). Ensure the server is accessible from the monitoring stack (you might need `--host 0.0.0.0` if running in Docker). By default, the metrics endpoint will be available at `http://<sglang_server_host>:30000/metrics`.

Expand Down Expand Up @@ -229,3 +230,38 @@ python3 -m sglang.bench_serving \
to generate some requests.

Then you should be able to see the metrics in the Grafana dashboard.

## Estimated Performance Metrics (MFU-related)

SGLang exports the following estimated per-GPU counters that can be used to derive
Model FLOPs Utilization (MFU)-related signals:

- `sglang:estimated_flops_per_gpu_total`: Estimated floating-point operations.
- `sglang:estimated_read_bytes_per_gpu_total`: Estimated bytes read from memory.
- `sglang:estimated_write_bytes_per_gpu_total`: Estimated bytes written to memory.

These metrics are available when both `--enable-metrics` and
`--enable-mfu-metrics` are enabled.

These are cumulative counters. Use Prometheus `rate(...)` to get per-second values.

### PromQL examples

Average TFLOPS per GPU:

```promql
rate(sglang:estimated_flops_per_gpu_total[1m]) / 1e12
```

Average estimated memory bandwidth in GB/s:

```promql
(rate(sglang:estimated_read_bytes_per_gpu_total[1m]) +
rate(sglang:estimated_write_bytes_per_gpu_total[1m])) / 1e9
```

### Notes

- These metrics are estimates intended for observability and trend analysis.
- Estimated memory bytes reflect modeled traffic and are not a direct hardware
counter from GPU profilers.
43 changes: 43 additions & 0 deletions python/sglang/srt/observability/metrics_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,30 @@ def __init__(
),
labelnames=list(labels.keys()) + ["category"],
)
self.estimated_flops_per_gpu_total = Counter(
name="sglang:estimated_flops_per_gpu_total",
documentation=(
"Estimated number of floating point operations per GPU "
"(for Model FLOPs Utilization calculations)."
),
labelnames=labels.keys(),
)
self.estimated_read_bytes_per_gpu_total = Counter(
name="sglang:estimated_read_bytes_per_gpu_total",
documentation=(
"Estimated number of bytes read from memory per GPU "
"(for Model FLOPs Utilization calculations)."
),
labelnames=labels.keys(),
)
self.estimated_write_bytes_per_gpu_total = Counter(
name="sglang:estimated_write_bytes_per_gpu_total",
documentation=(
"Estimated number of bytes written to memory per GPU "
"(for Model FLOPs Utilization calculations)."
),
labelnames=labels.keys(),
)

self.dp_cooperation_realtime_tokens_total = Counter(
name="sglang:dp_cooperation_realtime_tokens_total",
Expand Down Expand Up @@ -928,6 +952,25 @@ def increment_gpu_execution_seconds(
**dp_cooperation_info.to_labels(),
).inc(t)

def increment_estimated_perf(
self,
num_flops_per_gpu: float = 0.0,
num_read_bytes_per_gpu: float = 0.0,
num_write_bytes_per_gpu: float = 0.0,
) -> None:
if num_flops_per_gpu > 0:
self.estimated_flops_per_gpu_total.labels(**self.labels).inc(
num_flops_per_gpu
)
if num_read_bytes_per_gpu > 0:
self.estimated_read_bytes_per_gpu_total.labels(**self.labels).inc(
num_read_bytes_per_gpu
)
if num_write_bytes_per_gpu > 0:
self.estimated_write_bytes_per_gpu_total.labels(**self.labels).inc(
num_write_bytes_per_gpu
)

def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge_queue_count(self.num_running_reqs, stats.num_running_reqs)
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
Expand Down
187 changes: 185 additions & 2 deletions python/sglang/srt/observability/scheduler_metrics_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
from sglang.srt.disaggregation.utils import DisaggregationMode
Expand Down Expand Up @@ -114,6 +114,7 @@ def init_metrics(
self.stats = SchedulerStats()

# Metrics
self.enable_mfu_metrics = False
self.enable_metrics = self.server_args.enable_metrics
self.is_stats_logging_rank = self.attn_tp_rank == 0
self.current_scheduler_metrics_enabled = self.enable_metrics and (
Expand Down Expand Up @@ -148,6 +149,12 @@ def init_metrics(
enable_hierarchical_cache=self.enable_hierarchical_cache,
server_args=self.server_args,
)
self.enable_mfu_metrics = bool(self.server_args.enable_mfu_metrics)
if self.enable_mfu_metrics:
self._init_estimated_perf_constants()
self._mfu_log_flops = 0.0
self._mfu_log_read_bytes = 0.0
self._mfu_log_write_bytes = 0.0

if ENABLE_METRICS_DEVICE_TIMER:
self.forward_pass_device_timer = DeviceTimer(
Expand Down Expand Up @@ -175,6 +182,139 @@ def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int):
self.spec_num_forward_ct += bs
self.num_generated_tokens += num_accepted_tokens

def _init_estimated_perf_constants(self: Scheduler) -> None:
model_config = self.model_config
hf_text_config = model_config.hf_text_config

hidden_size = float(model_config.hidden_size)
num_layers = float(getattr(model_config, "num_attention_layers", 0))
head_dim = float(getattr(model_config, "head_dim", 0))
num_attn_heads = float(model_config.get_num_attention_heads(self.tp_size))
num_kv_heads = float(model_config.get_num_kv_heads(self.tp_size))
intermediate_size = getattr(hf_text_config, "intermediate_size", None)
if intermediate_size is None:
intermediate_size = getattr(hf_text_config, "ffn_hidden_size", 0)
intermediate_size = float(intermediate_size)

dtype_num_bytes = getattr(model_config.dtype, "itemsize", None)
if dtype_num_bytes is None:
dtype_num_bytes = 2
# Keep this estimator lightweight and consistent with current server dtype.
# KV cache quantization-aware bytes can be added in a follow-up.
act_bytes = float(dtype_num_bytes)
w_bytes = float(dtype_num_bytes)
cache_bytes = float(dtype_num_bytes)

# Linear-layer FLOPs per token on one GPU.
attn_linear_flops = (
2.0 * hidden_size * head_dim * (num_attn_heads + 2.0 * num_kv_heads)
+ 2.0 * hidden_size * head_dim * num_attn_heads
)
mlp_flops = (
6.0 * hidden_size * intermediate_size if intermediate_size > 0 else 0.0
)
self._linear_flops_per_token = max(
0.0, (attn_linear_flops + mlp_flops) * num_layers
)

# Attention dot-product FLOPs coefficient to multiply token-context product.
# attn_qk + attn_av = 4 * q * TC * d * L
self._attn_dot_flops_coeff = 4.0 * num_attn_heads * head_dim * num_layers

# KV cache bytes (write one K and one V vector per generated token).
self._kv_cache_bytes_per_token = (
2.0 * num_layers * num_kv_heads * head_dim * cache_bytes
)

# Weight read bytes per token.
self._weight_read_bytes_per_token = (
hidden_size
* head_dim
* (num_attn_heads + 2.0 * num_kv_heads)
* w_bytes
* num_layers
+ hidden_size * head_dim * num_attn_heads * w_bytes * num_layers
+ (
3.0 * hidden_size * intermediate_size * w_bytes * num_layers
if intermediate_size > 0
else 0.0
)
)

# Activation movement bytes per token (coarse approximation).
self._qkv_act_bytes_per_token = (
hidden_size * act_bytes * num_layers
+ (num_attn_heads + 2.0 * num_kv_heads) * head_dim * act_bytes * num_layers
+ head_dim * num_attn_heads * act_bytes * num_layers
+ hidden_size * act_bytes * num_layers
)
self._ffn_act_bytes_per_token = (
3.0 * intermediate_size * act_bytes * num_layers
if intermediate_size > 0
else 0.0
)

# Prefill reads Q/K/V activations from on-device memory.
self._prefill_attn_act_read_per_token = (
(num_attn_heads + 2.0 * num_kv_heads) * head_dim * act_bytes * num_layers
)

# Decode reads Q from activation memory; K/V reads are from KV cache.
self._decode_q_read_bytes_per_token = (
num_attn_heads * head_dim * act_bytes * num_layers
)

def _estimate_prefill_perf(
self: Scheduler, num_tokens: int
) -> Tuple[float, float, float]:
tokens = max(0, int(num_tokens))
if tokens == 0:
return 0.0, 0.0, 0.0

# Causal prefill token-context product.
context_product = tokens * (tokens + 1) / 2.0
flops = (
tokens * self._linear_flops_per_token
+ self._attn_dot_flops_coeff * context_product
)

read_bytes = (
tokens * self._weight_read_bytes_per_token
+ tokens * self._qkv_act_bytes_per_token
+ tokens * self._prefill_attn_act_read_per_token
)
write_bytes = (
tokens * self._kv_cache_bytes_per_token
+ tokens * self._qkv_act_bytes_per_token
+ tokens * self._ffn_act_bytes_per_token
)
return flops, read_bytes, write_bytes

def _estimate_decode_perf(
self: Scheduler, batch: ScheduleBatch, num_tokens: int
) -> Tuple[float, float, float]:
tokens = max(0, int(num_tokens))
if tokens == 0:
return 0.0, 0.0, 0.0

total_context = float(batch.seq_lens_cpu.sum().item())
flops = (
tokens * self._linear_flops_per_token
+ self._attn_dot_flops_coeff * total_context
)
read_bytes = (
tokens * self._weight_read_bytes_per_token
+ tokens * self._qkv_act_bytes_per_token
+ tokens * self._decode_q_read_bytes_per_token
+ total_context * self._kv_cache_bytes_per_token
)
write_bytes = (
tokens * self._kv_cache_bytes_per_token
+ tokens * self._qkv_act_bytes_per_token
+ tokens * self._ffn_act_bytes_per_token
)
return flops, read_bytes, write_bytes

def reset_metrics(self: Scheduler):
self.forward_ct_decode = 0
self.num_generated_tokens = 0
Expand Down Expand Up @@ -277,6 +417,11 @@ def report_prefill_stats(

msg += f"{graph_backend[self.device]}: {can_run_cuda_graph}"

if self.enable_mfu_metrics and gap_latency > 0:
flops, _, _ = self._estimate_prefill_perf(prefill_stats.log_input_tokens)
tflops_per_s = flops / gap_latency / 1e12
msg += f", est. prefill TFLOPS/s (per GPU): {tflops_per_s:.2f}"

if self.is_stats_logging_rank:
logger.info(msg)

Expand All @@ -289,6 +434,15 @@ def report_prefill_stats(
prefill_cache_tokens=prefill_stats.log_hit_tokens,
dp_cooperation_info=dp_cooperation_info,
)
if self.enable_mfu_metrics:
flops, read_bytes, write_bytes = self._estimate_prefill_perf(
prefill_stats.log_input_tokens
)
self.metrics_collector.increment_estimated_perf(
num_flops_per_gpu=flops,
num_read_bytes_per_gpu=read_bytes,
num_write_bytes_per_gpu=write_bytes,
)

# Basics
total_tokens = prefill_stats.log_input_tokens + prefill_stats.log_hit_tokens
Expand Down Expand Up @@ -356,11 +510,24 @@ def report_decode_stats(

# Every-iteration work: realtime token counting + status logger
if self.current_scheduler_metrics_enabled:
decode_tokens = batch.batch_size() + num_accepted_tokens
self.metrics_collector.increment_realtime_tokens(
# TODO unify this w/ the bumping logic in `Scheduler.num_generated_tokens` accumulator
decode_tokens=batch.batch_size() + num_accepted_tokens,
decode_tokens=decode_tokens,
dp_cooperation_info=batch.dp_cooperation_info,
)
if self.enable_mfu_metrics:
flops, read_bytes, write_bytes = self._estimate_decode_perf(
batch, decode_tokens
)
self.metrics_collector.increment_estimated_perf(
num_flops_per_gpu=flops,
num_read_bytes_per_gpu=read_bytes,
num_write_bytes_per_gpu=write_bytes,
)
self._mfu_log_flops += flops
self._mfu_log_read_bytes += read_bytes
self._mfu_log_write_bytes += write_bytes

if x := self.scheduler_status_logger:
x.maybe_dump(batch, self.waiting_queue)
Expand Down Expand Up @@ -492,6 +659,22 @@ def report_decode_stats(
f"#queue-req: {len(self.waiting_queue)}"
)

if self.enable_mfu_metrics and gap_latency > 0:
flops_per_s = self._mfu_log_flops / gap_latency
read_bytes_per_s = self._mfu_log_read_bytes / gap_latency
write_bytes_per_s = self._mfu_log_write_bytes / gap_latency
tflops_per_s = flops_per_s / 1e12
read_gb_per_s = read_bytes_per_s / 1e9
write_gb_per_s = write_bytes_per_s / 1e9
msg += (
f", est. decode TFLOPS/s (per GPU): {tflops_per_s:.2f}, "
f"est. read BW (GB/s per GPU): {read_gb_per_s:.2f}, "
f"est. write BW (GB/s per GPU): {write_gb_per_s:.2f}"
)
self._mfu_log_flops = 0.0
self._mfu_log_read_bytes = 0.0
self._mfu_log_write_bytes = 0.0

if self.is_stats_logging_rank:
logger.info(msg)
if self.current_scheduler_metrics_enabled:
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ class ServerArgs:
crash_dump_folder: Optional[str] = None
show_time_cost: bool = False
enable_metrics: bool = False
enable_mfu_metrics: bool = False
enable_metrics_for_all_schedulers: bool = False
tokenizer_metrics_custom_labels_header: str = "x-custom-labels"
tokenizer_metrics_allowed_custom_labels: Optional[List[str]] = None
Expand Down Expand Up @@ -4037,6 +4038,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Enable log prometheus metrics.",
)
parser.add_argument(
"--enable-mfu-metrics",
action="store_true",
help="Enable estimated MFU-related prometheus metrics.",
)
parser.add_argument(
"--enable-metrics-for-all-schedulers",
action="store_true",
Expand Down
Loading
Loading