Skip to content

Conversation

@JialinOuyang-Meta
Copy link
Contributor

@JialinOuyang-Meta JialinOuyang-Meta commented Jul 15, 2025

Summary:

Optimizations

As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.eq comparison easily.

  • No dataclass.eq invocation
  • Shorter code
  • Branchless

All these combined should yield significant perf improvement for this piece of code.

Observations

Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill.
Screenshot 2025-07-14 at 10 26 07 AM

|{F1980260529}|{F1980260481}|{F1980260497}|

By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.eq which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison.

Per dataclasses python library doc

dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type.

If the class already defines __eq__(), this parameter is ignored.

Test Plan:

Result

Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs).

Benchmark

After
Screenshot 2025-07-15 at 10 25 28 AM
Before
Screenshot 2025-07-15 at 10 23 56 AM|

Stack

After
Screenshot 2025-07-14 at 10 25 04 AM
Before
Screenshot 2025-07-14 at 10 26 07 AM

Rollback Plan:

Reviewed By: CuiCoco

Differential Revision: D78292345

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78292345

@mergify mergify bot added performance Performance-related issues v1 labels Jul 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization to the FreeKVCacheBlockQueue by implementing a doubly linked list with sentinel nodes. This change effectively removes expensive __eq__ comparisons on KVCacheBlock dataclasses, which should improve performance as demonstrated by the new benchmark. The implementation is a classic and well-executed approach.

My review focuses on ensuring the robustness of this new implementation. I've identified a couple of areas where adding validation checks could prevent potential crashes from state inconsistencies, making the system more resilient. These changes should have a negligible performance impact while significantly improving debuggability and correctness guarantees.

@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78292345

1 similar comment
@facebook-github-bot
Copy link

This pull request was exported from Phabricator. Differential Revision: D78292345

@JialinOuyang-Meta JialinOuyang-Meta changed the title Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue [Core] Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue Jul 15, 2025
…project#21005)

Summary:
Pull Request resolved: vllm-project#21005

# Optimizations
As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.__eq__ comparison easily.
- No dataclass.__eq__ invocation
- Shorter code
- Branchless

All these combined should yield significant perf improvement for this piece of code.

# Observations
Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill.
|{F1980260529}|{F1980260481}|{F1980260497}|

By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.__eq__ which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison.

Per [dataclasses python library doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass)
```
dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False)

eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type.

If the class already defines __eq__(), this parameter is ignored.
```

Test Plan:
# Result
Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs).

|After|Before|
|{F1980286936}|{F1980286941}|

Rollback Plan:

Reviewed By: CuiCoco

Differential Revision: D78292345

Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
Copy link
Collaborator

@yeqcharlotte yeqcharlotte left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JialinOuyang-Meta some graph in your test plan is broken could you fix it?

cc: @njhill @WoosukKwon could you take another look?

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 16, 2025
@Jialin
Copy link
Collaborator

Jialin commented Jul 17, 2025

resolve #21141

Signed-off-by: Jialin Ouyang <[email protected]>
Signed-off-by: Jialin Ouyang <[email protected]>
@simon-mo simon-mo merged commit 0f199f1 into vllm-project:main Jul 18, 2025
63 of 65 checks passed
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JialinOuyang-Meta! I have a few small comments, perhaps could be done as a follow-on.

Sorry for the delay, I was on vacation for the last week.

if not self.free_list_head:
if (self.fake_free_list_head.next_free_block
is self.fake_free_list_tail
or self.fake_free_list_head.next_free_block is None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this second check needed? Would self.fake_free_list_head.next_free_block ever be None ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logically, it should NOT be needed. But this would make pyre happy, otherwise, it would complain about we can't assign Optional[KVCacheBlock] to KVCacheBlock in L256.

Just curious, what's the typically way in vLLM to suppress pyre without such extra checks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using asserts or selective ignore directives

Comment on lines +302 to +304
if self.fake_free_list_tail.prev_free_block is None:
raise RuntimeError(
"prev_free_block of fake_free_list_tail should always exist")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is needed, or it should be an assert

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, only added to make pyre stop complaining against Optional[KVCacheBlock] -> KVCacheBlock assignments in L305.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine but it should be an assert in this case.

Comment on lines +324 to +326
if self.fake_free_list_head.next_free_block is None:
raise RuntimeError(
"next_free_block of fake_free_list_head should always exist")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment. If we want integrity check here it should be an assert.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar reply as above. Would love to hear the best practice to suppress pyre without extra checks or computations.

And I will definitely address all your comments once I get some guidances on this point :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert should work the same way and is more correct since this is an integrity check.

Comment on lines +222 to +227
# Create a fake head and a tail block for the doubly linked list to
# reduce branching in the code
#
# The implementation garenteed that the fake head and tail
# are NEVER got popped, so we could safely assume each real blocks
# in the queue has prev and next blocks.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggested rewording of the comment

        # Create a fake head tail blocks for the doubly linked list to
        # reduce branching in the code.
        #
        # The implementation guarantees that the fake head and tail
        # are NEVER popped, so we can safely assume each real block
        # in the queue has prev and next blocks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate that! Will address all your comments in a followup PR.

@Jialin
Copy link
Collaborator

Jialin commented Jul 19, 2025

Thanks @JialinOuyang-Meta! I have a few small comments, perhaps could be done as a follow-on.

Sorry for the delay, I was on vacation for the last week.

No worry. And appreciate your inputs!

x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants