[AMD][AITER] Dynamically define max_num_total_tokens to avoid OOMs in AITER attention backend buffers allocation#18263
Conversation
clean fix
Summary of ChangesHello @fxmarty-amd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the memory management strategy within the SGLang runtime, specifically for the AITER attention backend. It resolves critical Out-Of-Memory errors by implementing a sophisticated dynamic calculation for Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a dynamic calculation for max_num_total_tokens to prevent Out-Of-Memory errors when using the AITER attention backend. The core of the change is a new method, _solve_max_tokens_with_aiter_workspace, which correctly accounts for the AITER workspace memory by solving a piecewise function. This is a significant improvement over the previous static memory allocation. The removal of the related heuristic in server_args.py is also a good cleanup. I've included one suggestion to address a potential oversight where the workspace memory is not accounted for when max_running_requests is explicitly set, which could still lead to OOMs.
| if ( | ||
| self.server_args.attention_backend == "aiter" | ||
| and self.mambaish_config is None | ||
| and self.server_args.max_running_requests is None | ||
| ): | ||
| # `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define | ||
| # `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention. | ||
| # The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation. | ||
| max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace( | ||
| rest_memory_bytes, cell_size, num_layers | ||
| ) | ||
| else: | ||
| # No workspace overhead for other backends | ||
| max_total_num_tokens = rest_memory_bytes // cell_size |
There was a problem hiding this comment.
The current logic correctly handles the dynamic calculation of max_total_num_tokens when max_running_requests is not set. However, when max_running_requests is specified, the code falls through to the else block, which does not account for the AITER workspace memory. This could still lead to Out-Of-Memory errors, which this pull request aims to fix.
To make this more robust, I suggest handling the fixed max_running_requests case explicitly for the aiter backend by calculating the workspace size and subtracting it from the available memory. This ensures memory is correctly provisioned in all scenarios for the aiter backend.
if (
self.server_args.attention_backend == "aiter"
and self.mambaish_config is None
):
from sglang.srt.configs.model_config import AttentionArch
if self.model_config.attention_arch == AttentionArch.MLA:
# For MLA, workspace is allocated dynamically, not during init
max_total_num_tokens = rest_memory_bytes // cell_size
elif self.server_args.max_running_requests is None:
# `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define
# `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention.
# The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation.
max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace(
rest_memory_bytes, cell_size, num_layers
)
else:
# When max_running_requests is set, max_num_reqs is fixed.
# We need to account for the AITER workspace memory.
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
num_head = self.model_config.num_attention_heads // get_attention_tp_size()
head_dim = self.model_config.head_dim
max_context_len = self.model_config.context_len
max_num_partitions = AiterAttnBackend.get_max_num_partitions(max_context_len)
W = num_head * max_num_partitions * (head_dim * 4 + 8)
max_num_reqs = self.server_args.max_running_requests
aiter_workspace = max_num_reqs * W
max_total_num_tokens = (rest_memory_bytes - aiter_workspace) // cell_size
else:
# No workspace overhead for other backends
max_total_num_tokens = rest_memory_bytes // cell_sizeThere was a problem hiding this comment.
For simplicity, not included in this PR. It could be included in a future PR.
max_num_total_tokens to avoid OOMs in AITER attention backend buffers allocationmax_num_total_tokens to avoid OOMs in AITER attention backend buffers allocation
There was a problem hiding this comment.
Older workaround for this issue (not working in all cases)
python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
Outdated
Show resolved
Hide resolved
| candidate_max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size | ||
| if candidate_max_total_num_tokens / context_len * 512 >= 4096: | ||
| return int(candidate_max_total_num_tokens) | ||
|
|
There was a problem hiding this comment.
Do we need these two conditions?
if candidate_max_total_num_tokens / context_len * 512 <= 2048:
return int(candidate_max_total_num_tokens)
if candidate_max_total_num_tokens / context_len * 512 >= 4096:
return int(candidate_max_total_num_tokens)
Our purpose is to get the rest_memory_bytes when max_num_requests < 2048 or > 4096 and automatically to clamp max_num_requests to the min: 2048 and max: 4096 to calculate the maximum workspace size of aiter backend.
When we get the real rest_memory_bytes to calculate the max_total_num_tokens of kv_buffer , do we still need to check this value needs to meet these conditions (<=2048 or >=4096)?
From my point, we should return candidate_max_total_num_tokens directly, not need these conditions check.
Could you explain more why we need these two conditions?
There was a problem hiding this comment.
Humm, what would you return directly exactly?
The three cases correspond to the three regions of
sglang/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
Lines 304 to 313 in 079fc8f
We don't know which region we are in until we solve for max_total_num_tokens. The three checks assume a value of max_num_reqs (2048, 4096, or max_total_num_tokens * 512 / context_len), resolve a candidate max_total_num_tokens, and verify which of the three regions we are in.
|
CI summary: In https://github.com/sgl-project/sglang/actions/runs/21694948459/job/62564157114?pr=18263:build failed (likely unrelated to this PR): in https://github.com/sgl-project/sglang/actions/runs/21694948459/job/62564157059?pr=18263caused by => fails locally as well on MI355X. We fix in b2d06b0 by increasing in the test in https://github.com/sgl-project/sglang/actions/runs/21694948474/job/62563234452?pr=18263unrelated to this PR it seems: caused by In https://github.com/sgl-project/sglang/actions/runs/21694948474/job/62563234146?pr=18263
|
|
Hi @kkHuang-amd is there anything needed from me to get this merged? |
|
For example, |
As per title.
Fixes #18262
See context and details in #18262