Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .claude/commands/benchmark-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ curl -X POST http://127.0.0.1:8000/stop_profile
| E2EL | End-to-End Latency |
| Throughput | Output tokens/s |

## Critical Rules

- **Never run accuracy and performance tests simultaneously** — they interfere with each other's results. Always finish one before starting the other.
- **Report Total Token throughput (tok/s)**, not Output token throughput — Total includes input+output.
- **MTP models require `--use-chat-template`** — tokenizer mismatch without it causes inaccurate results.

## CI Benchmark Workflow

- File: `.github/workflows/atom-benchmark.yaml`
Expand Down
2 changes: 1 addition & 1 deletion .claude/commands/debug-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
- **[COMMON] "Failed to allocate kv cache":** `num_blocks` returned 0. Model + KV cache don't fit in VRAM. **Fix:** Reduce `max_num_seqs` or use `--kv_cache_dtype fp8`
- **Block exhaustion during serving:** Scheduler preempts last running sequence when blocks run out
- **Block leak (ref_count never reaches 0):** Verify `seq.block_table` is cleared on sequence completion
- **Mamba/GDN block table:** Hybrid models use separate `mamba_block_table` for linear attention state. No prefix caching support
- **Mamba/GDN state slots:** Hybrid models use `mamba_state_slot` (per-request slot from unified pool) for recurrent state. Pool tracks mamba memory via equiv-block accounting

## Scheduler

Expand Down
42 changes: 29 additions & 13 deletions atom/model_engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def __init__(self, config: Config):
self.used_block_ids: set[int] = set()
self.enable_prefix_caching = config.enable_prefix_caching

# Mamba/GDN recurrent state: per-request slot groups + equiv-block accounting
# Each slot group contains (1+num_spec) contiguous tensor indices.
# free_mamba_slots tracks group indices (0..num_groups-1).
self.mamba_equiv_per_req: int = getattr(config, "mamba_equiv_per_req", 0)
num_mamba_groups: int = getattr(config, "num_mamba_groups", 0)
self.free_mamba_slots: list[int] = list(range(num_mamba_groups))
# seq_id → list of accounting block_ids
self.mamba_accounting: dict[int, list[int]] = {}

@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
Expand Down Expand Up @@ -76,8 +85,13 @@ def _deallocate_block(self, block_id: int):
self.free_block_ids_set.add(block_id)

def can_allocate(self, seq: Sequence) -> bool:
mamba_cost = self.mamba_equiv_per_req if seq.mamba_enabled else 0
mamba_slot_ok = (not seq.mamba_enabled) or len(self.free_mamba_slots) > 0
if not self.enable_prefix_caching:
return len(self.free_block_ids_set) >= seq.num_blocks + seq.num_mamba_blocks
return (
len(self.free_block_ids_set) >= seq.num_blocks + mamba_cost
and mamba_slot_ok
)
# Dry-run: count how many blocks would be cache hits
h = -1
cache_miss = False
Expand All @@ -94,7 +108,9 @@ def can_allocate(self, seq: Sequence) -> bool:
cache_miss = True
if cache_miss:
needed_free += 1
return len(self.free_block_ids_set) >= needed_free + seq.num_mamba_blocks
return (
len(self.free_block_ids_set) >= needed_free + mamba_cost and mamba_slot_ok
)

def allocate(self, seq: Sequence):
assert not seq.block_table
Expand Down Expand Up @@ -127,15 +143,15 @@ def allocate(self, seq: Sequence):
self.hash_to_block_id[h] = block_id
seq.block_table.append(block_id)

# handle mamba-like model
# Mamba/GDN recurrent state: allocate equiv blocks (accounting) + slot (indexing)
if seq.mamba_enabled:
# For mamba, we need to ensure the last block is always allocated
# even if it has less than block_size tokens
for i in range(seq.num_mamba_blocks):
accounting_blocks = []
for _ in range(self.mamba_equiv_per_req):
block_id = self._pop_free_block()
self._allocate_block(block_id)
# No prefix caching support for mamba arch
seq.mamba_block_table.append(block_id)
accounting_blocks.append(block_id)
self.mamba_accounting[seq.id] = accounting_blocks
seq.mamba_state_slot = self.free_mamba_slots.pop()

