Skip to content

[AMD][AITER] Dynamically define max_num_total_tokens to avoid OOMs in AITER attention backend buffers allocation#18263

Open
fxmarty-amd wants to merge 5 commits intosgl-project:mainfrom
fxmarty-amd:fix-kv-cache-aiter-memory-allocation
Open

[AMD][AITER] Dynamically define max_num_total_tokens to avoid OOMs in AITER attention backend buffers allocation#18263
fxmarty-amd wants to merge 5 commits intosgl-project:mainfrom
fxmarty-amd:fix-kv-cache-aiter-memory-allocation

Conversation

@fxmarty-amd
Copy link
Copy Markdown
Contributor

As per title.

Fixes #18262

See context and details in #18262

clean fix
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @fxmarty-amd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the memory management strategy within the SGLang runtime, specifically for the AITER attention backend. It resolves critical Out-Of-Memory errors by implementing a sophisticated dynamic calculation for max_num_total_tokens. This new approach ensures that the memory footprint of AITER's workspace buffers is properly accounted for, leading to more stable and efficient operation when processing large language models, particularly on AMD platforms. The update replaces a static memory adjustment with a more precise, conditional calculation.

Highlights

  • Dynamic max_num_total_tokens Calculation: Introduced a new method, _solve_max_tokens_with_aiter_workspace, to dynamically determine max_num_total_tokens for the AITER attention backend. This calculation now accurately accounts for the memory required by AITER's internal workspace buffers, preventing Out-Of-Memory (OOM) errors.
  • OOM Prevention for AITER Backend: The changes directly address and fix Out-Of-Memory issues that previously occurred during AITER attention backend buffer allocation, especially when dealing with larger context lengths or specific memory configurations.
  • Conditional Memory Management Logic: The profile_max_num_token function has been updated to intelligently apply the AITER-specific memory calculation only when the AITER backend is active and certain conditions (e.g., no Mambaish config, no max_running_requests specified) are met. Otherwise, it defaults to a general memory calculation.
  • Removed Static Memory Adjustment: A previous static adjustment to mem_fraction_static for the AITER backend when context_len exceeded 8192 has been removed. The new dynamic calculation renders this static adjustment redundant and potentially less accurate.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/aiter_backend.py
    • Added a new static method get_max_num_partitions to AiterAttnBackend to calculate the maximum number of partitions based on context length, which is used in the new memory calculation logic.
  • python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
    • Introduced a new private method _solve_max_tokens_with_aiter_workspace to dynamically compute max_total_num_tokens by considering the AITER attention workspace memory requirements, solving a piecewise function for max_num_reqs.
    • Modified the profile_max_num_token method to conditionally invoke _solve_max_tokens_with_aiter_workspace when the AITER backend is in use and specific conditions are met, otherwise falling back to the default calculation.
  • python/sglang/srt/server_args.py
    • Removed the static mem_fraction_static adjustment that was previously applied to the AITER backend when model_config.context_len exceeded 8192, as this is now handled by the dynamic memory allocation logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a dynamic calculation for max_num_total_tokens to prevent Out-Of-Memory errors when using the AITER attention backend. The core of the change is a new method, _solve_max_tokens_with_aiter_workspace, which correctly accounts for the AITER workspace memory by solving a piecewise function. This is a significant improvement over the previous static memory allocation. The removal of the related heuristic in server_args.py is also a good cleanup. I've included one suggestion to address a potential oversight where the workspace memory is not accounted for when max_running_requests is explicitly set, which could still lead to OOMs.

Comment on lines +228 to +241
if (
self.server_args.attention_backend == "aiter"
and self.mambaish_config is None
and self.server_args.max_running_requests is None
):
# `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define
# `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention.
# The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation.
max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace(
rest_memory_bytes, cell_size, num_layers
)
else:
# No workspace overhead for other backends
max_total_num_tokens = rest_memory_bytes // cell_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic correctly handles the dynamic calculation of max_total_num_tokens when max_running_requests is not set. However, when max_running_requests is specified, the code falls through to the else block, which does not account for the AITER workspace memory. This could still lead to Out-Of-Memory errors, which this pull request aims to fix.

To make this more robust, I suggest handling the fixed max_running_requests case explicitly for the aiter backend by calculating the workspace size and subtracting it from the available memory. This ensures memory is correctly provisioned in all scenarios for the aiter backend.

        if (
            self.server_args.attention_backend == "aiter"
            and self.mambaish_config is None
        ):
            from sglang.srt.configs.model_config import AttentionArch
            if self.model_config.attention_arch == AttentionArch.MLA:
                # For MLA, workspace is allocated dynamically, not during init
                max_total_num_tokens = rest_memory_bytes // cell_size
            elif self.server_args.max_running_requests is None:
                # `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define
                # `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention.
                # The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation.
                max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace(
                    rest_memory_bytes, cell_size, num_layers
                )
            else:
                # When max_running_requests is set, max_num_reqs is fixed.
                # We need to account for the AITER workspace memory.
                from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
                num_head = self.model_config.num_attention_heads // get_attention_tp_size()
                head_dim = self.model_config.head_dim
                max_context_len = self.model_config.context_len
                max_num_partitions = AiterAttnBackend.get_max_num_partitions(max_context_len)
                W = num_head * max_num_partitions * (head_dim * 4 + 8)
                max_num_reqs = self.server_args.max_running_requests
                aiter_workspace = max_num_reqs * W
                max_total_num_tokens = (rest_memory_bytes - aiter_workspace) // cell_size
        else:
            # No workspace overhead for other backends
            max_total_num_tokens = rest_memory_bytes // cell_size

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For simplicity, not included in this PR. It could be included in a future PR.

