Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Aug 12, 2025

Let's try to keep get_kv_cache_layout as simple as possible, if we start adding if/else checks on backends this is going to very quickly become a mess.

We have two ways to set the kv cache layout right now:

  • envs.VLLM_KV_CACHE_LAYOUT, which should allow the user to specify their preferred layout
  • _KV_CACHE_LAYOUT_OVERRIDE for all other runtime use-cases that require force-setting a layout with priority in code.

I would argue we could probably get away with a single variable to model this behavior, but this PR maintains the two options for now as well as the priority between the two.
Also, the layout should be set once during startup and then become read-only.

What this PR does is that it moves VLLM_USE_TRTLLM_ATTENTION out of get_kv_cache_layout.
As this variable is only used within FA b200, and FA already forces the layout to be HND for b200 here https://github.com/vllm-project/vllm/blob/main/vllm/platforms/cuda.py#L281C21-L281C40, this should be redundant.
cc @mgoin for checking correctness here.

The second change is a small refactor to use_trtllm_attention since it's used in every forward, to avoid checking for environment state every single time and just cache it.

A clearer re-factor will probably be due once things are more stable. In particular, we should probably extend it to support multiple just like Hybrid mem allocator.

cc @tdoublep @LucasWilkinson

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Aug 12, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the logic for handling TRTLLM attention and simplifies how the KV cache layout is determined. The caching of the environment variable check in use_trtllm_attention is a good performance improvement. The simplification of get_kv_cache_layout is also a positive change for maintainability. I've identified a couple of areas for improvement: an incorrect type hint in the new supports_trtllm_attention function, and a logic flaw in get_kv_cache_layout that can lead to misleading log messages. My review includes suggestions to address these points.

@NickLucche NickLucche changed the title cache use_trtllm_attention and simplify get_kv_cache_layout [Kernel] Simplify get_kv_cache_layout and cache use_trtllm_attention env-dependent bit Aug 12, 2025
@NickLucche
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two main improvements. First, it refactors the logic for determining whether to use TensorRT-LLM attention by caching the environment-dependent checks. This is a good performance optimization as it avoids repeated lookups of environment variables and platform capabilities on every forward pass. The new supports_trtllm_attention function with @lru_cache is a clean way to achieve this. Second, it simplifies the get_kv_cache_layout function by removing a redundant check for VLLM_USE_TRTLLM_ATTENTION. The justification that this is handled by the _KV_CACHE_LAYOUT_OVERRIDE mechanism is sound and makes the code easier to understand. The new implementation also makes the priority of different configuration sources clearer. Overall, the changes are well-reasoned and improve both performance and maintainability.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me, thanks for the refactor.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 12, 2025
@mgoin mgoin enabled auto-merge (squash) August 12, 2025 16:16
@DarkLight1337
Copy link
Member

Need to fix merge conflict

@mergify
Copy link

mergify bot commented Aug 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 13, 2025
auto-merge was automatically disabled August 13, 2025 13:48

Head branch was pushed to by a user without write access

@NickLucche NickLucche force-pushed the cache-use-trt-attention branch from 912ea26 to 6cf8c3c Compare August 13, 2025 13:48
@mergify mergify bot removed the needs-rebase label Aug 13, 2025
@NickLucche
Copy link
Collaborator Author

done

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
@NickLucche NickLucche force-pushed the cache-use-trt-attention branch from 7835044 to 989ed93 Compare August 14, 2025 13:10
@mgoin mgoin enabled auto-merge (squash) August 14, 2025 13:53
@NickLucche
Copy link
Collaborator Author

Still seeing some OOMs in recent tests

Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the clean check

@pavanimajety
Copy link
Collaborator

Eventually, it may also make sense to not have a dependency on kv cache layout because trtllm natively supports both HND and NHD layouts. The cubins for NHD would have to be added though.

@mgoin mgoin merged commit 070da66 into vllm-project:main Aug 16, 2025
40 checks passed
666even666 pushed a commit to 666even666/vllm that referenced this pull request Aug 18, 2025
…on` env-dependent bit (vllm-project#22735)

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: Yiwen Chen <[email protected]>
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
divakar-amd pushed a commit to divakar-amd/vllm_upstream that referenced this pull request Aug 20, 2025
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
…on` env-dependent bit (vllm-project#22735)

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: Duncan Moss <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants