diff --git a/.claude/commands/benchmark-guide.md b/.claude/commands/benchmark-guide.md index 9a7f571b6..59d9dc09b 100644 --- a/.claude/commands/benchmark-guide.md +++ b/.claude/commands/benchmark-guide.md @@ -75,6 +75,12 @@ curl -X POST http://127.0.0.1:8000/stop_profile | E2EL | End-to-End Latency | | Throughput | Output tokens/s | +## Critical Rules + +- **Never run accuracy and performance tests simultaneously** — they interfere with each other's results. Always finish one before starting the other. +- **Report Total Token throughput (tok/s)**, not Output token throughput — Total includes input+output. +- **MTP models require `--use-chat-template`** — tokenizer mismatch without it causes inaccurate results. + ## CI Benchmark Workflow - File: `.github/workflows/atom-benchmark.yaml` diff --git a/.claude/commands/debug-guide.md b/.claude/commands/debug-guide.md index 87e92f9d1..62e5bd224 100644 --- a/.claude/commands/debug-guide.md +++ b/.claude/commands/debug-guide.md @@ -80,7 +80,7 @@ - **[COMMON] "Failed to allocate kv cache":** `num_blocks` returned 0. Model + KV cache don't fit in VRAM. **Fix:** Reduce `max_num_seqs` or use `--kv_cache_dtype fp8` - **Block exhaustion during serving:** Scheduler preempts last running sequence when blocks run out - **Block leak (ref_count never reaches 0):** Verify `seq.block_table` is cleared on sequence completion -- **Mamba/GDN block table:** Hybrid models use separate `mamba_block_table` for linear attention state. No prefix caching support +- **Mamba/GDN state slots:** Hybrid models use `mamba_state_slot` (per-request slot from unified pool) for recurrent state. Pool tracks mamba memory via equiv-block accounting ## Scheduler diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index 649c45305..8689ad425 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -41,6 +41,15 @@ def __init__(self, config: Config): self.used_block_ids: set[int] = set() self.enable_prefix_caching = config.enable_prefix_caching + # Mamba/GDN recurrent state: per-request slot groups + equiv-block accounting + # Each slot group contains (1+num_spec) contiguous tensor indices. + # free_mamba_slots tracks group indices (0..num_groups-1). + self.mamba_equiv_per_req: int = getattr(config, "mamba_equiv_per_req", 0) + num_mamba_groups: int = getattr(config, "num_mamba_groups", 0) + self.free_mamba_slots: list[int] = list(range(num_mamba_groups)) + # seq_id → list of accounting block_ids + self.mamba_accounting: dict[int, list[int]] = {} + @classmethod def compute_hash(cls, token_ids: list[int], prefix: int = -1): h = xxhash.xxh64() @@ -76,8 +85,13 @@ def _deallocate_block(self, block_id: int): self.free_block_ids_set.add(block_id) def can_allocate(self, seq: Sequence) -> bool: + mamba_cost = self.mamba_equiv_per_req if seq.mamba_enabled else 0 + mamba_slot_ok = (not seq.mamba_enabled) or len(self.free_mamba_slots) > 0 if not self.enable_prefix_caching: - return len(self.free_block_ids_set) >= seq.num_blocks + seq.num_mamba_blocks + return ( + len(self.free_block_ids_set) >= seq.num_blocks + mamba_cost + and mamba_slot_ok + ) # Dry-run: count how many blocks would be cache hits h = -1 cache_miss = False @@ -94,7 +108,9 @@ def can_allocate(self, seq: Sequence) -> bool: cache_miss = True if cache_miss: needed_free += 1 - return len(self.free_block_ids_set) >= needed_free + seq.num_mamba_blocks + return ( + len(self.free_block_ids_set) >= needed_free + mamba_cost and mamba_slot_ok + ) def allocate(self, seq: Sequence): assert not seq.block_table @@ -127,15 +143,15 @@ def allocate(self, seq: Sequence): self.hash_to_block_id[h] = block_id seq.block_table.append(block_id) - # handle mamba-like model + # Mamba/GDN recurrent state: allocate equiv blocks (accounting) + slot (indexing) if seq.mamba_enabled: - # For mamba, we need to ensure the last block is always allocated - # even if it has less than block_size tokens - for i in range(seq.num_mamba_blocks): + accounting_blocks = [] + for _ in range(self.mamba_equiv_per_req): block_id = self._pop_free_block() self._allocate_block(block_id) - # No prefix caching support for mamba arch - seq.mamba_block_table.append(block_id) + accounting_blocks.append(block_id) + self.mamba_accounting[seq.id] = accounting_blocks + seq.mamba_state_slot = self.free_mamba_slots.pop() def deallocate(self, seq: Sequence): for block_id in reversed(seq.block_table): @@ -145,13 +161,13 @@ def deallocate(self, seq: Sequence): self._deallocate_block(block_id) seq.num_cached_tokens = 0 seq.block_table.clear() - if seq.mamba_enabled: - for block_id in reversed(seq.mamba_block_table): + if seq.mamba_enabled and seq.mamba_state_slot >= 0: + for block_id in self.mamba_accounting.pop(seq.id, []): block = self.blocks[block_id] - # just in case - block.ref_count = 0 + block.ref_count = 0 # accounting blocks bypass ref-counting self._deallocate_block(block_id) - seq.mamba_block_table.clear() + self.free_mamba_slots.append(seq.mamba_state_slot) + seq.mamba_state_slot = -1 def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool: seq_len = len(seq) diff --git a/atom/model_engine/engine_core.py b/atom/model_engine/engine_core.py index a81736a32..c1879ef21 100644 --- a/atom/model_engine/engine_core.py +++ b/atom/model_engine/engine_core.py @@ -77,7 +77,10 @@ def __init__(self, config: Config, input_address: str, output_address: str): "atom.model_engine.model_runner.ModelRunner", config, ) - num_blocks = self.runner_mgr.call_func("get_num_blocks", wait_out=True) + block_info = self.runner_mgr.call_func("get_num_blocks", wait_out=True) + num_blocks = block_info["num_kvcache_blocks"] + config.mamba_equiv_per_req = block_info.get("mamba_equiv_per_req", 0) + config.num_mamba_groups = block_info.get("num_mamba_groups", 0) ret = self.runner_mgr.call_func( "allocate_kv_cache", num_blocks, wait_out=True ) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 3565e685a..af36b8e88 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -1076,24 +1076,9 @@ def _compute_block_bytes(self): * 4 # float32 ) - # gdn attn bytes - mamba_shape = self.gated_delta_net_state_shape( - get_tp_group().world_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - self.num_spec_tokens, - ) - mamba_dtypes = self.gated_delta_net_state_dtypes() - one_layer_byte = ( - math.prod(mamba_shape[0]) - * torch.tensor([], dtype=mamba_dtypes[0]).element_size() - + math.prod(mamba_shape[1]) - * torch.tensor([], dtype=mamba_dtypes[1]).element_size() - ) - block_bytes += self.num_gdn_attn_state * one_layer_byte + # GDN recurrent state is per-request (not per-block). + # It is accounted for separately via _compute_mamba_per_slot_bytes(). + # Do NOT add it to block_bytes. else: # Standard attention: kv_cache [2, num_hidden_layers, blocks, ...] # Note: allocate_kv_cache uses hf_config.num_hidden_layers for @@ -1116,6 +1101,31 @@ def _compute_block_bytes(self): ) return block_bytes + def _compute_mamba_per_slot_bytes(self) -> int: + """Compute per-slot recurrent state bytes (all GDN layers, one slot). + + A slot holds one request's state (or one spec token's state). + Returns 0 for non-GDN models. + """ + if not self.is_qwen_next(): + return 0 + hf_config = self.config.hf_config + mamba_shape = self.gated_delta_net_state_shape( + get_tp_group().world_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + self.num_spec_tokens, + ) + mamba_dtypes = self.gated_delta_net_state_dtypes() + one_layer_byte = ( + math.prod(mamba_shape[0]) * mamba_dtypes[0].itemsize + + math.prod(mamba_shape[1]) * mamba_dtypes[1].itemsize + ) + return self.num_gdn_attn_state * one_layer_byte + def _estimate_cudagraph_overhead(self): """Estimate GPU memory consumed by CUDA graph capture. @@ -1135,7 +1145,7 @@ def _estimate_cudagraph_overhead(self): # memory due to pooling across multiple captured batch sizes. return int(activation_bytes * 0.2) - def get_num_blocks(self): + def get_num_blocks(self) -> dict[str, int]: torch.set_default_device(self.device) config = self.config hf_config = config.hf_config @@ -1169,7 +1179,33 @@ def get_num_blocks(self): torch.set_default_device("cpu") block_bytes = self._compute_block_bytes() - num_kvcache_blocks = available_for_kv // block_bytes + + # GDN recurrent state: deduct mamba tensor memory from pool budget + mamba_per_slot = self._compute_mamba_per_slot_bytes() + slots_per_req = 1 + self.num_spec_tokens + max_mamba_slots = ( + config.max_num_seqs * slots_per_req if mamba_per_slot > 0 else 0 + ) + mamba_tensor_bytes = max_mamba_slots * mamba_per_slot + available_for_pool = available_for_kv - mamba_tensor_bytes + if available_for_pool <= 0: + raise RuntimeError( + f"GDN mamba tensor ({mamba_tensor_bytes / (1 << 30):.2f}GB for " + f"{max_mamba_slots} slots) exceeds available KV budget " + f"({available_for_kv / (1 << 30):.2f}GB). " + f"Reduce --max-num-seqs or increase gpu_memory_utilization." + ) + mamba_equiv = ( + math.ceil(mamba_per_slot / block_bytes) if mamba_per_slot > 0 else 0 + ) + + # Store for BlockManager and allocate_kv_cache + config.mamba_equiv_per_req = mamba_equiv + config.max_mamba_slots = max_mamba_slots + config.num_mamba_groups = config.max_num_seqs if mamba_per_slot > 0 else 0 + self.max_mamba_slots = max_mamba_slots + + num_kvcache_blocks = available_for_pool // block_bytes logger.info( f"Memory budget: total_gpu={total / (1 << 30):.2f}GB, " @@ -1183,6 +1219,14 @@ def get_num_blocks(self): f"block_bytes={block_bytes}, " f"num_kvcache_blocks={num_kvcache_blocks}" ) + if mamba_per_slot > 0: + logger.info( + f"GDN state pool: mamba_per_slot={mamba_per_slot / (1 << 20):.2f}MB, " + f"max_mamba_slots={max_mamba_slots}, " + f"mamba_tensor={mamba_tensor_bytes / (1 << 30):.2f}GB, " + f"mamba_equiv_blocks_per_req={mamba_equiv}, " + f"pool_blocks={num_kvcache_blocks}" + ) assert num_kvcache_blocks > 0, ( f"Not enough memory for KV cache with block size({self.block_size}). " @@ -1194,7 +1238,11 @@ def get_num_blocks(self): f"safety={safety_margin / (1 << 30):.2f}GB, " f"free={free / (1 << 30):.2f}GB)" ) - return num_kvcache_blocks + return { + "num_kvcache_blocks": num_kvcache_blocks, + "mamba_equiv_per_req": mamba_equiv, + "num_mamba_groups": config.max_num_seqs if mamba_per_slot > 0 else 0, + } def allocate_kv_cache(self, num_kvcache_blocks): pre_alloc = torch.cuda.memory_stats()["allocated_bytes.all.current"] @@ -1281,14 +1329,12 @@ def allocate_kv_cache(self, num_kvcache_blocks): ) mamba_dtypes = self.gated_delta_net_state_dtypes() self.mamba_k_cache = torch.zeros( - (self.num_gdn_attn_state, self.num_physical_kvcache_blocks) - + mamba_shape[0], + (self.num_gdn_attn_state, self.max_mamba_slots) + mamba_shape[0], dtype=mamba_dtypes[0], device="cuda", ) self.mamba_v_cache = torch.zeros( - (self.num_gdn_attn_state, self.num_physical_kvcache_blocks) - + mamba_shape[1], + (self.num_gdn_attn_state, self.max_mamba_slots) + mamba_shape[1], dtype=mamba_dtypes[1], device="cuda", ) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 06b35494e..b4953918a 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -214,8 +214,10 @@ def __init__( self.num_bonus = np.asarray( [seq.num_bonus_tokens for seq in seqs.values()], dtype=np.int32 ) - self.mamba_block_tables = [ - seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table + self.mamba_state_slots = [ + seq.mamba_state_slot + for seq in seqs.values() + if seq.mamba_enabled and seq.mamba_state_slot >= 0 ] self.top_ks = np.asarray([seq.top_k for seq in seqs.values()], dtype=np.int32) self.top_ps = np.asarray([seq.top_p for seq in seqs.values()], dtype=np.float32) diff --git a/atom/model_engine/sequence.py b/atom/model_engine/sequence.py index 617db9726..0cbd25d19 100644 --- a/atom/model_engine/sequence.py +++ b/atom/model_engine/sequence.py @@ -56,7 +56,7 @@ def __init__( self.num_rejected = 0 self.num_cached_tokens = 0 self.block_table = [] - self.mamba_block_table = [] + self.mamba_state_slot = -1 # per-request recurrent state slot index self.temperature = sampling_params.temperature self.top_k = sampling_params.top_k self.top_p = sampling_params.top_p @@ -96,11 +96,6 @@ def num_tokens(self): def num_tokens(self, value): self._num_tokens = value self.num_blocks = (value + self.block_size - 1) // self.block_size - # for mamba-like arch, we need to make sure there are always 1 + spec number of blocks - if self.mamba_enabled: - self.num_mamba_blocks = 1 + self.num_draft_tokens - else: - self.num_mamba_blocks = 0 self.last_block_num_tokens = ( self._num_tokens - (self.num_blocks - 1) * self.block_size ) diff --git a/atom/model_ops/attentions/gdn_attn.py b/atom/model_ops/attentions/gdn_attn.py index f58af7f80..d4ccfccb0 100644 --- a/atom/model_ops/attentions/gdn_attn.py +++ b/atom/model_ops/attentions/gdn_attn.py @@ -138,16 +138,18 @@ def __init__( def prepare_state_indices(self, batch: ScheduledBatch, with_spec: bool = False): non_spec_state_indices = self.non_spec_state_indices_tensor.np spec_state_indices = self.spec_state_indices_tensor.np - for idx, mamba_block_table in enumerate(batch.mamba_block_tables): + slots_per_group = 1 + self.num_spec + for idx, slot_group in enumerate(batch.mamba_state_slots): non_spec_state_indices[idx] = 0 spec_state_indices[idx] = 0 + base = slot_group * slots_per_group if not with_spec: - non_spec_state_indices[idx] = mamba_block_table[0] + non_spec_state_indices[idx] = base else: - spec_state_indices[idx, : 1 + self.num_spec] = mamba_block_table[ - : 1 + self.num_spec - ] + spec_state_indices[idx, : 1 + self.num_spec] = np.arange( + base, base + 1 + self.num_spec + ) def prepare_num_accepted_tokens(self, batch: ScheduledBatch): self.num_accepted_tokens.fill_(1) diff --git a/docs/architecture_guide.md b/docs/architecture_guide.md index 5f1ae903c..f23bee516 100644 --- a/docs/architecture_guide.md +++ b/docs/architecture_guide.md @@ -192,6 +192,7 @@ The `Sequence` class (in `atom/model_engine/sequence.py`) is the central data st | `num_prompt_tokens` | `int` | Length of the original prompt | | `num_tokens` | `int` (property) | Total length including generated tokens | | `block_table` | `list[int]` | KV cache block IDs allocated to this sequence | +| `mamba_state_slot` | `int` | Per-request GDN recurrent state slot index for hybrid models (Qwen3-Next, Qwen3.5); `-1` if unallocated or not a hybrid model | | `status` | `SequenceStatus` | Current lifecycle state | | `type` | `SequenceType` | Current execution type | | `temperature` | `float` | Sampling temperature | diff --git a/docs/configuration_guide.md b/docs/configuration_guide.md index c4592cdfa..05f2ebed4 100644 --- a/docs/configuration_guide.md +++ b/docs/configuration_guide.md @@ -61,12 +61,15 @@ Defined in `atom/config.py`. The root dataclass that the engine consumes. | `eos_token_id` | `int` | `-1` | End-of-sequence token ID (`-1` = use model default) | | `stop_token_ids` | `list[int]` | `[]` | Additional stop token IDs; populated from `GenerationConfig.eos_token_id` during init | -**Auto-derived fields** (set in `__post_init__`, not user-supplied): +**Auto-derived fields** (set in `__post_init__` or by `ModelRunner.get_num_blocks()`, not user-supplied): | Field | Type | Description | |---|---|---| | `hf_config` | `PretrainedConfig` | Loaded automatically via `get_hf_config(model)` | | `generation_config` | `GenerationConfig` | Loaded automatically via `get_generation_config(model)` | +| `mamba_equiv_per_req` | `int` | Number of KV cache block equivalents reserved per request for GDN recurrent state (hybrid models only); computed by `ModelRunner.get_num_blocks()` | +| `num_mamba_groups` | `int` | Number of per-request GDN state slot groups available (= `max_num_seqs` for hybrid models, 0 otherwise); computed by `ModelRunner.get_num_blocks()` | +| `max_mamba_slots` | `int` | Maximum number of GDN state slots (including slots for speculative tokens); computed by `ModelRunner.get_num_blocks()` | --- diff --git a/docs/model_support_guide.md b/docs/model_support_guide.md index 002a3acc4..1b558836d 100644 --- a/docs/model_support_guide.md +++ b/docs/model_support_guide.md @@ -133,6 +133,7 @@ ATOM resolves the HuggingFace `architectures` field from a model's `config.json` - **Architecture:** Hybrid MoE transformer with two attention types: full attention (`Qwen3NextAttention`) and Gated DeltaNet linear attention (`Qwen3NextGatedDeltaNet`). Layer type is determined by `config.layer_types`. - **Layer structure:** `Qwen3NextDecoderLayer` containing either full attention or linear attention, plus either `Qwen3NextSparseMoeBlock` (MoE layers) or `Qwen3NextMLP` (dense layers). - **Attention:** Full attention layers use `QKVParallelLinear` with QK norm, RoPE, GQA. Linear attention layers use `QKVZBAParallelLinear` for fused QKVZ+BA projections with Gated DeltaNet recurrence. +- **GDN Recurrent State:** The Gated DeltaNet linear attention layers maintain per-request recurrent state. ATOM manages this state via a dedicated per-request slot pool (separate from KV cache blocks). Each sequence is assigned a `mamba_state_slot` index during allocation, and the state memory is accounted for dynamically as block equivalents within the unified KV pool. - **MoE:** `Qwen3NextSparseMoeBlock` with `FusedMoE`, shared expert fusion support. - **Normalization:** Uses `GemmaRMSNorm` (aliased as `Qwen3NextRMSNorm`). - **MTP:** Separate draft model in `atom/models/qwen3_next_mtp.py` (`Qwen3NextMTP`). @@ -142,6 +143,7 @@ ATOM resolves the HuggingFace `architectures` field from a model's `config.json` - **Architecture:** Hybrid transformer with two attention types: full attention and Gated DeltaNet linear attention. Layer type is determined by `config.layer_types`. Dense or MoE variants. - **Layer structure:** `Qwen3_5DecoderLayer` containing either full attention or linear attention, plus either `Qwen3_5SparseMoeBlock` (MoE variants) or `Qwen3_5MLP` (dense variants). - **Attention:** Full attention layers use `QKVParallelLinear` with QK norm, RoPE, GQA. Linear attention layers use `QKVZBAParallelLinear` for fused QKVZ+BA projections with Gated DeltaNet. +- **GDN Recurrent State:** Like Qwen3-Next, the Gated DeltaNet layers maintain per-request recurrent state managed via the slot pool. Qwen3.5 models (both dense and MoE variants) use the same unified memory management as Qwen3-Next. - **MoE:** `Qwen3_5SparseMoeBlock` with `FusedMoE`, shared expert fusion support. - **Normalization:** RMSNorm with optional fused allreduce for MoE models. - **MTP:** Separate draft model in `atom/models/qwen3_5_mtp.py` (`Qwen3_5MTP`). The MTP predictor uses only full attention layers (no Gated DeltaNet) for efficiency, supporting both MTP1 and MTP3 variants via `num_speculative_tokens`. diff --git a/docs/scheduling_kv_cache_guide.md b/docs/scheduling_kv_cache_guide.md index 6cac8a670..f3eaea027 100644 --- a/docs/scheduling_kv_cache_guide.md +++ b/docs/scheduling_kv_cache_guide.md @@ -233,10 +233,24 @@ class BlockManager: self.free_block_ids: deque[int] = deque(range(num_blocks)) self.used_block_ids: set[int] = set() self.enable_prefix_caching = config.enable_prefix_caching + + # Mamba/GDN recurrent state: per-request slot groups + equiv-block accounting + # Each slot group contains (1+num_spec) contiguous tensor indices. + # free_mamba_slots tracks group indices (0..num_groups-1). + self.mamba_equiv_per_req: int = getattr(config, "mamba_equiv_per_req", 0) + num_mamba_groups: int = getattr(config, "num_mamba_groups", 0) + self.free_mamba_slots: list[int] = list(range(num_mamba_groups)) + # seq_id → list of accounting block_ids + self.mamba_accounting: dict[int, list[int]] = {} ``` The block pool is pre-allocated at startup. `free_block_ids` is a deque for O(1) pop/push, `used_block_ids` tracks active blocks, and `hash_to_block_id` maps content hashes to block IDs for prefix caching. +**Mamba/GDN State Pools (Hybrid Models):** For models with Gated DeltaNet (GDN) recurrent attention (Qwen3-Next, Qwen3.5): +- `free_mamba_slots` -- list of available per-request state slot indices (0 to `num_mamba_groups - 1`). Each slot holds one request's (or one speculative token's) recurrent state. +- `mamba_accounting` -- maps sequence ID to a list of equivalent block IDs used for memory accounting. The unified pool manages both KV cache blocks and GDN state through dynamic competition; GDN memory is accounted for as block equivalents. +- `mamba_equiv_per_req` -- number of KV cache block equivalents reserved per request for its GDN state. + ### 3.3 Allocation (`allocate`) Called during prefill scheduling for new sequences: @@ -245,6 +259,8 @@ Called during prefill scheduling for new sequences: def allocate(self, seq: Sequence): ``` +**KV Cache allocation:** + 1. Iterates over `seq.num_blocks` blocks. 2. For each block, computes hash if the block is full (`len(token_ids) == block_size`). Partial (last) blocks get `hash = -1`. 3. If prefix caching is enabled, looks up `hash_to_block_id`: @@ -252,6 +268,12 @@ def allocate(self, seq: Sequence): - **Cache miss:** Allocates from `free_block_ids[0]`. 4. Full blocks are registered in `hash_to_block_id`. +**Mamba/GDN state allocation (if `seq.mamba_enabled`):** + +1. Allocates `mamba_equiv_per_req` accounting blocks from the free pool (for memory accounting). +2. Stores these block IDs in `mamba_accounting[seq.id]` to track GDN memory usage. +3. Pops one slot index from `free_mamba_slots` and assigns it to `seq.mamba_state_slot` (per-request state indexing). + ### 3.4 Deallocation (`deallocate`) Called when a sequence finishes or is preempted: @@ -265,22 +287,49 @@ def deallocate(self, seq: Sequence): self._deallocate_block(block_id) seq.num_cached_tokens = 0 seq.block_table.clear() + if seq.mamba_enabled and seq.mamba_state_slot >= 0: + for block_id in self.mamba_accounting.pop(seq.id, []): + block = self.blocks[block_id] + block.ref_count = 0 # accounting blocks bypass ref-counting + self._deallocate_block(block_id) + self.free_mamba_slots.append(seq.mamba_state_slot) + seq.mamba_state_slot = -1 ``` -Blocks are released in reverse order. Shared blocks (with `ref_count > 1` from prefix caching) are not freed until all referencing sequences release them. +**KV Cache deallocation:** Blocks are released in reverse order. Shared blocks (with `ref_count > 1` from prefix caching) are not freed until all referencing sequences release them. + +**Mamba/GDN state deallocation (if `seq.mamba_enabled`):** + +1. Releases all accounting blocks for this sequence from `mamba_accounting[seq.id]` directly (bypassing ref-counting, as they are internal to the accounting system). +2. Returns the slot index `seq.mamba_state_slot` to `free_mamba_slots` for reuse. +3. Clears `seq.mamba_state_slot` to `-1` to mark it as released. ### 3.5 Can-Allocate and Can-Append Checks ```python def can_allocate(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= seq.num_blocks - -def can_append(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + mamba_cost = self.mamba_equiv_per_req if seq.mamba_enabled else 0 + mamba_slot_ok = (not seq.mamba_enabled) or len(self.free_mamba_slots) > 0 + if not self.enable_prefix_caching: + return ( + len(self.free_block_ids_set) >= seq.num_blocks + mamba_cost + and mamba_slot_ok + ) + # ... (prefix caching dry-run logic with mamba_cost included) + +def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool: + seq_len = len(seq) + current_blocks = len(seq.block_table) + needed_blocks = (seq_len + num_new_tokens + self.block_size - 1) // self.block_size + new_blocks_needed = max(0, needed_blocks - current_blocks) + return len(self.free_block_ids_set) >= new_blocks_needed ``` -- `can_allocate` checks that enough free blocks exist for the full sequence. -- `can_append` checks whether a decode step needs a new block. A new block is needed only when `len(seq) % block_size == 1` (the previous block just filled up), requiring exactly 1 free block. +- `can_allocate` checks that: + - Enough free KV blocks exist for the full sequence (`seq.num_blocks + mamba_cost` accounting blocks for GDN state if hybrid). + - At least one mamba slot is available if the sequence has `mamba_enabled=True`. + +- `can_append` checks whether a decode step needs a new block. Calculates the required block count given `num_new_tokens` (typically `mtp_k + 1` for speculative decode) and returns whether enough free blocks remain. ### 3.6 May-Append (Decode Extension) @@ -517,6 +566,8 @@ class Sequence: | `num_prompt_tokens` | `int` | Number of prompt tokens (fixed at init) | | `num_cached_tokens` | `int` | Tokens served from prefix cache | | `block_table` | `list[int]` | Ordered list of block IDs assigned to this sequence | +| `mamba_enabled` | `bool` | Whether the model uses Gated DeltaNet recurrent attention (set at sequence init) | +| `mamba_state_slot` | `int` | Per-request GDN recurrent state slot index (assigned by BlockManager during allocation, `-1` if unallocated) | | `last_token` | `int` | Most recently appended token ID | | `temperature` | `float` | Sampling temperature (from `SamplingParams`) | | `max_tokens` | `int` | Max completion tokens (from `SamplingParams`, default 64) | diff --git a/tests/test_mamba_state_decoupling.py b/tests/test_mamba_state_decoupling.py new file mode 100644 index 000000000..f2512526c --- /dev/null +++ b/tests/test_mamba_state_decoupling.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: MIT +# Tests for GDN recurrent state decoupling: unified block pool + mamba slot management. + + +from conftest import MockConfig +from atom.model_engine.block_manager import BlockManager +from atom.model_engine.sequence import Sequence +from atom.model_engine.scheduler import Scheduler, ScheduledBatch + +# ── helpers ──────────────────────────────────────────────────────────────── + + +def mamba_config(**overrides): + """Config with mamba state management enabled.""" + defaults = dict( + kv_cache_block_size=4, + num_kvcache_blocks=100, + enable_prefix_caching=False, + max_num_seqs=8, + max_num_batched_tokens=256, + bos_token_id=1, + eos_token_id=2, + stop_token_ids=[], + scheduler_delay_factor=0.0, + speculative_config=None, + mamba_equiv_per_req=5, # each mamba request costs 5 equiv blocks + num_mamba_groups=8, # max 8 concurrent mamba requests + mamba_slots_per_group=1, # no spec decode + ) + defaults.update(overrides) + return MockConfig(**defaults) + + +def mamba_seq(token_ids, block_size=4, **kwargs): + return Sequence(token_ids, block_size, mamba_enabled=True, **kwargs) + + +def plain_seq(token_ids, block_size=4, **kwargs): + return Sequence(token_ids, block_size, mamba_enabled=False, **kwargs) + + +# ── BlockManager: mamba slot allocation ──────────────────────────────────── + + +class TestBlockManagerMambaSlots: + + def test_mamba_disabled_no_slots(self): + """Non-mamba config: no slots allocated, behaves like before.""" + bm = BlockManager(MockConfig(num_kvcache_blocks=50)) + assert len(bm.free_mamba_slots) == 0 + assert bm.mamba_equiv_per_req == 0 + + def test_mamba_enabled_has_slots(self): + bm = BlockManager(mamba_config()) + assert len(bm.free_mamba_slots) == 8 + assert bm.mamba_equiv_per_req == 5 + + def test_allocate_assigns_slot(self): + bm = BlockManager(mamba_config()) + seq = mamba_seq([1, 2, 3, 4]) + bm.allocate(seq) + assert seq.mamba_state_slot >= 0 + assert seq.mamba_state_slot < 8 + assert len(bm.free_mamba_slots) == 7 + + def test_allocate_deducts_equiv_blocks(self): + bm = BlockManager(mamba_config()) + initial_free = len(bm.free_block_ids_set) + seq = mamba_seq([1, 2, 3, 4]) # 1 KV block + bm.allocate(seq) + # 1 KV block + 5 equiv blocks = 6 total deducted + assert len(bm.free_block_ids_set) == initial_free - 6 + assert seq.id in bm.mamba_accounting + assert len(bm.mamba_accounting[seq.id]) == 5 + + def test_deallocate_returns_slot_and_blocks(self): + bm = BlockManager(mamba_config()) + initial_free = len(bm.free_block_ids_set) + seq = mamba_seq([1, 2, 3, 4]) + bm.allocate(seq) + bm.deallocate(seq) + assert seq.mamba_state_slot == -1 + assert len(bm.free_block_ids_set) == initial_free + assert len(bm.free_mamba_slots) == 8 + assert seq.id not in bm.mamba_accounting + + def test_can_allocate_checks_both_kv_and_mamba(self): + """can_allocate must check KV blocks AND mamba slots.""" + bm = BlockManager(mamba_config(num_kvcache_blocks=100)) + seq = mamba_seq([1, 2, 3, 4]) + assert bm.can_allocate(seq) is True + + def test_can_allocate_fails_not_enough_blocks(self): + """Not enough free blocks for KV + mamba equiv.""" + bm = BlockManager(mamba_config(num_kvcache_blocks=5)) + seq = mamba_seq([1, 2, 3, 4]) # needs 1 KV + 5 equiv = 6 blocks + assert bm.can_allocate(seq) is False + + def test_can_allocate_fails_no_mamba_slots(self): + """All mamba slots exhausted.""" + bm = BlockManager(mamba_config(num_mamba_groups=1)) + seq1 = mamba_seq([1, 2, 3, 4]) + bm.allocate(seq1) + seq2 = mamba_seq([5, 6, 7, 8]) + assert bm.can_allocate(seq2) is False + + def test_plain_seq_ignores_mamba(self): + """Non-mamba sequence should not use mamba slots.""" + bm = BlockManager(mamba_config()) + initial_slots = len(bm.free_mamba_slots) + seq = plain_seq([1, 2, 3, 4]) + bm.allocate(seq) + assert seq.mamba_state_slot == -1 + assert len(bm.free_mamba_slots) == initial_slots + assert seq.id not in bm.mamba_accounting + + def test_multiple_allocate_deallocate(self): + """Allocate and deallocate multiple mamba sequences.""" + bm = BlockManager(mamba_config(num_kvcache_blocks=200)) + seqs = [mamba_seq([1, 2, 3, 4], id=i + 100) for i in range(8)] + slots = set() + for seq in seqs: + bm.allocate(seq) + slots.add(seq.mamba_state_slot) + # All 8 slots used + assert len(slots) == 8 + assert len(bm.free_mamba_slots) == 0 + + # Deallocate all + for seq in seqs: + bm.deallocate(seq) + assert len(bm.free_mamba_slots) == 8 + + def test_slot_reuse_after_dealloc(self): + """Freed slots can be reused.""" + bm = BlockManager(mamba_config(num_mamba_groups=2, num_kvcache_blocks=200)) + s1 = mamba_seq([1, 2, 3, 4]) + s2 = mamba_seq([5, 6, 7, 8]) + bm.allocate(s1) + bm.allocate(s2) + assert len(bm.free_mamba_slots) == 0 + + slot1 = s1.mamba_state_slot + bm.deallocate(s1) + assert len(bm.free_mamba_slots) == 1 + + s3 = mamba_seq([9, 10, 11, 12]) + bm.allocate(s3) + assert s3.mamba_state_slot == slot1 # reused + + def test_dynamic_competition(self): + """KV and mamba compete for same pool — long sequence reduces mamba capacity.""" + bm = BlockManager(mamba_config(num_kvcache_blocks=20, mamba_equiv_per_req=5)) + # Allocate a long plain sequence (16 tokens → 4 KV blocks) + long_seq = plain_seq(list(range(16))) + bm.allocate(long_seq) + # 20 - 4 = 16 free blocks + # mamba seq needs 1 KV + 5 equiv = 6 blocks + assert bm.can_allocate(mamba_seq([1, 2, 3, 4])) + s1 = mamba_seq([1, 2, 3, 4]) + bm.allocate(s1) # 16 - 6 = 10 free + s2 = mamba_seq([1, 2, 3, 4]) + bm.allocate(s2) # 10 - 6 = 4 free + s3 = mamba_seq([1, 2, 3, 4]) + assert bm.can_allocate(s3) is False # 4 < 6 + + +# ── Sequence: mamba_state_slot field ────────────────────────────────────── + + +class TestSequenceMambaSlot: + + def test_default_slot_negative(self): + seq = Sequence([1, 2, 3], 4, mamba_enabled=True) + assert seq.mamba_state_slot == -1 + assert seq.mamba_enabled is True + + def test_plain_seq_no_slot(self): + seq = Sequence([1, 2, 3], 4, mamba_enabled=False) + assert seq.mamba_state_slot == -1 + assert seq.mamba_enabled is False + + def test_no_num_mamba_blocks(self): + """num_mamba_blocks should no longer exist on Sequence.""" + seq = Sequence([1, 2, 3], 4, mamba_enabled=True) + assert not hasattr(seq, "num_mamba_blocks") + + +# ── ScheduledBatch: mamba_state_slots ───────────────────────────────────── + + +class TestScheduledBatchMambaSlots: + + def test_mamba_state_slots_collected(self): + s1 = mamba_seq([1, 2, 3, 4]) + s1.mamba_state_slot = 3 + s1.status = s1.status # keep as WAITING + s2 = plain_seq([5, 6, 7, 8]) + seqs = {s1.id: s1, s2.id: s2} + batch = ScheduledBatch( + seqs=seqs, + num_scheduled_tokens=[4, 4], + total_tokens_num=8, + total_seqs_num=2, + total_seqs_num_prefill=2, + ) + assert batch.mamba_state_slots == [3] + + def test_no_mamba_seqs(self): + s1 = plain_seq([1, 2, 3, 4]) + seqs = {s1.id: s1} + batch = ScheduledBatch( + seqs=seqs, + num_scheduled_tokens=[4], + total_tokens_num=4, + total_seqs_num=1, + total_seqs_num_prefill=1, + ) + assert batch.mamba_state_slots == [] + + +# ── State index mapping ────────────────────────────────────────────────── + + +class TestStateIndexMapping: + """Verify the slot_group → tensor index mapping logic used in gdn_attn.""" + + def test_non_spec_mapping(self): + """Non-spec: tensor_index = slot_group * slots_per_group.""" + slots_per_group = 4 # 1 + 3 spec + slot_group = 7 + base = slot_group * slots_per_group + assert base == 28 + + def test_spec_mapping(self): + """Spec: contiguous indices [base, base+1, ..., base+num_spec].""" + num_spec = 3 + slots_per_group = 1 + num_spec + slot_group = 5 + base = slot_group * slots_per_group + indices = list(range(base, base + 1 + num_spec)) + assert indices == [20, 21, 22, 23] + + def test_all_indices_in_range(self): + """All generated indices must be < max_mamba_slots.""" + max_num_seqs = 256 + num_spec = 3 + slots_per_group = 1 + num_spec + max_mamba_slots = max_num_seqs * slots_per_group + # Check the last group + last_group = max_num_seqs - 1 + base = last_group * slots_per_group + indices = list(range(base, base + 1 + num_spec)) + assert all(0 <= i < max_mamba_slots for i in indices) + assert indices[-1] == max_mamba_slots - 1 + + +# ── Scheduler integration ──────────────────────────────────────────────── + + +class TestSchedulerMambaIntegration: + + def test_prefill_mamba_seq(self): + """Scheduler prefill allocates mamba slot via block_manager.""" + sched = Scheduler(mamba_config(num_kvcache_blocks=100)) + seq = mamba_seq([1, 2, 3, 4]) + sched.add(seq) + batch, _ = sched.schedule() + assert batch.total_seqs_num_prefill == 1 + assert seq.mamba_state_slot >= 0 + assert len(batch.mamba_state_slots) == 1 + + def test_preempt_releases_mamba_slot(self): + """Preempted mamba sequence releases its slot.""" + sched = Scheduler(mamba_config(num_kvcache_blocks=100)) + seq = mamba_seq([1, 2, 3, 4]) + sched.add(seq) + sched.schedule() + assert seq.mamba_state_slot >= 0 + initial_slots = len(sched.block_manager.free_mamba_slots) + sched.preempt(seq) + assert seq.mamba_state_slot == -1 + assert len(sched.block_manager.free_mamba_slots) == initial_slots + 1 + + def test_mamba_slot_exhaustion_blocks_prefill(self): + """When all mamba slots are used, new mamba requests wait.""" + sched = Scheduler(mamba_config(num_kvcache_blocks=200, num_mamba_groups=2)) + s1 = mamba_seq([1, 2, 3, 4]) + s2 = mamba_seq([5, 6, 7, 8]) + s3 = mamba_seq([9, 10, 11, 12]) + sched.extend([s1, s2, s3]) + batch, _ = sched.schedule() + # Only 2 slots → only 2 prefilled + assert batch.total_seqs_num_prefill == 2 + assert sched.get_num_unfinished_requests() == 3 + + def test_mixed_mamba_and_plain(self): + """Mamba and plain sequences coexist — plain doesn't consume slots.""" + sched = Scheduler(mamba_config(num_kvcache_blocks=200, num_mamba_groups=2)) + s1 = mamba_seq([1, 2, 3, 4]) + s2 = plain_seq([5, 6, 7, 8]) + s3 = mamba_seq([9, 10, 11, 12]) + s4 = plain_seq([13, 14, 15, 16]) + sched.extend([s1, s2, s3, s4]) + batch, _ = sched.schedule() + # All 4 should prefill — only 2 mamba slots needed + assert batch.total_seqs_num_prefill == 4