@fxmarty-amd fxmarty-amd changed the title [AITER] Dynamically define max_num_total_tokens to avoid OOMs in AITER attention backend buffers allocation [AMD][AITER] Dynamically define max_num_total_tokens to avoid OOMs in AITER attention backend buffers allocation Feb 4, 2026
Comment on lines 1862 to 1865
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Older workaround for this issue (not working in all cases)

candidate_max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size
if candidate_max_total_num_tokens / context_len * 512 >= 4096:
return int(candidate_max_total_num_tokens)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From line #110 to line #122.

Do we need these two conditions?

if candidate_max_total_num_tokens / context_len * 512 <= 2048:
return int(candidate_max_total_num_tokens)
if candidate_max_total_num_tokens / context_len * 512 >= 4096:
return int(candidate_max_total_num_tokens)

Our purpose is to get the rest_memory_bytes when max_num_requests < 2048 or > 4096 and automatically to clamp max_num_requests to the min: 2048 and max: 4096 to calculate the maximum workspace size of aiter backend.

When we get the real rest_memory_bytes to calculate the max_total_num_tokens of kv_buffer , do we still need to check this value needs to meet these conditions (<=2048 or >=4096)?

From my point, we should return candidate_max_total_num_tokens directly, not need these conditions check.

Could you explain more why we need these two conditions?

Copy link
Copy Markdown
Contributor Author

@fxmarty-amd fxmarty-amd Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, what would you return directly exactly?

The three cases correspond to the three regions of

if max_num_reqs is None:
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
4096,
)

We don't know which region we are in until we solve for max_total_num_tokens. The three checks assume a value of max_num_reqs (2048, 4096, or max_total_num_tokens * 512 / context_len), resolve a candidate max_total_num_tokens, and verify which of the three regions we are in.

@fxmarty-amd
Copy link
Copy Markdown
Contributor Author

fxmarty-amd commented Feb 5, 2026

CI summary:

In https://github.com/sgl-project/sglang/actions/runs/21694948459/job/62564157114?pr=18263:

build failed (likely unrelated to this PR):

2026-02-05T02:03:29.1377914Z -->Runtime�[31merror:�[0m Error building extension 'module_rmsnorm': [1/1365] /opt/rocm/bin/hipcc  -DWITH_HIP -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_EXTENSION_NAME=module_rmsnorm -I/sgl-workspace/aiter/3rdparty/ck_helper -I/sgl-workspace/aiter/3rdparty/composable_kernel/include -I/sgl-workspace/aiter/3rdparty/composable_kernel/library/include -I/sgl-workspace/aiter/csrc/include -I/sgl-workspace/aiter/aiter/jit/build/module_rmsnorm/blob -I/sgl-workspace/aiter/3rdparty/composable_kernel/example/ck_tile/10_rmsnorm2d -I/sgl-workspace/aiter/csrc/include/torch -I/opt/venv/lib/python3.10/site-packages/pybind11/include -isystem /opt/rocm/include -isystem /opt/venv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /opt/venv/lib/python3.10/site-packages/torch/include/THC -isystem /opt/venv/lib/python3.10/site-packages/torch/include/TH -isystem /opt/venv/lib/python3.10/site-packages/torch/include -isystem /opt/venv/lib/python3.10/site-packages/torch/include/THH -isystem /usr/include/python3.10 -fPIC -std=c++20 -O3 -std=c++20 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -mcmodel=large -fno-unique-section-names -ffunction-sections -fdata-sections -fvisibility=hidden -fvisibility-inlines-hidden --offload-arch=native -DDLLVM_MAIN_REVISION=554785 -DLEGACY_HIPBLAS_DIRECT -DTORCH_Float4_e2m1fn_x2 -DUSE_PROF_API=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -Wno-macro-redefined -Wno-missing-template-arg-list-after-template-kw -Wno-switch-bool -Wno-undefined-func-template -Wno-unused-result -Wno-vla-cxx-extension -fgpu-flush-denormals-to-zero -fno-offload-uniform-block -mllvm --amdgpu-kernarg-preload-count=16 -mllvm --lsr-drop-solution=1 -mllvm -amdgpu-coerce-illegal-types=1 -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -mllvm -enable-post-misched=0 -fno-gpu-rdc -c /sgl-workspace/aiter/aiter/jit/build/module_rmsnorm/blob/rmsnorm2d_fwd_bf16_fp8_n1024_dquant_t5ml.cpp -o rmsnorm2d_fwd_bf16_fp8_n1024_dquant_t5ml.cuda.o 
2026-02-05T02:03:30.9762698Z clang++: �[31merror:�[0m cannot determine amdgcn architecture: /opt/rocm-7.0.0/lib/llvm/bin/amdgpu-arch: Child timed out: ; consider passing it via '--offload-arch'; environment variable CLANG_TOOLCHAIN_PROGRAM_TIMEOUT specifies the tool timeout (integer secs, <=0 is infinite)
ninja: build stopped: subcommand failed.
...
2026-02-05T02:03:31.1828695Z [aiter] build [module_rmsnorm] under /sgl-workspace/aiter/aiter/jit/build/module_rmsnorm/build failed !!!!!!

in https://github.com/sgl-project/sglang/actions/runs/21694948459/job/62564157059?pr=18263

2026-02-05T01:55:46.4448010Z ✗ FAILED:
2026-02-05T01:55:46.4448219Z   registered/hicache/test_hicache_variants.py (exit code 1)

caused by

  File "/felmarty/repos/sglang/python/sglang/srt/mem_cache/memory_pool_host.py", line 167, in __init__
    self.size > device_pool.size
AssertionError: The host memory should be larger than the device memory with the current protocol

=> fails locally as well on MI355X. We fix in b2d06b0 by increasing in the test --hicache-size 200 to --hicache-size 250 as the KV size went from KV size: 192.93 GB on main on MI355X to KV size: 232.18 GB using this fix (as we remove the mem_fraction_static *= 0.85 logic).

in https://github.com/sgl-project/sglang/actions/runs/21694948474/job/62563234452?pr=18263

unrelated to this PR it seems:

2026-02-05T01:35:03.9849945Z ✗ FAILED:
2026-02-05T01:35:03.9850184Z   registered/vlm/test_vision_openai_server_a.py (exit code 1)

caused by

2026-02-05T01:26:13.1246887Z   File "/usr/local/lib/python3.10/dist-packages/torchaudio/__init__.py", line 7, in <module>
2026-02-05T01:26:13.1247357Z     from . import _extension  # noqa  # usort: skip
2026-02-05T01:26:13.1247856Z   File "/usr/local/lib/python3.10/dist-packages/torchaudio/_extension/__init__.py", line 41, in <module>
2026-02-05T01:26:13.1248365Z     _check_cuda_version()
2026-02-05T01:26:13.1248806Z   File "/usr/local/lib/python3.10/dist-packages/torchaudio/_extension/utils.py", line 121, in _check_cuda_version
2026-02-05T01:26:13.1249441Z     version = torchaudio.lib._torchaudio.cuda_version()
2026-02-05T01:26:13.1250039Z AttributeError: partially initialized module 'torchaudio' has no attribute 'lib' (most likely due to a circular import)

In https://github.com/sgl-project/sglang/actions/runs/21694948474/job/62563234146?pr=18263 test_pp_long_context_prefill

2026-02-05T01:34:17.5724601Z ✗ FAILED:
2026-02-05T01:34:17.5724831Z   registered/perf/test_bench_serving_2gpu.py (exit code 1)

with the error

2026-02-05T01:33:22.8743598Z   File "/public_sglang_ci/runner-l2-4gt7x-gpu-23/_work/sglang/sglang/test/registered/perf/test_bench_serving_2gpu.py", line 104, in test_pp_long_context_prefill
2026-02-05T01:33:22.8744383Z     self.assertGreater(res["input_throughput"], 4000)
2026-02-05T01:33:22.8744740Z   File "/usr/lib/python3.10/unittest/case.py", line 1244, in assertGreater
2026-02-05T01:33:22.8745105Z     self.fail(self._formatMessage(msg, standardMsg))
2026-02-05T01:33:22.8745415Z   File "/usr/lib/python3.10/unittest/case.py", line 675, in fail
2026-02-05T01:33:22.8745718Z     raise self.failureException(msg)
2026-02-05T01:33:22.8745999Z AssertionError: 3398.341704887685 not greater than 4000

which I think is unrelated to this PR and should not be triggered in AMD CI ?

if is_in_amd_ci():
self.assertGreater(res["input_throughput"], 3000)
else:
self.assertGreater(res["input_throughput"], 4000)

@github-actions github-actions bot added the hicache Hierarchical Caching for SGLang label Feb 5, 2026
@fxmarty-amd fxmarty-amd requested a review from kkHuang-amd March 10, 2026 13:35
@fxmarty-amd
Copy link
Copy Markdown
Contributor Author

fxmarty-amd commented Mar 10, 2026

Hi @kkHuang-amd is there anything needed from me to get this merged?

@fxmarty-amd
Copy link
Copy Markdown
Contributor Author

fxmarty-amd commented Mar 13, 2026

For example, sglang serve --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --tensor-parallel-size 1 is crashing by default due to this issue on MI355.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd hicache Hierarchical Caching for SGLang run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] [AITER attention] OOM errors as resolved max_total_num_tokens does not take into account the memory requirements from attention backends

2 participants