diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index a99e230e25ac..6599b17c2c69 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -87,6 +87,12 @@ class ForwardMetadata: class AiterAttnBackend(AttentionBackend): + @staticmethod + def get_max_num_partitions(max_context_len: int) -> int: + return ( + max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 + ) // _AITER_PARTITION_SIZE_ROCM + def __init__( self, model_runner: ModelRunner, @@ -154,9 +160,9 @@ def __init__( ) # aiter kernel related initialization - self.max_num_partitions = ( - self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1 - ) // _AITER_PARTITION_SIZE_ROCM + self.max_num_partitions = AiterAttnBackend.get_max_num_partitions( + self.max_context_len + ) nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index c7dcd57667c1..71769f75f5fd 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -44,6 +44,85 @@ class ModelRunnerKVCacheMixin: + def _solve_max_tokens_with_aiter_workspace( + self: ModelRunner, rest_memory_bytes: int, cell_size: int, num_layers: int + ) -> int: + """ + Solve for max_total_num_tokens accounting for aiter attention workspace memory. + + We need to satisfy: + - kv_memory + workspace_memory = rest_memory + - kv_memory = max_total_num_tokens * cell_size + - workspace_memory = max_num_reqs * workspace_constant + - max_num_reqs = clamp(max_total_num_tokens / context_len * 512, 2048, 4096) + + The `max_num_reqs` function is piecewise, so we solve for each piece and pick the valid solution. + """ + from sglang.srt.configs.model_config import AttentionArch + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + + # Get attention parameters + num_head = self.model_config.num_attention_heads // get_attention_tp_size() + head_dim = self.model_config.head_dim + context_len = self.model_config.context_len + use_mla = self.model_config.attention_arch == AttentionArch.MLA + + # For MLA, workspace is allocated dynamically, not during init + if use_mla: + return rest_memory_bytes // cell_size + + max_num_partitions = AiterAttnBackend.get_max_num_partitions(context_len) + + # Resolve `max_total_num_tokens` based on the `workspace_size` required by aiter_backend.py: + # workspace_size = (max_num_reqs * num_head * max_num_partitions * head_dim) * 4 + # + 2 * (max_num_reqs * num_head * max_num_partitions) * 4 + + # i.e. `workspace_size = max_num_reqs * W` introducing the known constant W. + W = num_head * max_num_partitions * (head_dim * 4 + 8) + + # We then have from `ModelRunnerKVCacheMixin.init_memory_pool`: + # max_num_reqs = clamp(max_total_num_tokens / context_len * 512, 2048, 4096) + + # With the constraint: rest_memory_bytes = kv_memory + aiter_memory + # = max_total_num_tokens * cell_size + max_num_reqs * W + + # This creates three cases: + + # Case 2: Linear region in-between, typically where we'll be. + # 2048 <= max_total_num_tokens / context_len * 512 <= 4096 + # <=> max_num_reqs = max_total_num_tokens * 512 / context_len + # Injecting in rest_memory_bytes = kv_memory + aiter_memory we get: + # max_total_num_tokens * cell_size + (max_total_num_tokens * 512 / context_len) * W = rest_memory_bytes + # <=> max_total_num_tokens * (cell_size + 512 * W / context_len) = rest_memory_bytes + # <=> max_total_num_tokens = rest_memory_bytes / (cell_size + 512 * W / context_len) + candidate_max_total_num_tokens = rest_memory_bytes / ( + cell_size + 512 * W / context_len + ) + if 2048 <= candidate_max_total_num_tokens / context_len * 512 <= 4096: + return int(candidate_max_total_num_tokens) + + # Case 1: max_total_num_tokens / context_len * 512 <= 2048 + # <=> max_num_reqs = 2048 + # Injecting in rest_memory_bytes = kv_memory + aiter_memory: + # max_total_num_tokens * cell_size + 2048 * W = rest_memory_bytes + # <=> max_total_num_tokens = (rest_memory_bytes - 2048 * W) / cell_size + candidate_max_total_num_tokens = (rest_memory_bytes - 2048 * W) / cell_size + if candidate_max_total_num_tokens / context_len * 512 <= 2048: + return int(candidate_max_total_num_tokens) + + # Case 3: max_total_num_tokens / context_len * 512 >= 4096 + # <=> max_num_reqs = 4096 + # Injecting in rest_memory_bytes = kv_memory + aiter_memory we get: + # max_total_num_tokens * cell_size + 4096 * W = rest_memory_bytes + # <=> max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size + candidate_max_total_num_tokens = (rest_memory_bytes - 4096 * W) / cell_size + if candidate_max_total_num_tokens / context_len * 512 >= 4096: + return int(candidate_max_total_num_tokens) + + raise ValueError( + "Something went wrong in the memory allocation for KV cache. Please open an issue." + ) + def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int: kv_size = torch._utils._element_size(self.kv_cache_dtype) if self.use_mla_backend: @@ -146,7 +225,25 @@ def profile_max_num_token(self: ModelRunner, total_gpu_memory: int): if self.mambaish_config is not None: rest_memory = self.handle_max_mamba_cache(rest_memory) - return int(rest_memory * (1 << 30)) // cell_size + rest_memory_bytes = int(rest_memory * (1 << 30)) + + # NOTE: No special handling for the cases `self.mambaish_config is not None` and when `max_running_requests` is specified. + if ( + self.server_args.attention_backend == "aiter" + and self.mambaish_config is None + and self.server_args.max_running_requests is None + ): + # `max_total_num_tokens` is used in `ModelRunnerKVCacheMixin.init_memory_pool` to define + # `max_num_reqs`, which is in turn used in AITER attention backend to define GPU HBM buffers for the attention. + # The default strategy below to resolve `max_total_num_tokens` does NOT take into account the memory required for the attention backend, potentially resulting in OOM errors in AITER buffers allocation. + max_total_num_tokens = self._solve_max_tokens_with_aiter_workspace( + rest_memory_bytes, cell_size, num_layers + ) + else: + # No workspace overhead for other backends + max_total_num_tokens = rest_memory_bytes // cell_size + + return max_total_num_tokens def handle_max_mamba_cache(self: ModelRunner, total_rest_memory): config = self.mambaish_config diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6a1eb50d60ef..5e4e6a5d29ac 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1881,11 +1881,6 @@ def _handle_attention_backend_compatibility(self): ) self.page_size = 128 - # AMD platforms backends - if self.attention_backend == "aiter": - if model_config.context_len > 8192: - self.mem_fraction_static *= 0.85 - # Other platforms backends if ( self.attention_backend == "intel_amx" diff --git a/test/registered/hicache/test_hicache_variants.py b/test/registered/hicache/test_hicache_variants.py index 57e6edbbd30b..5500340753af 100644 --- a/test/registered/hicache/test_hicache_variants.py +++ b/test/registered/hicache/test_hicache_variants.py @@ -110,7 +110,7 @@ class TestHiCacheMLA(HiCacheBaseServer, HiCacheEvalMixin, HiCacheMGSMEvalMixin): hicache_args = [ "--trust-remote-code", "--enable-hierarchical-cache", - ] + (["--hicache-size", 200] if _is_hip else ["--hicache-ratio", 2]) + ] + (["--hicache-size", 250] if _is_hip else ["--hicache-ratio", 2]) expected_mmlu_score = 0.5