def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
Expand All @@ -145,13 +161,13 @@ def deallocate(self, seq: Sequence):
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
if seq.mamba_enabled:
for block_id in reversed(seq.mamba_block_table):
if seq.mamba_enabled and seq.mamba_state_slot >= 0:
for block_id in self.mamba_accounting.pop(seq.id, []):
block = self.blocks[block_id]
# just in case
block.ref_count = 0
block.ref_count = 0 # accounting blocks bypass ref-counting
self._deallocate_block(block_id)
seq.mamba_block_table.clear()
self.free_mamba_slots.append(seq.mamba_state_slot)
seq.mamba_state_slot = -1

def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool:
seq_len = len(seq)
Expand Down
5 changes: 4 additions & 1 deletion atom/model_engine/engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def __init__(self, config: Config, input_address: str, output_address: str):
"atom.model_engine.model_runner.ModelRunner",
config,
)
num_blocks = self.runner_mgr.call_func("get_num_blocks", wait_out=True)
block_info = self.runner_mgr.call_func("get_num_blocks", wait_out=True)
num_blocks = block_info["num_kvcache_blocks"]
config.mamba_equiv_per_req = block_info.get("mamba_equiv_per_req", 0)
config.num_mamba_groups = block_info.get("num_mamba_groups", 0)
ret = self.runner_mgr.call_func(
"allocate_kv_cache", num_blocks, wait_out=True
)
Expand Down
96 changes: 71 additions & 25 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,24 +1076,9 @@ def _compute_block_bytes(self):
* 4 # float32
)

# gdn attn bytes
mamba_shape = self.gated_delta_net_state_shape(
get_tp_group().world_size,
hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads,
hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
self.num_spec_tokens,
)
mamba_dtypes = self.gated_delta_net_state_dtypes()
one_layer_byte = (
math.prod(mamba_shape[0])
* torch.tensor([], dtype=mamba_dtypes[0]).element_size()
+ math.prod(mamba_shape[1])
* torch.tensor([], dtype=mamba_dtypes[1]).element_size()
)
block_bytes += self.num_gdn_attn_state * one_layer_byte
# GDN recurrent state is per-request (not per-block).
# It is accounted for separately via _compute_mamba_per_slot_bytes().
# Do NOT add it to block_bytes.
else:
# Standard attention: kv_cache [2, num_hidden_layers, blocks, ...]
# Note: allocate_kv_cache uses hf_config.num_hidden_layers for
Expand All @@ -1116,6 +1101,31 @@ def _compute_block_bytes(self):
)
return block_bytes

def _compute_mamba_per_slot_bytes(self) -> int:
"""Compute per-slot recurrent state bytes (all GDN layers, one slot).

A slot holds one request's state (or one spec token's state).
Returns 0 for non-GDN models.
"""
if not self.is_qwen_next():
return 0
hf_config = self.config.hf_config
mamba_shape = self.gated_delta_net_state_shape(
get_tp_group().world_size,
hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads,
hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
self.num_spec_tokens,
)
mamba_dtypes = self.gated_delta_net_state_dtypes()
one_layer_byte = (
math.prod(mamba_shape[0]) * mamba_dtypes[0].itemsize
+ math.prod(mamba_shape[1]) * mamba_dtypes[1].itemsize
)
return self.num_gdn_attn_state * one_layer_byte

