Fix mamba radix cache eviction logic in alloc_req_slots#11616
Fix mamba radix cache eviction logic in alloc_req_slots#11616merrymercy merged 7 commits intosgl-project:mainfrom
alloc_req_slots#11616Conversation
Signed-off-by: rogeryoungh <rogeryoungh@foxmail.com>
Summary of ChangesHello @rogeryoungh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical regression in the Mamba radix cache eviction mechanism. A previous refactoring had inadvertently disconnected the eviction logic from its invocation point, leading to potential cache overflow issues. This change rectifies the problem by relocating the Mamba radix cache eviction logic to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a bug where the Mamba radix cache eviction logic was not being called after a recent refactoring. The fix involves moving the eviction logic from the now-deleted alloc_req_slots method into its caller, alloc_for_extend, ensuring it executes at the correct time. The change is logical and well-contained. I have one minor suggestion to improve code conciseness by simplifying a conditional check.
| if mamba_available_size < bs: | ||
| if batch.tree_cache is not None and isinstance( | ||
| batch.tree_cache, MambaRadixCache | ||
| ): | ||
| mamba_num = max(0, bs - mamba_available_size) | ||
| batch.tree_cache.evict_mamba(mamba_num) |
There was a problem hiding this comment.
The check batch.tree_cache is not None is redundant because isinstance(None, MambaRadixCache) evaluates to False. You can combine the nested if statements for conciseness and better readability.
if mamba_available_size < bs and isinstance(batch.tree_cache, MambaRadixCache):
mamba_num = max(0, bs - mamba_available_size)
batch.tree_cache.evict_mamba(mamba_num)Signed-off-by: rogeryoungh <rogeryoungh@foxmail.com>
Signed-off-by: rogeryoungh <rogeryoungh@foxmail.com>
Signed-off-by: rogeryoungh <rogeryoungh@foxmail.com>
Motivation
This pull request addresses an issue with the cache eviction logic introduced in #11214. The eviction step is implemented in
schedule_batch.py, but due to the recent refactoring in #11313 (which movedalloc_req_slotsand its associated logic tocommon.py), the eviction logic wasn't being called. This issue was also reported in #11214 (comment) .This patch ensures that evictions are correctly called from the refactored alloc_req_slots function.
Modifications
Moving eviction logic to
common.py.Accuracy Tests
Benchmarking and Profiling
Checklist