-
Notifications
You must be signed in to change notification settings - Fork 5.2k
[AMD][AITER] Dynamically define max_num_total_tokens to avoid OOMs in AITER attention backend buffers allocation
#18263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a8afc58
1a867aa
a9d2bbd
b153f13
b2d06b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,85 @@ | |
|
|
||
|
|
||
| class ModelRunnerKVCacheMixin: | ||
| def _solve_max_tokens_with_aiter_workspace( | ||
| self: ModelRunner, rest_memory_bytes: int, cell_size: int, num_layers: int | ||
| ) -> int: | ||
| """ | ||
| Solve for max_total_num_tokens accounting for aiter attention workspace memory. | ||
|
|
||
| We need to satisfy: | ||
| - kv_memory + workspace_memory = rest_memory | ||
| - kv_memory = max_total_num_tokens * cell_size | ||
| - workspace_memory = max_num_reqs * workspace_constant | ||
| - max_num_reqs = clamp(max_total_num_tokens / context_len * 512, 2048, 4096) | ||
|
|
||
| The `max_num_reqs` function is piecewise, so we solve for each piece and pick the valid solution. | ||
| """ | ||
| from sglang.srt.configs.model_config import AttentionArch | ||
| from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend | ||
|
|
||
| # Get attention parameters | ||
| num_head = self.model_config.num_attention_heads // get_attention_tp_size() | ||
| head_dim = self.model_config.head_dim | ||
| context_len = self.model_config.context_len | ||
| use_mla = self.model_config.attention_arch == AttentionArch.MLA | ||
|
|
||
| # For MLA, workspace is allocated dynamically, not during init | ||
| if use_mla: | ||
| return rest_memory_bytes // cell_size | ||
|
|
||
| max_num_partitions = AiterAttnBackend.get_max_num_partitions(context_len) | ||
|
|
||
| # Resolve `max_total_num_tokens` based on the `workspace_size` required by aiter_backend.py: | ||
| # workspace_size = (max_num_reqs * num_head * max_num_partitions * head_dim) * 4 | ||
| # + 2 * (max_num_reqs * num_head * max_num_partitions) * 4 | ||
|
|
||
| # i.e. `workspace_size = max_num_reqs * W` introducing the known constant W. | ||
| W = num_head * max_num_partitions * (head_dim * 4 + 8) | ||
|
|
||
| # We then have from `ModelRunnerKVCacheMixin.init_memory_pool`: | ||
| # max_num_reqs = clamp(max_total_num_tokens / context_len * 512, 2048, 4096) | ||
|
|
||
| # With the constraint: rest_memory_bytes = kv_memory + aiter_memory | ||
| # = max_total_num_tokens * cell_size + max_num_reqs * W | ||
|
|
||
| # This creates three cases: | ||
|
|
||
| # Case 2: Linear region in-between, typically where we'll be. | ||
| # 2048 <= max_total_num_tokens / context_len * 512 <= 4096 | ||
| # <=> max_num_reqs = max_total_num_tokens * 512 / context_len | ||
| # Injecting in rest_memory_bytes = kv_memory + aiter_memory we get: | ||
| # max_total_num_tokens * cell_size + (max_total_num_tokens * 512 / context_len) * W = rest_memory_bytes | ||
| # <=> max_total_num_tokens * (cell_size + 512 * W / context_len) = rest_memory_bytes | ||
| # <=> max_total_num_tokens = rest_memory_bytes / (cell_size + 512 * W / context_len) | ||
| candidate_max_total_num_tokens = rest_memory_bytes / ( | ||
| cell_size + 512 * W / context_len | ||
| ) | ||
| if 2048 <= candidate_max_total_num_tokens / context_len * 512 <= 4096: | ||
| return int(candidate_max_total_num_tokens) | ||
|
|
||
| # Case 1: max_total_num_tokens / context_len * 512 <= 2048 | ||
| # <=> max_num_reqs = 2048 | ||
| # Injecting in rest_memory_bytes = kv_memory + aiter_memory: | ||
| # max_total_num_tokens * cell_size + 2048 * W = rest_memory_bytes | ||
| # <=> max_total_num_tokens = (rest_memory_bytes - 2048 * W) / cell_size | ||
| candidate_max_total_num_tokens = (rest_memory_bytes - 2048 * W) / cell_size | ||
| if candidate_max_total_num_tokens / context_len * 512 <= 2048: | ||
| return int(candidate_max_total_num_tokens) | ||
|
|
||
| # Case 3: max_total_num_tokens / context_len * 512 >= 4096 | ||
| # <=> max_num_reqs = 4096 | ||
| # Injecting in rest_memory_bytes = kv_memory + aiter_memory we get: | ||
| # max_total_num_tokens * cell_size + 4096 * W = rest_memory_bytes | ||
| # <=> max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size | ||
| candidate_max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size | ||
| if candidate_max_total_num_tokens / context_len * 512 >= 4096: | ||
| return int(candidate_max_total_num_tokens) | ||
|
|
||
| raise ValueError( | ||
| "Something went wrong in the memory allocation for KV cache. Please open an issue." | ||
| ) | ||
|
|
||
| def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int: | ||
| kv_size = torch._utils._element_size(self.kv_cache_dtype) | ||
| if self.use_mla_backend: | ||
|
|
@@ -146,7 +225,25 @@ def profile_max_num_token(self: ModelRunner, total_gpu_memory: int): | |
| if self.mambaish_config is not None: | ||
| rest_memory = self.handle_max_mamba_cache(rest_memory) | ||
|
|
||
| return int(rest_memory * (1 << 30)) // cell_size | ||
| rest_memory_bytes = int(rest_memory * (1 << 30)) | ||
|
|
||
| # NOTE: No special handling for the cases `self.mambaish_config is not None` and when `max_running_requests` is specified. | ||
| if ( | ||
| self.server_args.attention_backend == "aiter" | ||
| and self.mambaish_config is None | ||
| and self.server_args.max_running_requests is None | ||
| ): | ||
| # `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define | ||
| # `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention. | ||
| # The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation. | ||
| max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace( | ||
| rest_memory_bytes, cell_size, num_layers | ||
| ) | ||
| else: | ||
| # No workspace overhead for other backends | ||
| max_total_num_tokens = rest_memory_bytes // cell_size | ||
|
Comment on lines
+231
to
+244
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current logic correctly handles the dynamic calculation of To make this more robust, I suggest handling the fixed if (
self.server_args.attention_backend == "aiter"
and self.mambaish_config is None
):
from sglang.srt.configs.model_config import AttentionArch
if self.model_config.attention_arch == AttentionArch.MLA:
# For MLA, workspace is allocated dynamically, not during init
max_total_num_tokens = rest_memory_bytes // cell_size
elif self.server_args.max_running_requests is None:
# `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define
# `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention.
# The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation.
max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace(
rest_memory_bytes, cell_size, num_layers
)
else:
# When max_running_requests is set, max_num_reqs is fixed.
# We need to account for the AITER workspace memory.
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
num_head = self.model_config.num_attention_heads // get_attention_tp_size()
head_dim = self.model_config.head_dim
max_context_len = self.model_config.context_len
max_num_partitions = AiterAttnBackend.get_max_num_partitions(max_context_len)
W = num_head * max_num_partitions * (head_dim * 4 + 8)
max_num_reqs = self.server_args.max_running_requests
aiter_workspace = max_num_reqs * W
max_total_num_tokens = (rest_memory_bytes - aiter_workspace) // cell_size
else:
# No workspace overhead for other backends
max_total_num_tokens = rest_memory_bytes // cell_size
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For simplicity, not included in this PR. It could be included in a future PR. |
||
|
|
||
| return max_total_num_tokens | ||
|
|
||
| def handle_max_mamba_cache(self: ModelRunner, total_rest_memory): | ||
| config = self.mambaish_config | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From line #110 to line #122.
Do we need these two conditions?
if candidate_max_total_num_tokens / context_len * 512 <= 2048:
return int(candidate_max_total_num_tokens)
if candidate_max_total_num_tokens / context_len * 512 >= 4096:
return int(candidate_max_total_num_tokens)
Our purpose is to get the rest_memory_bytes when max_num_requests < 2048 or > 4096 and automatically to clamp max_num_requests to the min: 2048 and max: 4096 to calculate the maximum workspace size of aiter backend.
When we get the real rest_memory_bytes to calculate the max_total_num_tokens of kv_buffer , do we still need to check this value needs to meet these conditions (<=2048 or >=4096)?
From my point, we should return candidate_max_total_num_tokens directly, not need these conditions check.
Could you explain more why we need these two conditions?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm, what would you return directly exactly?
The three cases correspond to the three regions of
sglang/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
Lines 304 to 313 in 079fc8f
We don't know which region we are in until we solve for
max_total_num_tokens. The three checks assume a value ofmax_num_reqs(2048, 4096, ormax_total_num_tokens * 512 / context_len), resolve a candidatemax_total_num_tokens, and verify which of the three regions we are in.