def _estimate_cudagraph_overhead(self):
"""Estimate GPU memory consumed by CUDA graph capture.

Expand All @@ -1135,7 +1145,7 @@ def _estimate_cudagraph_overhead(self):
# memory due to pooling across multiple captured batch sizes.
return int(activation_bytes * 0.2)

def get_num_blocks(self):
def get_num_blocks(self) -> dict[str, int]:
torch.set_default_device(self.device)
config = self.config
hf_config = config.hf_config
Expand Down Expand Up @@ -1169,7 +1179,33 @@ def get_num_blocks(self):
torch.set_default_device("cpu")

block_bytes = self._compute_block_bytes()
num_kvcache_blocks = available_for_kv // block_bytes

# GDN recurrent state: deduct mamba tensor memory from pool budget
mamba_per_slot = self._compute_mamba_per_slot_bytes()
slots_per_req = 1 + self.num_spec_tokens
max_mamba_slots = (
config.max_num_seqs * slots_per_req if mamba_per_slot > 0 else 0
)
mamba_tensor_bytes = max_mamba_slots * mamba_per_slot
available_for_pool = available_for_kv - mamba_tensor_bytes
if available_for_pool <= 0:
raise RuntimeError(
f"GDN mamba tensor ({mamba_tensor_bytes / (1 << 30):.2f}GB for "
f"{max_mamba_slots} slots) exceeds available KV budget "
f"({available_for_kv / (1 << 30):.2f}GB). "
f"Reduce --max-num-seqs or increase gpu_memory_utilization."
)
mamba_equiv = (
math.ceil(mamba_per_slot / block_bytes) if mamba_per_slot > 0 else 0
)

# Store for BlockManager and allocate_kv_cache
config.mamba_equiv_per_req = mamba_equiv
config.max_mamba_slots = max_mamba_slots
config.num_mamba_groups = config.max_num_seqs if mamba_per_slot > 0 else 0
self.max_mamba_slots = max_mamba_slots

num_kvcache_blocks = available_for_pool // block_bytes

logger.info(
f"Memory budget: total_gpu={total / (1 << 30):.2f}GB, "
Expand All @@ -1183,6 +1219,14 @@ def get_num_blocks(self):
f"block_bytes={block_bytes}, "
f"num_kvcache_blocks={num_kvcache_blocks}"
)
if mamba_per_slot > 0:
logger.info(
f"GDN state pool: mamba_per_slot={mamba_per_slot / (1 << 20):.2f}MB, "
f"max_mamba_slots={max_mamba_slots}, "
f"mamba_tensor={mamba_tensor_bytes / (1 << 30):.2f}GB, "
f"mamba_equiv_blocks_per_req={mamba_equiv}, "
f"pool_blocks={num_kvcache_blocks}"
)

assert num_kvcache_blocks > 0, (
f"Not enough memory for KV cache with block size({self.block_size}). "
Expand All @@ -1194,7 +1238,11 @@ def get_num_blocks(self):
f"safety={safety_margin / (1 << 30):.2f}GB, "
f"free={free / (1 << 30):.2f}GB)"
)
return num_kvcache_blocks
return {
"num_kvcache_blocks": num_kvcache_blocks,
"mamba_equiv_per_req": mamba_equiv,
"num_mamba_groups": config.max_num_seqs if mamba_per_slot > 0 else 0,
}

def allocate_kv_cache(self, num_kvcache_blocks):
pre_alloc = torch.cuda.memory_stats()["allocated_bytes.all.current"]
Expand Down Expand Up @@ -1281,14 +1329,12 @@ def allocate_kv_cache(self, num_kvcache_blocks):
)
mamba_dtypes = self.gated_delta_net_state_dtypes()
self.mamba_k_cache = torch.zeros(
(self.num_gdn_attn_state, self.num_physical_kvcache_blocks)
+ mamba_shape[0],
(self.num_gdn_attn_state, self.max_mamba_slots) + mamba_shape[0],
dtype=mamba_dtypes[0],
device="cuda",
)
self.mamba_v_cache = torch.zeros(
(self.num_gdn_attn_state, self.num_physical_kvcache_blocks)
+ mamba_shape[1],
(self.num_gdn_attn_state, self.max_mamba_slots) + mamba_shape[1],
dtype=mamba_dtypes[1],
device="cuda",
)
Expand Down
6 changes: 4 additions & 2 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ def __init__(
self.num_bonus = np.asarray(
[seq.num_bonus_tokens for seq in seqs.values()], dtype=np.int32
)
self.mamba_block_tables = [
seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table
self.mamba_state_slots = [
seq.mamba_state_slot
for seq in seqs.values()
if seq.mamba_enabled and seq.mamba_state_slot >= 0
]
self.top_ks = np.asarray([seq.top_k for seq in seqs.values()], dtype=np.int32)
self.top_ps = np.asarray([seq.top_p for seq in seqs.values()], dtype=np.float32)
Expand Down
7 changes: 1 addition & 6 deletions atom/model_engine/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.num_rejected = 0
self.num_cached_tokens = 0
self.block_table = []
self.mamba_block_table = []
self.mamba_state_slot = -1 # per-request recurrent state slot index
self.temperature = sampling_params.temperature
self.top_k = sampling_params.top_k
self.top_p = sampling_params.top_p
Expand Down Expand Up @@ -96,11 +96,6 @@ def num_tokens(self):
def num_tokens(self, value):
self._num_tokens = value
self.num_blocks = (value + self.block_size - 1) // self.block_size
# for mamba-like arch, we need to make sure there are always 1 + spec number of blocks
if self.mamba_enabled:
self.num_mamba_blocks = 1 + self.num_draft_tokens
else:
self.num_mamba_blocks = 0
self.last_block_num_tokens = (
self._num_tokens - (self.num_blocks - 1) * self.block_size
)
Expand Down
12 changes: 7 additions & 5 deletions atom/model_ops/attentions/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,18 @@ def __init__(
def prepare_state_indices(self, batch: ScheduledBatch, with_spec: bool = False):
non_spec_state_indices = self.non_spec_state_indices_tensor.np
spec_state_indices = self.spec_state_indices_tensor.np
for idx, mamba_block_table in enumerate(batch.mamba_block_tables):
slots_per_group = 1 + self.num_spec
for idx, slot_group in enumerate(batch.mamba_state_slots):
non_spec_state_indices[idx] = 0
spec_state_indices[idx] = 0
base = slot_group * slots_per_group

if not with_spec:
non_spec_state_indices[idx] = mamba_block_table[0]
non_spec_state_indices[idx] = base
else:
spec_state_indices[idx, : 1 + self.num_spec] = mamba_block_table[
: 1 + self.num_spec
]
spec_state_indices[idx, : 1 + self.num_spec] = np.arange(
base, base + 1 + self.num_spec
)

def prepare_num_accepted_tokens(self, batch: ScheduledBatch):
self.num_accepted_tokens.fill_(1)
Expand Down
1 change: 1 addition & 0 deletions docs/architecture_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ The `Sequence` class (in `atom/model_engine/sequence.py`) is the central data st
| `num_prompt_tokens` | `int` | Length of the original prompt |
| `num_tokens` | `int` (property) | Total length including generated tokens |
| `block_table` | `list[int]` | KV cache block IDs allocated to this sequence |
| `mamba_state_slot` | `int` | Per-request GDN recurrent state slot index for hybrid models (Qwen3-Next, Qwen3.5); `-1` if unallocated or not a hybrid model |
| `status` | `SequenceStatus` | Current lifecycle state |
| `type` | `SequenceType` | Current execution type |
| `temperature` | `float` | Sampling temperature |
Expand Down
5 changes: 4 additions & 1 deletion docs/configuration_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ Defined in `atom/config.py`. The root dataclass that the engine consumes.
| `eos_token_id` | `int` | `-1` | End-of-sequence token ID (`-1` = use model default) |
| `stop_token_ids` | `list[int]` | `[]` | Additional stop token IDs; populated from `GenerationConfig.eos_token_id` during init |

**Auto-derived fields** (set in `__post_init__`, not user-supplied):
**Auto-derived fields** (set in `__post_init__` or by `ModelRunner.get_num_blocks()`, not user-supplied):

| Field | Type | Description |
|---|---|---|
| `hf_config` | `PretrainedConfig` | Loaded automatically via `get_hf_config(model)` |
| `generation_config` | `GenerationConfig` | Loaded automatically via `get_generation_config(model)` |
| `mamba_equiv_per_req` | `int` | Number of KV cache block equivalents reserved per request for GDN recurrent state (hybrid models only); computed by `ModelRunner.get_num_blocks()` |
| `num_mamba_groups` | `int` | Number of per-request GDN state slot groups available (= `max_num_seqs` for hybrid models, 0 otherwise); computed by `ModelRunner.get_num_blocks()` |
| `max_mamba_slots` | `int` | Maximum number of GDN state slots (including slots for speculative tokens); computed by `ModelRunner.get_num_blocks()` |

---

Expand Down
2 changes: 2 additions & 0 deletions docs/model_support_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ ATOM resolves the HuggingFace `architectures` field from a model's `config.json`
- **Architecture:** Hybrid MoE transformer with two attention types: full attention (`Qwen3NextAttention`) and Gated DeltaNet linear attention (`Qwen3NextGatedDeltaNet`). Layer type is determined by `config.layer_types`.
- **Layer structure:** `Qwen3NextDecoderLayer` containing either full attention or linear attention, plus either `Qwen3NextSparseMoeBlock` (MoE layers) or `Qwen3NextMLP` (dense layers).
- **Attention:** Full attention layers use `QKVParallelLinear` with QK norm, RoPE, GQA. Linear attention layers use `QKVZBAParallelLinear` for fused QKVZ+BA projections with Gated DeltaNet recurrence.
- **GDN Recurrent State:** The Gated DeltaNet linear attention layers maintain per-request recurrent state. ATOM manages this state via a dedicated per-request slot pool (separate from KV cache blocks). Each sequence is assigned a `mamba_state_slot` index during allocation, and the state memory is accounted for dynamically as block equivalents within the unified KV pool.
- **MoE:** `Qwen3NextSparseMoeBlock` with `FusedMoE`, shared expert fusion support.
- **Normalization:** Uses `GemmaRMSNorm` (aliased as `Qwen3NextRMSNorm`).
- **MTP:** Separate draft model in `atom/models/qwen3_next_mtp.py` (`Qwen3NextMTP`).
Expand All @@ -142,6 +143,7 @@ ATOM resolves the HuggingFace `architectures` field from a model's `config.json`
- **Architecture:** Hybrid transformer with two attention types: full attention and Gated DeltaNet linear attention. Layer type is determined by `config.layer_types`. Dense or MoE variants.
- **Layer structure:** `Qwen3_5DecoderLayer` containing either full attention or linear attention, plus either `Qwen3_5SparseMoeBlock` (MoE variants) or `Qwen3_5MLP` (dense variants).
- **Attention:** Full attention layers use `QKVParallelLinear` with QK norm, RoPE, GQA. Linear attention layers use `QKVZBAParallelLinear` for fused QKVZ+BA projections with Gated DeltaNet.
- **GDN Recurrent State:** Like Qwen3-Next, the Gated DeltaNet layers maintain per-request recurrent state managed via the slot pool. Qwen3.5 models (both dense and MoE variants) use the same unified memory management as Qwen3-Next.
- **MoE:** `Qwen3_5SparseMoeBlock` with `FusedMoE`, shared expert fusion support.
- **Normalization:** RMSNorm with optional fused allreduce for MoE models.
- **MTP:** Separate draft model in `atom/models/qwen3_5_mtp.py` (`Qwen3_5MTP`). The MTP predictor uses only full attention layers (no Gated DeltaNet) for efficiency, supporting both MTP1 and MTP3 variants via `num_speculative_tokens`.
Expand Down
Loading
Loading