Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/benchmark/models_accuracy.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@
"accuracy_baseline_model": "moonshotai/Kimi-K2.5",
"_baseline_note": "HF: amd/Kimi-K2.5-MXFP4 card shows Kimi-K2.5 baseline=0.9409"
},
{
"model_name": "Kimi-K2.5-MXFP4 Eagle3",
"model_path": "amd/Kimi-K2.5-MXFP4",
"extraArgs": "--kv_cache_dtype fp8 -tp 8 --trust-remote-code --method eagle3 --draft-model lightseekorg/kimi-k2.5-eagle3 --num-speculative-tokens 3",
"env_vars": "HSA_NO_SCRATCH_RECLAIM=1",
"runner": "atom-mi355-8gpu.predownload",
"test_level": "nightly",
"accuracy_threshold": 0.91,
"accuracy_baseline": 0.9257,
"accuracy_baseline_model": "amd/Kimi-K2.5-MXFP4 + lightseekorg/kimi-k2.5-eagle3",
"_baseline_note": "Eagle3 spec decode on Kimi-K2.5-MXFP4. Local case_verify_v9_gluon GSM8K 5-shot flexible-extract=0.9257 (vLLM=0.9280, within ±0.71% se). Threshold 0.91 leaves ~1.5pp headroom for noise. -tp 8 (vs base entry's tp=4) because Eagle3 draft KV needs the full 8-rank sharding."
},
{
"model_name": "GLM-5-FP8",
"model_path": "zai-org/GLM-5-FP8",
Expand Down
28 changes: 28 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,8 @@ class SpeculativeConfig:
model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
draft_model_hf_config: Optional[PretrainedConfig] = None
use_aux_hidden_state: bool = False
eagle3_aux_layer_ids: list[int] = field(default_factory=list)

# model_type → mtp_model_type mapping
_MTP_TYPE_MAP: ClassVar[dict[str, str]] = {
Expand Down Expand Up @@ -754,8 +756,34 @@ def __post_init__(self):
self.draft_model_hf_config = self.draft_model_hf_config.text_config
self.hf_config_override(self.draft_model_hf_config)

if self.method == "eagle3":
if getattr(self.draft_model_hf_config, "kv_lora_rank", None):
raise NotImplementedError(
"Eagle3 draft model with MLA attention is not supported"
)
# Aux hidden state layers: prefer the draft checkpoint's
# eagle_config; if absent or the list is empty, ModelRunner
# falls back to model.get_eagle3_aux_hidden_state_layers(),
# which defaults to 3 layers — early / middle / late
# (see DeepseekV2ForCausalLM.get_eagle3_aux_hidden_state_layers,
# returns `(2, num_layers // 2, num_layers - 3)`, aligned with vLLM).
eagle_cfg = getattr(self.draft_model_hf_config, "eagle_config", None)
if eagle_cfg:
self.use_aux_hidden_state = eagle_cfg.get("use_aux_hidden_state", False)
if self.use_aux_hidden_state and not self.eagle3_aux_layer_ids:
self.eagle3_aux_layer_ids = eagle_cfg.get(
"eagle_aux_hidden_state_layer_ids", []

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, add comment

)
else:
self.use_aux_hidden_state = True

@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> None:
# Eagle3 architecture mapping (architecture-level, not model_type)
arch = (getattr(hf_config, "architectures", None) or [""])[0]
if arch == "LlamaForCausalLMEagle3":
hf_config.architectures = ["Eagle3LlamaModel"]

# Step 1: resolve model_type → mtp model_type
mtp_type = SpeculativeConfig._MTP_TYPE_MAP.get(hf_config.model_type)
if mtp_type is not None:
Expand Down
30 changes: 24 additions & 6 deletions atom/model_engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class EngineArgs:
method: Optional[str] = None
num_speculative_tokens: int = 1
kv_transfer_config: str = "{}"
draft_model: Optional[str] = None
mark_trace: bool = False

@staticmethod
Expand Down Expand Up @@ -163,7 +164,7 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"--method",
type=str,
default=None,
choices=["mtp"],
choices=["mtp", "eagle3"],
help="Speculative method",
)
parser.add_argument(
Expand All @@ -172,6 +173,12 @@ def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
default=1,
help="Number of speculative tokens to generate per iteration (draft model runs this many times autoregressively)",
)
parser.add_argument(
"--draft-model",
type=str,
default=None,
help="Path to external Eagle3 draft model. Required when --method eagle3.",
)
parser.add_argument(
"--max-num-batched-tokens",
type=int,
Expand Down Expand Up @@ -243,14 +250,25 @@ def _get_engine_kwargs(self) -> dict:
),
)
if self.method and self.num_speculative_tokens > 0:
kwargs["speculative_config"] = SpeculativeConfig(
method=kwargs.pop("method"),
model=self.model,
num_speculative_tokens=kwargs.pop("num_speculative_tokens"),
)
method = kwargs.pop("method")
num_spec_tokens = kwargs.pop("num_speculative_tokens")
draft_model = kwargs.pop("draft_model")
if method == "eagle3":
kwargs["speculative_config"] = SpeculativeConfig(
method=method,
model=draft_model,
num_speculative_tokens=num_spec_tokens,
)
else:
kwargs["speculative_config"] = SpeculativeConfig(
method=method,
model=self.model,
num_speculative_tokens=num_spec_tokens,
)
else:
kwargs.pop("method")
kwargs.pop("num_speculative_tokens")
kwargs.pop("draft_model")
kwargs["speculative_config"] = None

# --enable-tbo [prefill|all] → enable_tbo + enable_tbo_decode
Expand Down
135 changes: 113 additions & 22 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,13 @@ def __init__(self, rank: int, config: Config):
self.num_spec_tokens = (
self.config.speculative_config.num_speculative_tokens if use_spec else 0
)
self.eagle3_mode = (
self.config.speculative_config is not None
and self.config.speculative_config.method == "eagle3"
)

self.use_aux_hidden_state_outputs = False
self._aux_hidden_states = None
self.tokenID_processor = tokenIDProcessor(
self,
self.config.max_num_batched_tokens,
Expand Down Expand Up @@ -621,6 +628,18 @@ def __init__(self, rank: int, config: Config):
torch.set_default_device(None)
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)

if self.eagle3_mode and self.config.speculative_config.use_aux_hidden_state:
aux_ids = self.config.speculative_config.eagle3_aux_layer_ids
if not aux_ids and hasattr(
self.model, "get_eagle3_aux_hidden_state_layers"
):
aux_ids = list(self.model.get_eagle3_aux_hidden_state_layers())
if aux_ids:
self.model.set_aux_hidden_state_layers(tuple(aux_ids))
self.use_aux_hidden_state_outputs = True
logger.info(f"Eagle3 aux hidden state layers: {aux_ids}")

torch.set_default_device(self.device)
self.async_execute_stream = torch.cuda.Stream(self.device)
self.allocate_forward_vars()
Expand Down Expand Up @@ -1075,24 +1094,35 @@ def _get_num_kv_heads(self):
return 1

def _get_total_num_layers(self):
"""Return total layer count including draft (MTP) layers."""
"""Return total layer count including draft (MTP) layers.

Drafts that own an independent KV cache via their own builder
(e.g. Eagle3 MHA draft on an MLA target) account for their layers
through that builder, so they are NOT added here. Only MTP-style
drafts that share the target's KV pool contribute.
"""
total = self.config.hf_config.num_hidden_layers
if self.config.speculative_config and hasattr(self, "drafter"):
draft_hf = self.config.speculative_config.draft_model_hf_config
total += getattr(draft_hf, "num_nextn_predict_layers", 1)
if not hasattr(self, "eagle3_draft_builder"):
draft_hf = self.config.speculative_config.draft_model_hf_config
total += getattr(draft_hf, "num_nextn_predict_layers", 1)
return total

def _compute_block_bytes(self):
"""Per-block bytes for the unified KV pool budget.

Delegates to the attention builder, which knows its own tensor
layout (MLA 576-dim packed, GDN-hybrid full-attn-only, MiMo-V2
per-layer-type, standard MHA split-K/V). Mirror of
`attn_metadata_builder.allocate_kv_cache_tensors()` so the budget
math matches what's actually allocated. Per-request cache bytes
are accounted for separately via `compute_per_req_cache_bytes()`.
Sum across all attention builders attached to this runner: the
target builder always, plus an optional `eagle3_draft_builder`
when a heterogeneous spec-decode draft owns its own KV pool. Each
builder knows its own tensor layout (MLA 576-dim packed, GDN-hybrid
full-attn-only, MiMo-V2 per-layer-type, standard MHA split-K/V,
Eagle3 independent MHA). Per-request cache bytes are accounted
for separately via `compute_per_req_cache_bytes()`.
"""
return self.attn_metadata_builder.compute_block_bytes()
block_bytes = self.attn_metadata_builder.compute_block_bytes()
if hasattr(self, "eagle3_draft_builder"):
block_bytes += self.eagle3_draft_builder.compute_block_bytes()
return block_bytes

def _estimate_cudagraph_overhead(self):
"""Estimate GPU memory consumed by CUDA graph capture.
Expand Down Expand Up @@ -1255,13 +1285,24 @@ def allocate_kv_cache(self, num_kvcache_blocks):
num_draft_layers = 0
if self.config.speculative_config and hasattr(self, "drafter"):
draft_hf_config = self.config.speculative_config.draft_model_hf_config
# For MTP, use num_nextn_predict_layers instead of num_hidden_layers
num_draft_layers = getattr(draft_hf_config, "num_nextn_predict_layers", 1)
total_num_layers += num_draft_layers
logger.info(
f"Allocating KV cache for {hf_config.num_hidden_layers} target layers + "
f"{num_draft_layers} draft (MTP) layers = {total_num_layers} total layers"
)
if hasattr(self, "eagle3_draft_builder"):
# Heterogeneous draft (e.g. Eagle3 MHA on MLA target) owns
# its own KV pool via its builder; don't add to target's count.
num_draft_layers = draft_hf_config.num_hidden_layers
logger.info(
f"Allocating KV cache for {hf_config.num_hidden_layers} target layers + "
f"{num_draft_layers} Eagle3 draft layers (separate non-MLA cache)"
)
else:
# For MTP, use num_nextn_predict_layers instead of num_hidden_layers
num_draft_layers = getattr(
draft_hf_config, "num_nextn_predict_layers", 1
)
total_num_layers += num_draft_layers
logger.info(
f"Allocating KV cache for {hf_config.num_hidden_layers} target layers + "
f"{num_draft_layers} draft (MTP) layers = {total_num_layers} total layers"
)

# Primary KV cache allocation (model-agnostic, delegated to the
# attention builder). Each builder owns its tensor layout: MLA →
Expand All @@ -1277,6 +1318,16 @@ def allocate_kv_cache(self, num_kvcache_blocks):
for name, value in main_kv.items():
setattr(self, name, value)

# Heterogeneous draft (e.g. Eagle3 MHA alongside an MLA target) owns
# its own KV pool through a sibling builder; same protocol as above,
# tensors land under namespaced keys (eagle3_kv_cache, eagle3_kv_scale).
if hasattr(self, "eagle3_draft_builder"):
draft_kv = self.eagle3_draft_builder.allocate_kv_cache_tensors(
num_kv_heads, num_draft_layers
)
for name, value in draft_kv.items():
setattr(self, name, value)

# Per-request cache allocation (model-agnostic, delegated to the
# attention metadata builder). For GDN this returns
# `{"mamba_k_cache": ..., "mamba_v_cache": ...}`; for stateless
Expand All @@ -1302,10 +1353,12 @@ def allocate_kv_cache(self, num_kvcache_blocks):
kv_cache_tensors = []
layer_id = 0
# Promote to self so the attention builder's build_kv_cache_tensor()
# can access it without recomputing from drafter state.
# can access it without recomputing from drafter state. Heterogeneous
# drafts (Eagle3) own their own layer space via their builder, so
# leave mtp_start_layer_idx at hf_config.num_hidden_layers in that mode.
self.mtp_start_layer_idx = (
self.drafter.model.model.mtp_start_layer_idx
if hasattr(self, "drafter")
if hasattr(self, "drafter") and not hasattr(self, "eagle3_draft_builder")
else hf_config.num_hidden_layers
)
for model_name, model in models_to_bind:
Expand All @@ -1314,6 +1367,18 @@ def allocate_kv_cache(self, num_kvcache_blocks):
)

for module in model.modules():
# Drafts that own an independent KV pool (Eagle3) bind through
# their sibling builder first; for unrecognized modules it
# returns None and we fall through to the target builder.
if model_name == "draft" and hasattr(self, "eagle3_draft_builder"):
kv_cache_tensor = self.eagle3_draft_builder.build_kv_cache_tensor(
layer_id, module
)
if kv_cache_tensor is not None:
kv_cache_tensors.append(kv_cache_tensor)
layer_id += 1
continue

# Per-attention-type binding is owned by the attention
# metadata builder; ModelRunner only walks modules and
# collects the resulting KVCacheTensor entries. The builder
Expand Down Expand Up @@ -1625,7 +1690,12 @@ def run_model(
label += f" tok={batch.total_tokens_num} ctx={ctx_str}"
label += "]"
with record_function(label):
hidden_states = self.model(input_ids, positions)
model_output = self.model(input_ids, positions)
if self.use_aux_hidden_state_outputs:
hidden_states, self._aux_hidden_states = model_output
else:
hidden_states = model_output
self._aux_hidden_states = None
logits = self.model.compute_logits(hidden_states)
else:
# decode[bs=128 tok=128 d=128] or decode[bs=128 tok=128 p=2 d=126 spec=3]
Expand All @@ -1645,6 +1715,12 @@ def run_model(
self.graphs[graph_key].replay()
num_tokens = context.batch_size * max_q_len
hidden_states = self.forward_vars["outputs"][:num_tokens]
if graph_key in self.graph_aux_hidden:
self._aux_hidden_states = [
aux[:num_tokens] for aux in self.graph_aux_hidden[graph_key]
]
else:
self._aux_hidden_states = None
if self.logits_in_graph:
logits = self.graph_logits[graph_key][:num_tokens]
else:
Expand Down Expand Up @@ -1833,6 +1909,7 @@ def propose_draft_token_ids(
num_reject_tokens=num_reject_tokens,
next_token_ids=next_token_ids,
last_token_indices=last_token_indices,
aux_hidden_states=self._aux_hidden_states,
)
return self.tokenID_processor.prepare_draft_ids(batch, draft_token)

Expand Down Expand Up @@ -1882,6 +1959,7 @@ def capture_cudagraph(self):

self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict()
self.graph_logits: dict[tuple[int, int], torch.Tensor] = dict()
self.graph_aux_hidden: dict[tuple[int, int], list[torch.Tensor]] = dict()
self.graph_pool = None
is_tbo = self.config.enable_tbo and isinstance(self.model, UBatchWrapper)
# TBO graphs don't capture compute_logits, so disable logits_in_graph.
Expand Down Expand Up @@ -1932,9 +2010,13 @@ def capture_cudagraph(self):
)

# Warmup
outputs[:num_tokens] = self.model(
model_output = self.model(
input_ids[:num_tokens], positions[:num_tokens]
)
if self.use_aux_hidden_state_outputs:
outputs[:num_tokens] = model_output[0]
else:
outputs[:num_tokens] = model_output
if self.logits_in_graph:
self.model.compute_logits(outputs[:num_tokens])

Expand All @@ -1953,13 +2035,20 @@ def capture_cudagraph(self):
gc.stream,
output_buffer=outputs[:num_tokens],
)
graph_aux = None
else:
# Standard single-stream capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.graph_pool, stream=gc.stream):
outputs[:num_tokens] = self.model(
model_output = self.model(
input_ids[:num_tokens], positions[:num_tokens]
)
if self.use_aux_hidden_state_outputs:
outputs[:num_tokens] = model_output[0]
graph_aux = model_output[1]
else:
outputs[:num_tokens] = model_output
graph_aux = None
if self.logits_in_graph:
graph_logits = self.model.compute_logits(
outputs[:num_tokens]
Expand All @@ -1969,6 +2058,8 @@ def capture_cudagraph(self):
self.graphs[(bs, max_q_len)] = graph
if self.logits_in_graph and ubatch_slices is None:
self.graph_logits[(bs, max_q_len)] = graph_logits
if graph_aux is not None:
self.graph_aux_hidden[(bs, max_q_len)] = graph_aux
torch.cuda.synchronize()
self.graph_bs.sort(reverse=False)

Expand Down
13 changes: 11 additions & 2 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,21 @@ def postprocess(
continue
token_ids = prev_token_ids[idx]
num_new_token = len(token_ids)
if self.spec_stats:
self.spec_stats.update(num_new_token)
if is_deferred_out or self.use_spec:
num_rejected = fwd_output.num_rejected[idx]
num_bonus = fwd_output.num_bonus[idx]
offset = 0 if (num_new_token + num_rejected) == 1 else self.mtp_k
# Align stats with vLLM: only count steps that actually ran
# speculation (drafts proposed and validated). Skip the
# prefill-only step where no draft tokens were scored against
# the target — vLLM gates this via
# `if scheduled_spec_token_ids and generated_token_ids`.
if (
self.spec_stats
and num_new_token > 0
and (num_new_token + num_rejected) > 1
):
self.spec_stats.update(num_new_token)
seq.num_rejected = num_rejected
seq.num_bonus_tokens = num_bonus
for i, el in enumerate(token_ids):
Expand Down
Loading
Loading