Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class ForwardMetadata:


class AiterAttnBackend(AttentionBackend):
@staticmethod
def get_max_num_partitions(max_context_len: int) -> int:
return (
max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
) // _AITER_PARTITION_SIZE_ROCM

def __init__(
self,
model_runner: ModelRunner,
Expand Down Expand Up @@ -154,9 +160,9 @@ def __init__(
)

# aiter kernel related initialization
self.max_num_partitions = (
self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
) // _AITER_PARTITION_SIZE_ROCM
self.max_num_partitions = AiterAttnBackend.get_max_num_partitions(
self.max_context_len
)

nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,85 @@


class ModelRunnerKVCacheMixin:
def _solve_max_tokens_with_aiter_workspace(
self: ModelRunner, rest_memory_bytes: int, cell_size: int, num_layers: int
) -> int:
"""
Solve for max_total_num_tokens accounting for aiter attention workspace memory.

We need to satisfy:
- kv_memory + workspace_memory = rest_memory
- kv_memory = max_total_num_tokens * cell_size
- workspace_memory = max_num_reqs * workspace_constant
- max_num_reqs = clamp(max_total_num_tokens / context_len * 512, 2048, 4096)

The `max_num_reqs` function is piecewise, so we solve for each piece and pick the valid solution.
"""
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

# Get attention parameters
num_head = self.model_config.num_attention_heads // get_attention_tp_size()
head_dim = self.model_config.head_dim
context_len = self.model_config.context_len
use_mla = self.model_config.attention_arch == AttentionArch.MLA

# For MLA, workspace is allocated dynamically, not during init
if use_mla:
return rest_memory_bytes // cell_size

max_num_partitions = AiterAttnBackend.get_max_num_partitions(context_len)

# Resolve `max_total_num_tokens` based on the `workspace_size` required by aiter_backend.py:
# workspace_size = (max_num_reqs * num_head * max_num_partitions * head_dim) * 4
# + 2 * (max_num_reqs * num_head * max_num_partitions) * 4

# i.e. `workspace_size = max_num_reqs * W` introducing the known constant W.
W = num_head * max_num_partitions * (head_dim * 4 + 8)

# We then have from `ModelRunnerKVCacheMixin.init_memory_pool`:
# max_num_reqs = clamp(max_total_num_tokens / context_len * 512, 2048, 4096)

# With the constraint: rest_memory_bytes = kv_memory + aiter_memory
# = max_total_num_tokens * cell_size + max_num_reqs * W

# This creates three cases:

# Case 2: Linear region in-between, typically where we'll be.
# 2048 <= max_total_num_tokens / context_len * 512 <= 4096
# <=> max_num_reqs = max_total_num_tokens * 512 / context_len
# Injecting in rest_memory_bytes = kv_memory + aiter_memory we get:
# max_total_num_tokens * cell_size + (max_total_num_tokens * 512 / context_len) * W = rest_memory_bytes
# <=> max_total_num_tokens * (cell_size + 512 * W / context_len) = rest_memory_bytes
# <=> max_total_num_tokens = rest_memory_bytes / (cell_size + 512 * W / context_len)
candidate_max_total_num_tokens = rest_memory_bytes / (
cell_size + 512 * W / context_len
)
if 2048 <= candidate_max_total_num_tokens / context_len * 512 <= 4096:
return int(candidate_max_total_num_tokens)

# Case 1: max_total_num_tokens / context_len * 512 <= 2048
# <=> max_num_reqs = 2048
# Injecting in rest_memory_bytes = kv_memory + aiter_memory:
# max_total_num_tokens * cell_size + 2048 * W = rest_memory_bytes
# <=> max_total_num_tokens = (rest_memory_bytes - 2048 * W) / cell_size
candidate_max_total_num_tokens = (rest_memory_bytes - 2048 * W) / cell_size
if candidate_max_total_num_tokens / context_len * 512 <= 2048:
return int(candidate_max_total_num_tokens)

# Case 3: max_total_num_tokens / context_len * 512 >= 4096
# <=> max_num_reqs = 4096
# Injecting in rest_memory_bytes = kv_memory + aiter_memory we get:
# max_total_num_tokens * cell_size + 4096 * W = rest_memory_bytes
# <=> max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size
candidate_max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size
if candidate_max_total_num_tokens / context_len * 512 >= 4096:
return int(candidate_max_total_num_tokens)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From line #110 to line #122.

Do we need these two conditions?

if candidate_max_total_num_tokens / context_len * 512 <= 2048:
return int(candidate_max_total_num_tokens)
if candidate_max_total_num_tokens / context_len * 512 >= 4096:
return int(candidate_max_total_num_tokens)

Our purpose is to get the rest_memory_bytes when max_num_requests < 2048 or > 4096 and automatically to clamp max_num_requests to the min: 2048 and max: 4096 to calculate the maximum workspace size of aiter backend.

When we get the real rest_memory_bytes to calculate the max_total_num_tokens of kv_buffer , do we still need to check this value needs to meet these conditions (<=2048 or >=4096)?

From my point, we should return candidate_max_total_num_tokens directly, not need these conditions check.

Could you explain more why we need these two conditions?

Copy link
Copy Markdown
Contributor Author

@fxmarty-amd fxmarty-amd Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, what would you return directly exactly?

The three cases correspond to the three regions of

if max_num_reqs is None:
max_num_reqs = min(
max(
int(
self.max_total_num_tokens / self.model_config.context_len * 512
),
2048,
),
4096,
)

We don't know which region we are in until we solve for max_total_num_tokens. The three checks assume a value of max_num_reqs (2048, 4096, or max_total_num_tokens * 512 / context_len), resolve a candidate max_total_num_tokens, and verify which of the three regions we are in.

raise ValueError(
"Something went wrong in the memory allocation for KV cache. Please open an issue."
)

def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int:
kv_size = torch._utils._element_size(self.kv_cache_dtype)
if self.use_mla_backend:
Expand Down Expand Up @@ -146,7 +225,25 @@ def profile_max_num_token(self: ModelRunner, total_gpu_memory: int):
if self.mambaish_config is not None:
rest_memory = self.handle_max_mamba_cache(rest_memory)

return int(rest_memory * (1 << 30)) // cell_size
rest_memory_bytes = int(rest_memory * (1 << 30))

# NOTE: No special handling for the cases `self.mambaish_config is not None` and when `max_running_requests` is specified.
if (
self.server_args.attention_backend == "aiter"
and self.mambaish_config is None
and self.server_args.max_running_requests is None
):
# `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define
# `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention.
# The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation.
max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace(
rest_memory_bytes, cell_size, num_layers
)
else:
# No workspace overhead for other backends
max_total_num_tokens = rest_memory_bytes // cell_size
Comment on lines +231 to +244
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic correctly handles the dynamic calculation of max_total_num_tokens when max_running_requests is not set. However, when max_running_requests is specified, the code falls through to the else block, which does not account for the AITER workspace memory. This could still lead to Out-Of-Memory errors, which this pull request aims to fix.

To make this more robust, I suggest handling the fixed max_running_requests case explicitly for the aiter backend by calculating the workspace size and subtracting it from the available memory. This ensures memory is correctly provisioned in all scenarios for the aiter backend.

        if (
            self.server_args.attention_backend == "aiter"
            and self.mambaish_config is None
        ):
            from sglang.srt.configs.model_config import AttentionArch
            if self.model_config.attention_arch == AttentionArch.MLA:
                # For MLA, workspace is allocated dynamically, not during init
                max_total_num_tokens = rest_memory_bytes // cell_size
            elif self.server_args.max_running_requests is None:
                # `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define
                # `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention.
                # The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation.
                max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace(
                    rest_memory_bytes, cell_size, num_layers
                )
            else:
                # When max_running_requests is set, max_num_reqs is fixed.
                # We need to account for the AITER workspace memory.
                from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
                num_head = self.model_config.num_attention_heads // get_attention_tp_size()
                head_dim = self.model_config.head_dim
                max_context_len = self.model_config.context_len
                max_num_partitions = AiterAttnBackend.get_max_num_partitions(max_context_len)
                W = num_head * max_num_partitions * (head_dim * 4 + 8)
                max_num_reqs = self.server_args.max_running_requests
                aiter_workspace = max_num_reqs * W
                max_total_num_tokens = (rest_memory_bytes - aiter_workspace) // cell_size
        else:
            # No workspace overhead for other backends
            max_total_num_tokens = rest_memory_bytes // cell_size

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For simplicity, not included in this PR. It could be included in a future PR.


return max_total_num_tokens

def handle_max_mamba_cache(self: ModelRunner, total_rest_memory):
config = self.mambaish_config
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,11 +1881,6 @@ def _handle_attention_backend_compatibility(self):
)
self.page_size = 128

# AMD platforms backends
if self.attention_backend == "aiter":
if model_config.context_len > 8192:
self.mem_fraction_static *= 0.85

# Other platforms backends
if (
self.attention_backend == "intel_amx"
Expand Down
2 changes: 1 addition & 1 deletion test/registered/hicache/test_hicache_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class TestHiCacheMLA(HiCacheBaseServer, HiCacheEvalMixin, HiCacheMGSMEvalMixin):
hicache_args = [
"--trust-remote-code",
"--enable-hierarchical-cache",
] + (["--hicache-size", 200] if _is_hip else ["--hicache-ratio", 2])
] + (["--hicache-size", 250] if _is_hip else ["--hicache-ratio", 2])
expected_mmlu_score = 0.5


Expand Down
Loading