[Bugfix][P/D] Fix throughput stats in disaggregated setup#27569
[Bugfix][P/D] Fix throughput stats in disaggregated setup#27569NickLucche wants to merge 7 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a bug in the prompt throughput stats calculation for disaggregated setups, where tokens copied from the primary to the decoder were incorrectly counted as prefilled tokens. The changes involve modifying the scheduler to track locally prefilled tokens and updating the logging to reflect the corrected throughput. The review focuses on ensuring the correctness of the fix and the clarity of the code changes.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # Prefill is to be recomputed locally. | ||
| request.num_external_computed_tokens = 0 |
There was a problem hiding this comment.
@sdavidbd can you please double check this, my understanding is that we have to re-compute the whole prefill now so we can track prompt throughput
There was a problem hiding this comment.
Even from the docstring, it seems clear we don't always re-compute the whole prefill:
This method scans the given requests, detects those with invalid blocks and adjusts their
num_computed_tokensto the longest valid prefix.
A few things:
- I think
num_external_computed_tokens = 0should only happen inside thenot marked_invalid_blockclause - here we're saying all externally computed blocks are invalid - (Unrelated to this PR - an observation) Setting
request.num_computed_tokens = request.num_cached_tokenson line 1489 doesn't make sense to me - sincenum_cached_tokensincludes both local and external computed tokens? - We should update
num_external_computed_tokensat# Truncate the computed tokens at the first failed block- to something likerequest.num_computed_tokens - local_computed_tokens(but not obvious how we calculatedlocal_computed_tokens)
There was a problem hiding this comment.
@NickLucche — as @markmc noted, we don’t recompute the entire prefill. Only the externally computed tokens starting from the first failed block are recomputed.
To correctly update num_external_computed_tokens, we should first determine how many externally computed tokens are affected. This can be derived from the delta between the original and truncated num_computed_tokens — the same tokens already aggregated in total_affected_tokens (lines 1473–1477):
# Truncate the computed tokens at the first failed block
request.num_computed_tokens = idx * self.block_size
num_affected_tokens = req_num_computed_tokens - request.num_computed_tokens
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
There was a problem hiding this comment.
@markmc — regarding your points:
- The
not marked_invalid_blockcondition covers the sync-loading edge case where a request is affected by externally computed tokens that failed to load but are shared with preceding requests that will handle their recomputation. In this situation, the affected request still treats those tokens as locally computed, so itsnum_external_computed_tokensremains unchanged.
For example, assuming block_size = 1 and the following prompts (with R1receding R2 in the batch):
R1: t1 t2 t3
R2: t1 t2 t4 t5
Suppose t1 is locally computed, t2 and t4 are externally computed, and t2 fails to load while t4 succeeds. Then:
| Request | num_computed_tokens | num_external_computed_tokens |
|---|---|---|
| R1 | 2 | 1 |
| R2 | 3 | 1 |
| Request | num_computed_tokens | num_external_computed_tokens |
|---|---|---|
| R1 | 1 | 0 |
| R2 | 3 | 1 |
Both R1 and R2 are affected and will recompute t2, t3 and t5 respectively, but R2’s total number of computed tokens remains unchanged.
-
Correct —
num_cached_tokensrepresents the total number of computed tokens (both local and external). Settingnum_computed_tokens = num_cached_tokensensures that all new tokens are recomputed in the current iteration, since the previousnum_computed_tokensvalue already included them. -
Agreed — see my suggested code changes above for how we update
num_external_computed_tokensaccordingly.
| # Prefill is to be recomputed locally. | ||
| request.num_external_computed_tokens = 0 |
There was a problem hiding this comment.
Even from the docstring, it seems clear we don't always re-compute the whole prefill:
This method scans the given requests, detects those with invalid blocks and adjusts their
num_computed_tokensto the longest valid prefix.
A few things:
- I think
num_external_computed_tokens = 0should only happen inside thenot marked_invalid_blockclause - here we're saying all externally computed blocks are invalid - (Unrelated to this PR - an observation) Setting
request.num_computed_tokens = request.num_cached_tokenson line 1489 doesn't make sense to me - sincenum_cached_tokensincludes both local and external computed tokens? - We should update
num_external_computed_tokensat# Truncate the computed tokens at the first failed block- to something likerequest.num_computed_tokens - local_computed_tokens(but not obvious how we calculatedlocal_computed_tokens)
| @@ -121,6 +121,8 @@ class EngineCoreOutput( | |||
| trace_headers: Mapping[str, str] | None = None | |||
| # The number of tokens with prefix cache hits. | |||
There was a problem hiding this comment.
Yeah, this comment looks incorrect ... assuming "prefix cache" refers to the local cache?
# Total computed tokens (local + external).
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
)
...
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
There was a problem hiding this comment.
I am not familiar with it cc @chaunceyjiang
| @@ -221,6 +221,8 @@ def __init__(self): | |||
| self.num_generation_tokens = 0 | |||
| self.num_prompt_tokens = 0 | |||
| self.num_preempted_reqs = 0 | |||
| # Num of prompt tokens that have been computed locally. | |||
There was a problem hiding this comment.
Is the naming here a big confusing? By "computed locally" here we mean both computed and locally cached?
If you just tracked num_external_computed_tokens and then subtracted it in _track_iteration_stats() would that be more clear?
There was a problem hiding this comment.
By "computed locally" here we mean both computed and locally cached?
Yes the behavior is unchanged, cached ones would still result in higher throughput even in regular aggregated setup.
If you just tracked num_external_computed_tokens and then subtracted it in _track_iteration_stats() would that be more clear?
I think looking at the diff
self.num_prompt_tokens += iteration_stats.num_prompt_tokens
-->
self.num_prompt_tokens += iteration_stats.num_local_prompt_tokens
this is pretty clear that I just want to rule out the remote tokens ie I assume the semantic was the intended one from the beginning, it's just "local" used to be redundant
| @@ -121,6 +121,8 @@ class EngineCoreOutput( | |||
| trace_headers: Mapping[str, str] | None = None | |||
| # The number of tokens with prefix cache hits. | |||
| num_cached_tokens: int = 0 | |||
| # The number of tokens that have been computed remotely. | |||
| num_external_computed_tokens: int = 0 | |||
There was a problem hiding this comment.
I'd be tempted to refactor these two into a PrefillStats object ... and only include that in the ECO when the prefill completes ... especially if we ever wanted to also send like num_locally_cached_tokens too
There was a problem hiding this comment.
I don't have a strong opinion on this tbh, we can probably wait to have a few more things to bundle before executing the suggestion
| @@ -113,7 +113,7 @@ def _reset(self, now): | |||
|
|
|||
| def _track_iteration_stats(self, iteration_stats: IterationStats): | |||
There was a problem hiding this comment.
Presumably you want to update the Prometheus metric too?
There was a problem hiding this comment.
@markmc which one? I intentionally left self.counter_prompt_tokens unchanged to avoid replacing the actual prompt count.
Should I just make a new one for local tokens?
c3d6723 to
950baf4
Compare
|
Sorry for the delay in coming back to this. I see more clearly now where you're coming from. On a decode instance, you want to see this: i.e. "this vLLM instance isn't doing any prefill computation" So, you're proposing that we subtract any prefilling done by a KV connector from the prompt throughput reported on the console If that's the desired behavior, the code changes lgtm now. However, the whole thing raises a lot of other questions for me!
On the other hand though, looks like you found a clear bug here: I have a fix for that, just need to update tests before submitting |
|
I think these are all good questions, but as you noted I was really just proposing a fix for a very specific use-case.
Don't have a strong opinion on Grafana. Open to address that in this or a follow-up PR if needs be. In general my approach was: counting prompt tokens is not wrong, but it's wrong trying to derive throughput from them since result is bonkers.
To be honest I don't know and don't have a strong opinion on it, I see how in practice you want to show a signal rather than 0.0 in throughput in the regular colocated setup. Perhaps it's better discussed in a separate issue, I am was just trying to fix the disaggregated case.
I think in this case the rationale is actually the same as in this PR, even offloaded KV caches are not necessarily computed locally. |
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: David Ben-David davidb@pliops.com Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
be1002f to
4d2bacb
Compare
|
@markmc let me know if you have some other comment for this fix |
The throughput figure is just expressing a counter as a rate. In a Grafana dashboard, we do the same with e.g. That's why I think "Prompt Throughput" in Prometheus/Grafana should mean the same thing as in the console log. And that's just to say I think it's important that we can articulate a consistent mental model for what all of these metrics mean, whether in the console log or Prometheus. Here's my mental model for how we're counting things now: With that mental model, this: is not bonkers at all? 53.2 prompt token/s came in at the top, none were found in the local prefix cache, 100% were retrieved from the connector, and we generated 15 tokens/s If you see it differently, could you draw your mental model for these counters, in such a way that the same model can be applied consistently across use cases? |
Currently we are recording async connector prefix cache queries and hits twice, for the same reason that update_state_after_alloc() is called twice: > If get_num_new_matched_tokens previously returned True for a > request, this function may be called twice for that same request - > first when blocks are allocated for the connector tokens to be > asynchronously loaded into, and second when any additional blocks > are allocated, after the load/transfer is complete. Worse, the second time we are recording with `num_external_computed_tokens=0` so effectively we are halving the hit rate. Before ``` External prefix cache hit rate: 100.0% ``` After ``` External prefix cache hit rate: 50.0% ``` Borrows part of vllm-project#27569 to track `num_external_computed_tokens` for use when the KV transfer completes. Will use vllm-project#28550 for testing this scenario. Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com>
|
From an end user perspective, I would argue the top level outcome is: I want a “prompt throughput” prometheus metric with various sources (a category / type / source label), and I want the sum of the series under that label to equal the observed throughput in terms of prompt tokens processed by this server. I want the label to be sufficiently scoped that I can identify:
Then the metrics printed are a derivative of that end user observability outcome. |
Claude helped me out with this version of your proposal ... does it match? (There's definitely still some information duplication with the Proposal: Labeled Prompt Token CountersPrometheus MetricsReplace single Grafana QueriesTotal throughput (all tokens flowing through): Compute load (actual prefill work): KV transfer verification: Cache effectiveness: CLI LoggerDisplay breakdown: Benefits
|
|
Roughly yes We are likely to have more cache hit sources in the future. I'd suggest two labels, Otherwise, that structure looks like what I would expect EDIT: One note, we should pick values for |
|
I'm ok with fleshing out current metric, I think end-user clarity would benefit from it. We can repurpose this PR to implement at least the compute/transfer I suppose. Re: cache source, why would we move away from internal/external? Where internal is co-located hbm wrt current EngineCore, external is anything else. |
|
Re: cache source: An operator who has multiple external sources would want to know which source it is. I'm not arguing for changing labels that exist, just ensuring that we can attribute the external contribution appropriately. I agree that preserving existing labels without disruption is better than removing them. |
|
This pull request has merge conflicts that must be resolved before it can be |
|
#33290 has merged, so I think we can close this now. Cool! |
Fix prompt throughput stats in CLI logger by only accounting for tokens that were prefilled locally.
In a P/D setup, kv cache is copied over from P to D, and this currently resolves in outputting the following on the Decoder side:
which is plain wrong given we have not actually prefilled those tokens, but actually just "copied" them over.
After this PR: