Skip to content

[Metrics] Add labeled prompt token metrics for P/D disaggregation#33290

Merged
markmc merged 2 commits intovllm-project:mainfrom
ZhanqiuHu:feature/labeled-prompt-token-metrics
Feb 4, 2026
Merged

[Metrics] Add labeled prompt token metrics for P/D disaggregation#33290
markmc merged 2 commits intovllm-project:mainfrom
ZhanqiuHu:feature/labeled-prompt-token-metrics

Conversation

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

@ZhanqiuHu ZhanqiuHu commented Jan 28, 2026

Summary

Add labeled Prometheus metrics for P/D disaggregation to distinguish prompt token sources (local compute, KV transfer, cache hit).

Purpose

Add labeled Prometheus metrics to distinguish where prompt tokens come from in P/D disaggregated deployments.

In P/D disaggregation, decode instances receive KV cache from prefill instances. Currently, decode reports inflated prompt throughput because it counts all prompt tokens as "computed", even though most were transferred.

This PR adds labeled metrics so users can understand actual compute work vs transferred work:

vllm:prompt_tokens_by_source_total{source="local_compute"}        # Tokens prefilled locally
vllm:prompt_tokens_by_source_total{source="external_kv_transfer"} # Tokens received via KV transfer  
vllm:prompt_tokens_by_source_total{source="local_cache_hit"}      # Tokens from local prefix cache
vllm:prompt_tokens_cached_total                                    # Total cached (local + external, -1 when all cached)

Note: The -1 adjustment is applied by the scheduler when all prompt tokens are cached (from local cache or KV transfer). This forces the model to recompute the last prompt token locally, since the model needs at least one input token to run a forward pass.

Fixes #33289
Related: PR #27569

Test Plan

  1. P/D disaggregation setup using examples/online_serving/disaggregated_prefill.sh
    • Model: Qwen/Qwen3-0.6B
  2. Send a request:
    • Prompt: 10 tokens
    • Max output tokens: 20
  3. Verify metrics via Prometheus endpoint:
# Check prefill instance metrics
curl localhost:<prefill-port>/metrics 

# Check decode instance metrics  
curl localhost:<decode-port>/metrics 

Test Result

Screenshot 2026-01-28 at 5 43 55 PM
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces valuable new Prometheus metrics for disaggregating prompt token sources in P/D deployments. The changes are well-structured, adding a new PromptTokenStats dataclass and integrating the metric collection into the existing stats and logging infrastructure.

I've found a correctness issue in the token breakdown calculation that leads to an overcounting of tokens when a prompt is fully cached. I've left a detailed comment with a suggested fix. Additionally, I've recommended adding unit tests for this new complex logic to ensure its correctness and prevent future regressions.

Once these points are addressed, this will be a great addition for improving observability.

Comment thread vllm/v1/metrics/stats.py Outdated
Comment on lines +293 to +314
if is_prefilling:
# Compute breakdown from existing fields:
# - num_cached_tokens: total cached (local + external), has -1 adjustment
# when all tokens were cached (to force recomputation of last token)
# - num_external_computed_tokens: original kv transfer count from connector
#
# Check if -1 adjustment was applied (all tokens were cached)
adjustment = 1 if (output.num_cached_tokens + 1 == prompt_len) else 0

computed = prompt_len - output.num_cached_tokens
external_kv_transfer = output.num_external_computed_tokens
local_cache_hit = (
output.num_cached_tokens
+ adjustment
- output.num_external_computed_tokens
)
cached_tokens = output.num_cached_tokens

self.prompt_token_stats.computed += computed
self.prompt_token_stats.local_cache_hit += local_cache_hit
self.prompt_token_stats.external_kv_transfer += external_kv_transfer
self.prompt_token_stats.cached_tokens += cached_tokens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This new logic for calculating the prompt token breakdown is non-trivial and has some subtleties, especially around the -1 adjustment for fully cached prompts. To ensure its correctness and prevent future regressions, it would be highly beneficial to add unit tests covering various scenarios:

  • No caching
  • Partial local caching
  • Partial external KV transfer
  • A mix of local and external caching
  • Fully cached prompt (triggering the -1 adjustment)

These tests would help verify the logic and make the code more robust.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Unit tests for PromptTokenStats.update_from_output() would be pretty straightforward, and useful for documenting your intent

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the PR with:

  1. Refactored the breakdown logic into PromptTokenStats.update_from_output(), and
  2. Added unit tests covering suggested scenarios: all computed, partial local cache, external KV transfer, mixed sources, and full cache with recompute in tests/v1/metrics/test_stats.py.

Comment thread vllm/v1/metrics/stats.py Outdated
Comment on lines +304 to +308
local_cache_hit = (
output.num_cached_tokens
+ adjustment
- output.num_external_computed_tokens
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There appears to be a miscalculation in the prompt token breakdown. The sum of computed, local_cache_hit, and external_kv_transfer should equal prompt_len. However, with the current logic, it sums to prompt_len + adjustment, which is incorrect when adjustment is 1 (i.e., when all prompt tokens are cached).

When adjustment is 1, it means one token that was originally from the cache is being recomputed. This token is correctly counted in computed, but it is not subtracted from the cached token counts, leading to it being double-counted in the total.

To fix this, the local_cache_hit calculation should not add the adjustment. The number of local cache hits is the total number of cached tokens (which already includes the -1 adjustment) minus the tokens from external transfer.

            local_cache_hit = (
                output.num_cached_tokens - output.num_external_computed_tokens
            )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: external_kv_transfer reports the actual number of tokens transferred (e.g., prompt length N), while prompt_tokens_cached_total reports the adjusted count (e.g., N-1). The last token is both transferred AND recomputed locally, so there's overlap.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fair any future person looking at this is going to be super confused

Maybe instead of adjustment, you could have recomputed_tokens with a comment above it why it's 1 in this case

Basically, I think it's good feedback that the LLM wasn't able to understand the subtleties here ... play around a little bit more with variable naming and comments, etc.

But you're pretty close to having made it as easy as possible to understand IMO

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to recomputed_tokens and added a comment explaining the scheduler behavior in vllm/v1/metrics/stats.py.

@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

@markmc Would appreciate your feedback on this. It implements the labeled prompt token metrics from #27569 discussion.

Copy link
Copy Markdown
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking really good to me, thanks @ZhanqiuHu!

FYI, if you're using Claude or whatever, it's good to add:

Co-authored-by: Claude <noreply@anthropic.com>

in your commit message

Comment thread vllm/v1/metrics/stats.py Outdated
external_kv_transfer: int = 0 # From P/D KV transfer
cached_tokens: int = 0 # Total cached (has -1 adjustment when all cached)


class IterationStats:
"""Stats associated with a single set of EngineCoreOutputs."""

def __init__(self):
self.iteration_timestamp = time.time()
self.num_generation_tokens = 0
self.num_prompt_tokens = 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move num_prompt_tokens into PromptTokenStats

Comment thread vllm/v1/metrics/stats.py Outdated
@@ -267,6 +291,30 @@ def update_from_output(

self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
# Compute breakdown from existing fields:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to a PromptTokenStats.update_from_output(output, prompt_len) method - it's complex enough to benefit from being standalone

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment thread vllm/v1/metrics/stats.py Outdated
Comment on lines +304 to +308
local_cache_hit = (
output.num_cached_tokens
+ adjustment
- output.num_external_computed_tokens
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fair any future person looking at this is going to be super confused

Maybe instead of adjustment, you could have recomputed_tokens with a comment above it why it's 1 in this case

Basically, I think it's good feedback that the LLM wasn't able to understand the subtleties here ... play around a little bit more with variable naming and comments, etc.

But you're pretty close to having made it as easy as possible to understand IMO

Comment thread vllm/v1/metrics/stats.py Outdated
Comment on lines +293 to +314
if is_prefilling:
# Compute breakdown from existing fields:
# - num_cached_tokens: total cached (local + external), has -1 adjustment
# when all tokens were cached (to force recomputation of last token)
# - num_external_computed_tokens: original kv transfer count from connector
#
# Check if -1 adjustment was applied (all tokens were cached)
adjustment = 1 if (output.num_cached_tokens + 1 == prompt_len) else 0

computed = prompt_len - output.num_cached_tokens
external_kv_transfer = output.num_external_computed_tokens
local_cache_hit = (
output.num_cached_tokens
+ adjustment
- output.num_external_computed_tokens
)
cached_tokens = output.num_cached_tokens

self.prompt_token_stats.computed += computed
self.prompt_token_stats.local_cache_hit += local_cache_hit
self.prompt_token_stats.external_kv_transfer += external_kv_transfer
self.prompt_token_stats.cached_tokens += cached_tokens
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Unit tests for PromptTokenStats.update_from_output() would be pretty straightforward, and useful for documenting your intent

Comment thread vllm/v1/metrics/stats.py Outdated
vllm:prompt_tokens_by_source_total{source="external_kv_transfer"}
"""

computed: int = 0 # Locally computed (actual prefill work)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think a lot of comments in this PR could be dropped without affecting how understandable the code is

e.g. here you're repeating what's in the docstring a few lines above

Comment thread vllm/v1/metrics/loggers.py Outdated
counter_prompt_tokens_by_source.labels(
model_name, str(idx), "local_compute"
)
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we're basically repeating the sources = ["local_compute", "local_cache_hit", "external_kv_transfer"] list

Might be worth doing something like:

class PromptTokenStats:
    ALL_SOURCES = ["local_compute", "local_cache_hit", "external_kv_transfer"]

and using that here and in observe()

(I'm sure something a bit more elegant is possible though, and if it turns out to make things less readable, feel free to push back)

Comment thread vllm/v1/metrics/loggers.py Outdated
@@ -590,6 +590,43 @@ def __init__(
counter_prompt_tokens, engine_indexes, model_name
)

# Labeled prompt token counters by source (for P/D disaggregation)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not just for P/D - even the computed vs local is useful on its own

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the comments to remove "(for P/D disaggregation)".

@markmc
Copy link
Copy Markdown
Member

markmc commented Jan 30, 2026

Update CLI logger to show breakdown
Update console prompt throughput calculation to reflect different sources, so decode nodes show accurate "local compute" throughput instead of inflated total

I think I'd like to see this used in CLI logging somehow in this PR, if even just to validate the approach - let's bear in mind Nicolo's original take that the console logging is "plain wrong" 👍

@markmc markmc moved this from P1 to In Review in Metrics & Tracing Jan 30, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the feature/labeled-prompt-token-metrics branch from 7cb28dc to e552fca Compare February 2, 2026 16:04
@ZhanqiuHu ZhanqiuHu requested a review from orozery as a code owner February 2, 2026 16:04
@ZhanqiuHu ZhanqiuHu force-pushed the feature/labeled-prompt-token-metrics branch 2 times, most recently from 9671177 to d829d03 Compare February 2, 2026 16:10
@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

Update CLI logger to show breakdown
Update console prompt throughput calculation to reflect different sources, so decode nodes show accurate "local compute" throughput instead of inflated total

I think I'd like to see this used in CLI logging somehow in this PR, if even just to validate the approach - let's bear in mind Nicolo's original take that the console logging is "plain wrong" 👍

Done. In the updated PR: Changed LoggingStatLogger to use computed tokens for prompt throughput calculation.
Now the decode output prints out (decode instance processes 1 prompt token per request):

Engine 000: Avg prompt throughput: 0.9 tokens/s, Avg generation throughput: 88.7 tokens/s, 
Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%, 
External prefix cache hit rate: 100.0%

@ZhanqiuHu ZhanqiuHu requested a review from markmc February 3, 2026 16:09
Copy link
Copy Markdown
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work!

@markmc markmc moved this from In Review to Ready in Metrics & Tracing Feb 3, 2026
@markmc markmc added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 3, 2026
Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
@markmc markmc force-pushed the feature/labeled-prompt-token-metrics branch from d829d03 to 69e3a6c Compare February 3, 2026 19:23
@markmc markmc merged commit 4403e3e into vllm-project:main Feb 4, 2026
42 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in Metrics & Tracing Feb 4, 2026
markmc added a commit to markmc/vllm that referenced this pull request Feb 4, 2026
Needed by vllm-project#33290

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
gameofdimension pushed a commit to gameofdimension/vllm that referenced this pull request Feb 5, 2026
…lm-project#33290)

Add labeled Prometheus metrics to distinguish where prompt tokens come
from in P/D disaggregated deployments.

In P/D disaggregation, decode instances receive KV cache from prefill instances.
Currently, decode reports inflated prompt throughput because it counts all
prompt tokens as "computed", even though most were transferred.

This PR adds labeled metrics so users can understand actual compute work vs
transferred work:

vllm:prompt_tokens_by_source_total{source="local_compute"}        # Tokens prefilled locally
vllm:prompt_tokens_by_source_total{source="external_kv_transfer"} # Tokens received via KV transfer
vllm:prompt_tokens_by_source_total{source="local_cache_hit"}      # Tokens from local prefix cache
vllm:prompt_tokens_cached_total                                    # Total cached (local + external, -1 when all

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
Prowindy added a commit to Prowindy/vllm that referenced this pull request Feb 8, 2026
In P/D (Prefill/Decode) disaggregated deployments, the local_cache_hit
metric could become negative when external KV transfer tokens exceed
locally cached tokens. This caused Prometheus counter increment failures
with ValueError: "Counters can only be incremented by non-negative amounts."

The fix clamps local_cache_hit to non-negative values using max(0, ...).

Root cause:
- In P/D disagg, decode receives tokens via external KV transfer
- The calculation: local_cache_hit = num_cached_tokens - num_external_computed_tokens
- When external > cached, this goes negative
- Prometheus counters reject negative increments

Example scenario:
- Prefill sends 7000 tokens to decode via NIXL
- Decode has 0 local cache
- Old: local_cache_hit = 0 - 7000 = -7000 (CRASH!)
- New: local_cache_hit = max(0, 0 - 7000) = 0 (OK)

Fixes the regression introduced in vllm-project#33290.

Signed-off-by: Cong Chen <congc@meta.com>
@simon-mo
Copy link
Copy Markdown
Collaborator

simon-mo commented Feb 9, 2026

@ZhanqiuHu PTAL at #34079

markmc added a commit to markmc/vllm that referenced this pull request Apr 10, 2026
In the case of a full local prefix cache hit (prompt length N),
we actually only use N-1 tokens. The `vllm:prompt_tokens_recomputed`
was intended to count how many cached tokens we are effectively
discarding because of this.

```
KVCacheManager.get_computed_blocks():
    ...
    # NOTE: When all tokens hit the cache, we must recompute the last token
    # to obtain logits. [...]
    max_cache_hit_length = request.num_tokens - 1
```

However, even here, we can't assume the last token would have been
a cache hit and should be counted as "recomputed". Given this, the
metric seems quite misguided, in retrospect.

The metric was added as a side-effect in vllm-project#33290 in order to make
sense of the fact that:

```
vllm:prompt_tokens_by_source_total{source="external_kv_transfer"}
```

will include a token that is recomputed. See this comment:

> Note: external_kv_transfer reports the actual number of tokens
> transferred (e.g., prompt length N), while prompt_tokens_cached_total
> reports the adjusted count (e.g., N-1). The last token is both
> transferred AND recomputed locally, so there's overlap.

However, it makes more sense for the `external_kv_transfer` count to
reflect only tokens we actually used, not any recomputed tokens. This
will be done in #vllm-project#37460.

I'm not aware of any user demand for this metric, or anyone relying
on it now. So it seems safe to remove it, rather than go through
a deprecation period.

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Feature]: [Metrics] Labeled prompt token metrics for P/D disaggregation (Follow-up on PR #27569)

3 participants