Skip to content

[Speculative decoding] feat: add EAGLE3 speculative decoding support#18039

Draft
ichbinhandsome wants to merge 12 commits intoggml-org:masterfrom
ichbinhandsome:eagle3-adapt-new-arch
Draft

[Speculative decoding] feat: add EAGLE3 speculative decoding support#18039
ichbinhandsome wants to merge 12 commits intoggml-org:masterfrom
ichbinhandsome:eagle3-adapt-new-arch

Conversation

@ichbinhandsome
Copy link

@ichbinhandsome ichbinhandsome commented Dec 14, 2025

As discussed in #15902, Eagle3 represents the current SOTA in speculative decoding and is widely adopted across the industry. Integrating Eagle3 into llama.cpp enhances its performance and strengthens its competitiveness among leading inference frameworks. With Eagle3 speculative decoding now integrated into llama.cpp, inference performance has been significantly improved, achieving a 2–3× speedup.
This enhancement is the result of close collaboration between the NVIDIA and GGML teams, showcasing a strong technical partnership.

The following provides a brief overview of this PR:

EAGLE3 is an encoder-decoder based speculative decoding method:

  • Extracts features from target model at specific layers
  • Uses feature fusion layer to compress target features
  • Generates draft tokens with single-layer decoder
  • Maps draft vocabulary to target vocabulary via d2t tensor

Key changes:

  • Add LLM_ARCH_EAGLE3 architecture
  • Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp)
  • Add feature extraction from target model layers
  • Add g_embeddings handling for decoder input
  • Add GGML_TENSOR_FLAG_SYNC for GPU synchronization
  • Add --eagle3 flag for speculative-simple example
  • Add EAGLE3 model conversion in convert_hf_to_gguf.py

EAGLE3 Architecture Overview :

┌─────────────────────────────────────────────────────────────────┐
│                    EAGLE3 Overview                              │
└─────────────────────────────────────────────────────────────────┘

  Target Model          EAGLE3 Encoder         EAGLE3 Decoder
  (LLaMA 8B)              (FC Layer)           (1-layer Transformer)
       │                      │                       │
       │                      │                       │
       ▼                      ▼                       ▼
┌─────────────┐        ┌─────────────┐        ┌─────────────────┐
│  Generate   │        │  Compress   │        │  Generate Draft │
│  Features   │───────►│  Features   │───────►│  Tokens Fast    │
│  [12288]    │        │  [4096]     │        │  [k tokens]     │
└─────────────┘        └─────────────┘        └────────┬────────┘
                                                       │
                                                       ▼
                                              ┌─────────────────┐
                                              │  Verify Drafts  │
                                              │  with Target    │
                                              └─────────────────┘

How to run EAGLE3 in llama.cpp

Requirements

This PR currently only supports two supports following EAGLE3 models:

The following eagle3 models should also work out of the box, though they haven’t been tested yet:

Step 1: Convert Models to GGUF Format

  • Convert Target Model
TARGET_MODEL_HF="${MODELS_DIR}/Meta-Llama-3.1-8B-Instruct"
TARGET_MODEL_GGUF="${MODELS_DIR}/Meta-Llama-3.1-8B-Instruct_bf16.gguf"

python convert_hf_to_gguf.py \
    "${TARGET_MODEL_HF}" \
    --outtype bf16 \
    --outfile "${TARGET_MODEL_GGUF}"
  • Convert EAGLE3 Draft Model
TARGET_MODEL_HF="${MODELS_DIR}/Meta-Llama-3.1-8B-Instruct"
EAGLE3_MODEL_HF="${MODELS_DIR}/EAGLE3-LLaMA3.1-Instruct-8B"
EAGLE3_MODEL_GGUF="${MODELS_DIR}/EAGLE3-LLaMA3.1-Instruct-8B_fp16.gguf"

python convert_hf_to_gguf.py \
    "${EAGLE3_MODEL_HF}" \
    --outtype f16 \
    --target-model-dir "${TARGET_MODEL_HF}" \
    --outfile "${EAGLE3_MODEL_GGUF}"

Step 2: Compile llama.cpp

cmake -B build -DGGML_CUDA=ON
cmake --build build --config Release

[Optional] Step 3: Quantize the GGUF model

./build/bin/llama-quantize \
  ${TARGET_MODEL_GGUF} \
  ${TARGET_MODEL_GGUF}_Q4_K_M.gguf \
  Q4_K_M
 
./build/bin/llama-quantize \
  ${EAGLE3_MODEL_GGUF} \
  ${EAGLE3_MODEL_GGUF}_Q4_K_M.gguf \
  Q4_K_M

Step 4: Run EAGLE3 Speculative Decoding

for prompt in \
    "Write a quicksort algorithm in Python. Write code only." \
    "Explain the Pythagorean theorem" \
    "Plan a 1 day trip to DC"; do
  echo "=== Prompt: $prompt ==="
    ./build/bin/llama-speculative-simple \
      -m "${TARGET_MODEL_GGUF}" \
      -md "${EAGLE3_MODEL_GGUF}" \
      --eagle3 -p "$prompt" -n 256 --draft 8 \
      --temp 0 --top-k 1 --seed 42 -ngl 99 -ngld 99 
done

Performance Evaluation (RTX A6000 48GB)

Note: Using the chat_template for each model version can improve acceptance rates. Always apply the model’s corresponding chat_template when constructing prompts.

  • LLaMA3.1-Instruct-8B with BF16, its Eagle3 with FP16
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 44.5 t/s 146.2 t/s 80.6% 3.28x
Explain the Pythagorean theorem 44.5 t/s 127.1 t/s 77.4% 2.85x
Plan a 1 day trip to DC 44.5 t/s 113.8 t/s 80.9% 2.55x
  • LLaMA3.1-Instruct-8B with Q4_K_M, its Eagle3 with Q4_K_M
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 121.5 t/s 274.4 t/s 92.5% 2.26x
Explain the Pythagorean theorem 121.4 t/s 238.9 t/s 79.4% 1.97x
Plan a 1 day trip to DC 121.4 t/s 196.5 t/s 77.2% 1.62x
  • LLaMA3.3-Instruct-70B with Q4_K_M, its Eagle3 with Q4_K_M
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 15.6 t/s 33.4 t/s 73.6% 2.14x
Explain the Pythagorean theorem 15.6 t/s 37.6 t/s 82.0% 2.41x
Plan a 1 day trip to DC 15.6 t/s 28.8 t/s 69.3% 1.85x
  • Qwen3-8B with BF16, its Eagle3 with BF16
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 43.6 t/s 94.8 t/s 69.8% 2.17x
Explain the Pythagorean theorem 43.6 t/s 86.8 t/s 68.3% 1.99x
Plan a 1 day trip to DC 43.6 t/s 70.7 t/s 57.3% 1.62x
  • Qwen3-14B with BF16, its Eagle3 with BF16
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 24.4 t/s 35.7 t/s 40.4% 1.46x
Explain the Pythagorean theorem 24.4 t/s 34.5 t/s 41.3% 1.41x
Plan a 1 day trip to DC 24.3 t/s 30.5 t/s 28.0% 1.26x
  • Qwen3-32B with Q4_K_M, its Eagle3 with Q4_K_M
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 32.0 t/s 39.7 t/s 39.7% 1.24x
Explain the Pythagorean theorem 32.0 t/s 41.5 t/s 43.3% 1.30x
Plan a 1 day trip to DC 32.0 t/s 37.1 t/s 32.6% 1.16x
  • Qwen3-30B-A3B with BF16, its Eagle3 with BF16 (tested on NVIDIA DGX Spark 128GB, speedup might be better on other hardwares)
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 31.1 t/s 43.3 t/s 64.4% 1.39x
Explain the Pythagorean theorem 31.2 t/s 41.2 t/s 60.6% 1.32x
Plan a 1 day trip to DC 30.9 t/s 38.6 t/s 58.8% 1.25x
Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 61.3 t/s 65.05 t/s 74.25% 1.06x
Explain the Pythagorean theorem 61.2 t/s 58.13 t/s 69.23% 0.95x
Plan a 1 day trip to DC 61.4 t/s 54.50 t/s 62.96% 0.89x

Details of GGML backend modifications (Fixed, no longer needed)

In the Eagle3 decoder, two parallel inputs are processed:

input_embeds ──→ RMS_NORM ──┐
                            ├──→ CONCAT ──→ Transformer Decoder
g_embeddings ──→ RMS_NORM ──┘

When both RMS_NORM operations run in the same GPU split, a lack of synchronization causes buffer contention and race conditions (CPU execution is fine as it auto‑syncs between subgraphs).

Solution:
Use ggml_set_sync() to add a synchronization point after the first RMS_NORM, forcing the scheduler to create a split boundary and synchronize before continuing.

input_embeds ──→ RMS_NORM ──→ [SYNC] ──┐
                                       ├──→ CONCAT ──→ Transformer Decoder
g_embeddings ─────────────→ RMS_NORM ──┘
         (split 1)            |         (split 2)
                           barrier

This ensures correct execution and can be applied to any parallel path that needs synchronization, not just Eagle3.

Examples results

  • Prompt: "Write a quicksort algorithm in Python. Write code only."
image
  • Prompt: "Explain the Pythagorean theorem"
image
  • Prompt: "Plan a 1 day trip to DC"
image

Future Steps

  • Support more Eagle3 models
  • Currently, Eagle3 is integrated only in llama-speculative-simple, support may need to be extended to other APIs if possible
  • Support context-dependent tree sampling (tree attention) as described in the Eagle3 paper to improve accept rate
  • Support batch processing (batch size > 1) with Eagle3 speculative decoding

EAGLE3 is an encoder-decoder based speculative decoding method:
- Extracts features from target model at specific layers
- Uses feature fusion layer to compress target features
- Generates draft tokens with single-layer decoder
- Maps draft vocabulary to target vocabulary via d2t tensor

Key changes:
- Add LLM_ARCH_EAGLE3 architecture
- Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp)
- Add feature extraction from target model layers
- Add g_embeddings handling for decoder input
- Add GGML_TENSOR_FLAG_SYNC for GPU synchronization
- Add --eagle3 flag for speculative-simple example
- Add EAGLE3 model conversion in convert_hf_to_gguf.py
@ngxson
Copy link
Collaborator

ngxson commented Dec 15, 2025

Judging by the description of this PR, I believe many models with multiple-token prediction also have the same strategy of reusing hidden features from the main model.

It can be quite interesting to generalize this features to support other models. I would expect some kind of sub-llama_context that allow both the main and draft models to share the same cgraph, avoiding the need of explicitly passing the intermediate embedding through the host memory.

@ggerganov
Copy link
Member

It can be quite interesting to generalize this features to support other models.

I will definitely be looking at refactoring the implementation to become more generic before merging it. The initial results in terms of performance are really great, but we'll need to work on cleaning up the code and reduce the special-casing in several places. I'll try to provide insights how to do that in the next days.

@ichbinhandsome
Copy link
Author

It can be quite interesting to generalize this features to support other models.

I will definitely be looking at refactoring the implementation to become more generic before merging it. The initial results in terms of performance are really great, but we'll need to work on cleaning up the code and reduce the special-casing in several places. I'll try to provide insights how to do that in the next days.

Thanks @ggerganov @ngxson for your inputs. Definitely, looking forward to hearing your feedback and improving this PR.

Comment on lines +60 to +65

// TODO: refactor into llm_graph_input
ggml_tensor * inp_g = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_set_input(inp_g);
cb(inp_g, "inp_g_embeddings", -1); // TODO: do not change the name! refactor into llm_graph_input

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change this to llm_graph_input in order to remove the extra "set input" logic in llama_context::process_ubatch.

Comment on lines +26 to +35
// EAGLE3: Extract intermediate layer features from target model at layer INPUT
if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) {
static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"};
for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) {
if (eagle3->extract_layer_indices[i] == il) {
cb(inpL, eagle3_extract_names[i], il);
break;
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will next look to remove this ad hoc logic and generalize it some way. Likely by passing the extraction points in some more generic way during llama_context creation. TBD

Comment on lines +195 to +198

// EAGLE3 draft model - target model hidden size
uint32_t eagle3_target_hidden_size = 0;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can become more generic by renaming it to n_embd_enc and utilizing the n_embd_inp() call.

Comment on lines +875 to +878
// Get pointer to target model features extracted for EAGLE3 encoder
// Returns NULL if no features are available
// Format: [3*n_embd, n_tokens] - use model.hparams.n_embd and batch.n_tokens for dimensions
LLAMA_API const float * llama_get_eagle3_target_features(struct llama_context * ctx);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call should become more generic and not Eagle3 specific. Will be looking how to achieve this in the best way.

Comment on lines +880 to +887
// Set g_embeddings from EAGLE3 encoder output for decoder input
// g_embd: pointer to encoder output embeddings
LLAMA_API void llama_set_eagle3_g_embeddings(
struct llama_context * ctx,
const float * g_embd,
int32_t n_embd,
int32_t n_tokens);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be possible to avoid this API if we combine the Eagle encoder and decoder in a single context. TBD

Copy link
Author

@ichbinhandsome ichbinhandsome Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When combining the Eagle3 encoder and decoder into a single context, note that the Eagle3 encoder is used only to fuse the extracted features from the target model, i.e. it is invoked as many times as the target model itself. The Eagle3 decoder, on the other hand, is solely responsible for generating draft tokens in autoregressive way.
llama_set_eagle3_g_embeddings() sets the g_embedding both from the Eagle3 encoder (used in the first generation step of the Eagle3 decoder) and from the Eagle3 decoder itself (used in subsequent generation steps).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I noticed this interaction. We don't have a previous use case similar to this, but I think the enc-dec context could be adapted accordingly.

@pwilkin
Copy link
Collaborator

pwilkin commented Jan 6, 2026

Bumping, is there any progress on this? It's probably one of the more coveted features to have right now.

@ggerganov
Copy link
Member

Bumping, is there any progress on this?

I'm currently side-tracked by some graph reallocation optimizations. Will probably come back to this after that.

@pwilkin pwilkin added the hot Something that is hot label Jan 6, 2026
@ichbinhandsome
Copy link
Author

Eagle3 checkpoints for the Qwen3 series (including both dense and MoE models) are now supported, see the updated PR description for details.
Although these Eagle3 checkpoints are from third party, they can still deliver a 1–2× speedup.
Speculative decoding performance for MoE models is not as good as dense models, which is expected, since more experts are invoked during the parallel verification phase than during the target model’s decoding phase.

@ichbinhandsome
Copy link
Author

ichbinhandsome commented Jan 9, 2026

One question: it seems that CUDA Graph is disabled when the input n_tokens > 1. During the target model verification stage of speculative decoding, CUDA Graph is always disabled for the target model, since it’s only used for verification with multiple draft tokens > 1. However, we can fix the number of draft tokens (e.g., by using padding) to make it constant and thus enable CUDA Graph (may need to remove n_tokens > 1 constraint)? @ggerganov

Context: I’m testing GPT-OSS-120B Eagle3 with llama.cpp, and I found that even with Eagle3 (accept rate 86%), the performance is worse than the naive llama-cli. After profiling, I discovered that CUDA Graph is consistently disabled for the target model during speculative decoding, whereas it remains enabled in llama-cli. This results in the target model’s verification(prefiling) phase being roughly >5× times slower compared to normal autoregressive decoding step.
After disabling CUDA graphs for llama-cli using GGML_CUDA_DISABLE_GRAPHS=1, the eagle3 achieved roughly a 1.5× speedup.

I’ve only observed this performance issue with GPT-OSS-120B Eagle3. For other models, even without CUDA Graph enabled for target model in Eagle3 speculative decoding, the performance remains great.

@ggerganov
Copy link
Member

Speculative decoding performance for MoE models is not as good as dense models, which is expected, since more experts are invoked during the parallel verification phase than during the target model’s decoding phase.

I think the small-batch mul_mat_id could be improved in the CUDA backend. AFAIR there the performance for batch sizes (1, 8] is not optimal atm. Need double check.

However, we can fix the number of draft tokens (e.g., by using padding) to make it constant and thus enable CUDA Graph (may need to remove n_tokens > 1 constraint)? @ggerganov

Possibly, but to me this sounds like second-order optimization. Optimizing the mul_mat_id for small batches should bring more generic benefits and would likely have larger impact for speculative decoding compared to enabling CUDA graphs.

After disabling CUDA graphs for llama-cli using GGML_CUDA_DISABLE_GRAPHS=1, the eagle3 achieved roughly a 1.5× speedup.

Hm, this is a bit surprising observation. Can you run a llama-batched-bench test on your system with and without CUDA graphs using the commands from #18308 (comment) and share the results. We are interested in batch sizes [1, 4]. So something like this:

llama-batched-bench -m [gpt-oss-120b] -c 65536 -b 2048 -ub 512 -npp 1024 -ntg 32 -npl 1,2,3,4,5,6,7,8

@ichbinhandsome
Copy link
Author

Thanks very much for your inputs! @ggerganov

After disabling CUDA graphs for llama-cli using GGML_CUDA_DISABLE_GRAPHS=1, the eagle3 achieved roughly a 1.5× speedup.

Hm, this is a bit surprising observation. Can you run a llama-batched-bench test on your system with and without CUDA graphs using the commands from #18308 (comment) and share the results. We are interested in batch sizes [1, 4]. So something like this:

I double-checked the run today. The previous statement about cuda graph was incorrect due to instability and concurrent CPU activity in my test environment, sorry about that! Currently, enabling or disabling CUDA Graphs doesn’t have much impact in llama-cli for GPT-OSS-120B model. (I am testing on DGX Spark)

  • with cuda graph enabled: [ Prompt: 120.1 t/s | Generation: 47.2 t/s ]
  • without cuda graph enabled: [ Prompt: 119.2 t/s | Generation: 45.7 t/s ]

Also, the results for llama-batched-bench:

  • with cuda graph enabled
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 256 1 1280 1.227 834.89 5.255 48.71 6.482 197.47
1024 256 2 2560 1.579 1296.74 9.277 55.19 10.856 235.81
1024 256 3 3840 2.284 1344.72 10.447 73.51 12.731 301.61
1024 256 4 5120 3.031 1351.58 11.550 88.66 14.580 351.16
1024 256 5 6400 3.780 1354.59 12.433 102.96 16.212 394.76
1024 256 6 7680 4.528 1356.95 13.347 115.08 17.874 429.66
1024 256 7 8960 5.304 1351.48 13.982 128.16 19.286 464.59
1024 256 8 10240 6.018 1361.20 14.704 139.28 20.722 494.16
  • without cuda graph enabled
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1024 256 1 1280 1.279 800.61 5.758 44.46 7.037 181.90
1024 256 2 2560 1.597 1282.22 9.297 55.07 10.895 234.98
1024 256 3 3840 2.286 1343.84 10.383 73.97 12.669 303.11
1024 256 4 5120 3.031 1351.51 11.547 88.68 14.577 351.23
1024 256 5 6400 3.771 1357.64 12.438 102.91 16.209 394.84
1024 256 6 7680 4.525 1357.71 13.342 115.12 17.868 429.83
1024 256 7 8960 5.289 1355.17 13.986 128.13 19.275 464.84
1024 256 8 10240 5.999 1365.48 14.653 139.77 20.652 495.83

Possibly, but to me this sounds like second-order optimization. Optimizing the mul_mat_id for small batches should bring more generic benefits and would likely have larger impact for speculative decoding compared to enabling CUDA graphs.

I agree. CUDA graphs could be second-order optimization.
Here are the eagle3 GPT-OSS-120B test results on DGX spark: (I will also test this on other hardwares)

Prompt Baseline (llama-cli) EAGLE3 (draft_size=8) Accept Rate Speedup
Write a quicksort algorithm in Python. Write code only. 48.3 t/s 52.2 t/s 85.0% 1.08x
Explain the Pythagorean theorem 47.8 t/s 46.5 t/s 74.0% 0.97x
Plan a 1 day trip to DC 48.4 t/s 40.0 t/s 55.7% 0.83x

For MoE models, prefilling becomes the main performance bottleneck because more active experts are involved. As a result, the assumption that “processing multiple draft tokens concurrently is as fast as processing a single token” no longer holds, which is an important condition for effective speculative decoding. I also saw that as the draft token length increases, the verification cost of the target model also rises.
This explains the results shown in the table above, in some cases, Eagle3 can even degrade performance. To observe improvements, the accept rate must exceed a certain lower bound.

Do you have any rough ideas that how much performance gain we can get through imporving mul_mat_id?

@ggerganov
Copy link
Member

The llama-batched-bench results are actually better than I expected. In the previous reported numbers there was a sharp dip at BS = 2. Here the TG performance steadily increases with the batch size which is good, though it is not as linear as we want it to be.

I suppose the explanation is that for MoE models, at low batch sizes the amount of data we need to read from the weights for each batch increases linearly with the batch size (i.e. each extra token in the batch activates more experts and at small batch size the experts for each token are very likely different from each other). So it's probably normal that TG for MoE does not scale as well as TG for dense models as a function of the batch size.

As a result, the assumption that “processing multiple draft tokens concurrently is as fast as processing a single token” no longer holds, which is an important condition for effective speculative decoding.

Yeah, that's my guess as well. Do we have some references to cross-check this? Does the Eagle3 authors discuss it's performance for MoE models? Do we have sample numbers for gpt-oss-120 with Eagle3 using vllm, trrt?

Do you have any rough ideas that how much performance gain we can get through imporving mul_mat_id?

Hm, not sure. Thinking about it now, I feel like mul_mat_id is unlikely to scale good enough due to the increasing data for each new token.

@arch-btw
Copy link
Contributor

The following eagle3 models should also work out of the box, though they haven’t been tested yet:
Qwen3-235B-A22B-EAGLE3

I tested the Baichuan-M3-235B model that was released yesterday (draft here). It's a finetune of the Qwen3 model above. It quantized successfully but failed due to having a different tensor shape (even in the original weights):

load_tensors: EAGLE3 using d2t mapping (draft_vocab_size = 32000)
llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_q.weight' has wrong shape; expected  8192,  8192, got  8192, 16384,     1,     1
llama_model_load_from_file_impl: failed to load model
failed to load EAGLE3 draft model

I haven't looked into how often this to happen in finetunes of the same model, especially in the context of eagle3.

However, the shapes of the tensors changing might be something to account for in the implementation (in this case Qwen3). Unless those will be treated as completely new models, in which case please disregard this comment.

@ichbinhandsome
Copy link
Author

The following eagle3 models should also work out of the box, though they haven’t been tested yet:
Qwen3-235B-A22B-EAGLE3

I tested the Baichuan-M3-235B model that was released yesterday (draft here). It's a finetune of the Qwen3 model above. It quantized successfully but failed due to having a different tensor shape (even in the original weights):

load_tensors: EAGLE3 using d2t mapping (draft_vocab_size = 32000)
llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_q.weight' has wrong shape; expected  8192,  8192, got  8192, 16384,     1,     1
llama_model_load_from_file_impl: failed to load model
failed to load EAGLE3 draft model

I haven't looked into how often this to happen in finetunes of the same model, especially in the context of eagle3.

However, the shapes of the tensors changing might be something to account for in the implementation (in this case Qwen3). Unless those will be treated as completely new models, in which case please disregard this comment.

I spent some time analyzing the Baichuan-EAGLE3 draft model. It has a slightly different architecture compared to the standard Qwen3-EAGLE3 model.
The main difference is in the self_attn.q_proj.weight tensor shape:

  • Standard Qwen3-EAGLE3 : [8192, 8192] — outputs Q only
  • Baichuan-EAGLE3: [16384, 8192] — outputs Q + Gate (2× the size)

This is because Baichuan-EAGLE3 uses an Attention Output Gate mechanism, which is not present in the standard EAGLE3 model. In this variant:

  • The Q projection outputs both query vectors and gate vectors
  • After attention computation, the output is element-wise multiplied by sigmoid(gate) before the output projection

This is essentially a variant architecture of EAGLE3, not just a tensor shape difference. Supporting this variant would require:

  • Detecting the gate mechanism during model loading
  • Modifying the graph construction to split Q/Gate and apply the gating after attention
  • Adding the ggml_sigmoid operation in the attention path

I would suggest we focus this PR on the standard EAGLE3 model first. Once merged, we can consider adding support for this gated variant in a follow-up PR.

Have you tested the standard Qwen3-EAGLE3 model as well? Does it work well with the current implementation? If yes, could you please share the t/s and speedup you got with eagle3? @arch-btw

@ngxson
Copy link
Collaborator

ngxson commented Jan 14, 2026

Since EAGLE3 can vary quite a lot for each model, maybe a better way is to consider it as an adapter (the same logic as lora adapter), instead of a dedicated arch?

That way, it can hook into existing models more easily, making internal data like KV state, gate, etc, accessible to the draft model.

@ichbinhandsome
Copy link
Author

Since EAGLE3 can vary quite a lot for each model, maybe a better way is to consider it as an adapter (the same logic as lora adapter), instead of a dedicated arch?

That way, it can hook into existing models more easily, making internal data like KV state, gate, etc, accessible to the draft model.

Good point. However, Eagle3 doesn’t vary much across models. So far, except for Baichuan-Eagle3, all other models essentially use the same Eagle3 architecture. Please refer to the supported models listed in the PR description. I’d say the majority of models share the same Eagle3 architecture, with only a few exceptions. This standalone Eagle3 architecture strategy is also adopted in TensorRT-LLM, vLLM, and SGLang.

@ngxson
Copy link
Collaborator

ngxson commented Jan 14, 2026

I doubt that. In theory, nothing prevent them or another team from making a variant of eagle3 that get the state of more than 3 layers, or even reuse the KV state from earlier layers. Possibilities are endless, and that's why it's important to think about the bigger picture instead of just trying to make it work with one single existing architecture.

I think a more model-agnostic approach via adapter API (or another API based on that form) will likely be the way ultimately. It will allow computing both the next token + draft token in one pass, allowing even higher performance than this approach.

@ichbinhandsome
Copy link
Author

I doubt that. In theory, nothing prevent them or another team from making a variant of eagle3 that get the state of more than 3 layers, or even reuse the KV state from earlier layers. Possibilities are endless, and that's why it's important to think about the bigger picture instead of just trying to make it work with one single existing architecture.

Could you please share some examples or real-world use cases of this? I’d like to better understand how such an approach might be applied in practice.

@ngxson
Copy link
Collaborator

ngxson commented Jan 14, 2026

The main problem with this PR and #15225 is that both assumes that the MTP (multi-token prediction) to work this way:

  • main LLM generates first tokens + hidden_state from a list of selected layers
  • hidden_state is then forwarded to the speculative model to generate N next tokens

(Note: the dash line is to tell that it's may not be the case for all models; some only use the last hidden state)

image

While it does work for the moment, this approach doesn't address the true nature of MTP models. In other words, it is not truly model-agnostic. The main drawbacks is that you must manually pass the embeddings between 2 models, so you must know where to get the embeddings, its shapes, etc.

Instead, we should look at MTP models as a normal LLM with multiple output heads:

image

In this POV, it's not matter what is the implementation of the mtp_head. From the outside world, the model will just output N next tokens given one input token.

In practice, the mtp_head(s) can be:

Now, returning to your question:

Could you please share some examples or real-world use cases of this? I’d like to better understand how such an approach might be applied in practice.

If you already get the idea above, then consider gemma3n: the model has 30 layers, but only 20 layers has KV projection. The last 10 layers reuse the KV from the 20-th layer. Some models also implement this idea, notably GLM, bailing.

The same idea can be apply to MTP layers. Future models may has MTP layers to not just reuse the layer output hidden state, but also the projected KV inside the layer. While there is no models in the wild currently doing that, Baichuan-EAGLE3 (as you shown), already someway heading towards this direction by exposing both the Q+gate to the MTP model.

@ngxson
Copy link
Collaborator

ngxson commented Jan 14, 2026

(I have to split up my comment otherwise it's too long)

My proposal is that we must design this function + the API in a way that it is flexible enough for future models.

For EAGLE3, the MTP model is technically a mtp_head shipped as an extension to the main model (note that the eagle3 repo only contains the extra tensors, but does not contain the main LLM), it can be viewed as an adapter, much like how LoRA works.

For the API, we must avoid leaking the information about the implementation under the hood. The downstream code must only know about how many tokens can be generated, they don't need to know how to generate these extra tokens.

So, an array of API as follow should be enough:

  • llama_model_load_mtp: load the mtp as a llama_adapter_lora or maybe we can add a new struct for it
  • llama_mtp_set_n_draft: set the max number of draft tokens to be generated in the next llama_decode; set to 0 for verification pass
  • llama_mtp_get_n_draft_max: get max number of draft tokens that the MTP head can generate
  • llama_mtp_get_logits_ith: get logits at for i-th token in batch, returns array of float with size n_vocab*n_draft

All the info about embeddings and the draft model must be kept private.

CC @ggerganov maybe this is helpful for you

@ichbinhandsome
Copy link
Author

ichbinhandsome commented Jan 15, 2026

Yeah, that's my guess as well. Do we have some references to cross-check this? Does the Eagle3 authors discuss it's performance for MoE models? Do we have sample numbers for gpt-oss-120 with Eagle3 using vllm, trrt?

As far as I know, the Eagle3 authors did not discuss their approach to MoE model performance in their paper. I am currently cross-checking the performance of GPT-OSS-120B Eagle3 on DGX Spark using SGLang, which essentially employs the same GPT-OSS-120B-Eagle3 draft model as I used for llama.cpp testing.

The running commands I used are as follows:
• Baseline: Run the GPT-OSS-120B target model only.

python3 -m sglang.launch_server --model-path gpt-oss-120b --host 0.0.0.0 --port 30000 --trust-remote-code

• Eagle3: Set the draft size to 8 and disable tree decoding to ensure a fair comparison with our tests on llama.cpp.

python3 -m sglang.launch_server --model gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 8 --speculative-eagle-topk 1 --speculative-num-draft-tokens 8  --trust-remote-code --host 0.0.0.0 --port 30000

I am using curl to test prompts, e.g.

curl -sS -X POST http://localhost:30000/generate \
  -H "Content-Type: application/json" \
  -d '{
      "text": "Write a quicksort algorithm in Python. Write code only.",
      "sampling_params": {
          "max_new_tokens": 256
      }
  }' | python3 -c "
import sys, json
d = json.load(sys.stdin)
tokens = d['meta_info']['completion_tokens']
latency = d['meta_info']['e2e_latency']
tps = tokens / latency
print(f'completion_tokens: {tokens}')
print(f'e2e_latency: {latency:.3f}s')
print(f'token/s: {tps:.2f}')
"

Here are the test results on DGX spark:

Prompt Baseline EAGLE3 (draft_size=8) Speedup
Write a quicksort algorithm in Python. Write code only. 52.50 t/s 36.4 t/s 0.69x
Explain the Pythagorean theorem 52.64 t/s 24.4 t/s 0.46x
Plan a 1 day trip to DC 52.69 t/s 24.7 t/s 0.47x

I also tested shorter draft sizes using the following command:

python3 -m sglang.launch_server --model /home/nvidia/models/gpt-oss-120b --speculative-algorithm EAGLE3 --speculative-draft-model-path /home/nvidia/ruixiangw/models/lmsys/EAGLE3-gpt-oss-120b-bf16 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4  --trust-remote-code --host 0.0.0.0 --port 30000

The results:

Prompt Baseline EAGLE3 (draft_size=3) Speedup
Write a quicksort algorithm in Python. Write code only. 52.50 t/s 37.36 t/s 0.71x
Explain the Pythagorean theorem 52.64 t/s 27.96 t/s 0.53x
Plan a 1 day trip to DC 52.69 t/s 26.53 t/s 0.50x

From the tables above, we observed similar performance degradation for GPT-OSS-120B-Eagle3 on a single GPU device in SGLang as well. However, in their blog post, they claimed to have achieved some speedups for GPT-OSS-120B-Eagle3 inference using tp=4, i.e., Tensor Parallelism across four H200 GPUs. It’s unclear whether this configuration is the key factor contributing to the observed speedup.
In addition, I feel batching (i.e., processing multiple prompts concurrently) may also play an important role in achieving Eagle3 speedups. A larger batch size means more experts are already activated during native decoding, so activating additional experts in the target model may not become a bottleneck during the Eagle3 verification stage.
TODO: This can be verified after merging this PR and extending eagle3 within llama-server to support multi-batch inference.

In summary, I believe that for large MoE models such as GPT-OSS-120B, Eagle3 may not provide a performance gain on single GPU device with single prompt use case. However, this does not apply to all MoE models—for example, we observed a performance improvement with Qwen3-30B-A3B_eagle3. This might be related to the number of active experts per token and the overall model size, where loading active experts (a memory-bound operation) dominates the inference time.
After profiling, I confirmed that draft token verification for the GPT-OSS MoE target model takes at least twice (depends on how many draft tokens to verify) as long as single-token autoregressive generation on the same GPT-OSS MoE model.
@ggerganov

@ichbinhandsome
Copy link
Author

ichbinhandsome commented Jan 15, 2026

My proposal is that we must design this function + the API in a way that it is flexible enough for future models.

Thank you very much for taking the time for this insightful proposal. Although we discussed the Eagle3 design(#15902 (reply in thread)) several months ago, it’s still great to hear your perspective. @ggerganov These might be things worth considering.

@ngxson
Copy link
Collaborator

ngxson commented Jan 15, 2026

The mentioned discussion only discuss the internal design, not the public API design. Probably it's best to open a dedicated discussion on the public API design to avoid going to far into a wrong direction.

Even after reading #15902, I still believe that this implementation is just N layers on top of the main model, meaning the MTP model is just an extension to the main model, instead of being an external model. (I'll comment in the code)

Comment on lines +25 to +38
llm_build_eagle3_encode::llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
ggml_tensor * cur = nullptr;

cur = build_inp_embd();

// Feature fusion layer
cur = build_lora_mm(model.fc, cur);
cb(cur, "fc_out", -1);

// Output: g_embeddings e.g. [4096, n_tokens]
res->t_embd = cur;

ggml_build_forward_expand(gf, cur);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the whole point of the encoder is to just do a projection, I think it isn't truly an encoder in transformer terms.

an encoder is responsible for populating KV cache. here, we do not touch the KV cache at all. Instead, I believe this projection can be part of the decoder.

if we need to allow larger input embeddings than n_embd, there is an interface called n_embd_inp that allow doing just that


// Single decoder layer (il = 0)
const int il = 0;
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm ok, I thought that we can fuse this cgraph with the main LLM cgraph. But that won't work very well because we need to call the sampling system to sample a new token each for each decoding pass of eagle3.

In such case, keeping it as a dedicated model seem ok, although I believe that in term of API design, we must keep llama_set_eagle3_g_embeddings private (not exposing it to the public API)

I think the best could be to have a notion of sub-llama_context, where one llama_context can encapsulate another llama_context. Will see if this is something that can easily be implemented or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In such case, keeping it as a dedicated model seem ok, although I believe that in term of API design, we must keep llama_set_eagle3_g_embeddings private (not exposing it to the public API)

I think it can be avoided using an enc-dec context:

#18039 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary because my comment above suggests that eagle3 is not exactly an enc-dec model, but more like an decoder-only model with n_embd_inp > n_embd

What I'm suggesting here is to pass the embeddings from main LLM to the smaller speculative LLM. Because they are currently on 2 different llama_context, so we currently have no better way than passing them via a public API (which make it less future-proof)

Copy link
Collaborator

@ngxson ngxson Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think I'm commenting on the wrong line, this comment should be placed on llama_get_eagle3_target_features)

Copy link
Collaborator

@ngxson ngxson Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked deeper into GLM-4.6 implementation today, and I'm pretty confident that eagle3 is almost the same as the MTP model of GLM-4.6

The "encoder" here is basically equivalent to nextn.eh_proj. It is not an enc-dec in transformer terms (i.e. unlike T5), just a bad naming.

And the rest is the same as deepseekv3 MTP style, except that instead of passing the hidden state from one MTP pass to another MTP pass, eagle3 use KV cache

image

I'm playing around with an implementation on my side that will expose just one single llama_decode_mtp call that will handle hidden state passing under the hood (based on llama_cross), so you can think of the main LLM as the encoder, that will populate the cross, and the MTP as the decoder, in transformer terms.

Will push it when I have a working version.

Copy link
Collaborator

@ngxson ngxson Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In anyways, I'm still not convinced that the linear projection should be a dedicated "encoder" cgraph. As I mentioned, the performance loss in this PR could also be due to the backend synchronization happens between encode and decode pass of eagle3 model

The solution 2 in my last comment seems to be the most feasible, will try to implement that on my PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, the performance loss in this PR could also be due to the backend synchronization happens between encode and decode pass of eagle3 model

No, it is not. As mentioned earlier, the performance degradation occurs only with the MoE model #18039 (comment). This is because the MoE model requires significantly more time for draft token verification compared to the dense model.
If you perform profiling, you will notice that the backend synchronization between the encode and decode passes of the Eagle3 model is relatively negligible.
image

Copy link
Collaborator

@ngxson ngxson Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ofc it is negligible if you compare it to the time it takes for the verification pass, but I don't believe that it is negligible if you compare to the time it take to generate one single draft token. The draft model is very small and CPU time can have significant impact on it.

But even if you say that's not important for whatever reason, the more important thing is that copying data to host memory is redundant. At this point, I think it's a better use of my time to just improving this in my implementation instead of arguing here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, from your profiling screenshot, it seems like there is a big gap between the large cudaMemcpyAsync and the run after it (I suppose that's the encoder pass of eagle3), I'm curious what happens in that big gap, probably some calculations on the CPU?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious what happens in that big gap, probably some calculations on the CPU?

Yes. It is the rejection sampling phase during speculative decoding. Once we obtain the logits for the draft tokens from the target model, we need to verify which tokens are accepted and which need be rejected, and prepare these as input for the draft model. Note that the token_id to embedding mappling also happens in CPU.

@arch-btw
Copy link
Contributor

arch-btw commented Jan 16, 2026

@ichbinhandsome thank you for looking into the Baichuan model.

Have you tested the standard Qwen3-EAGLE3 model as well? Does it work well with the current implementation? If yes, could you please share the t/s and speedup you got with eagle3? @arch-btw

It took me a bit because I had to download the gguf of Qwen3. It does appear to work, but I'm noticing somewhat of a slowdown:

./llama-speculative-simple -m Qwen3-235B-A22B.Q2_K.gguf -md Qwen3-235B-A22B-EAGLE3-draft-Q2_K.gguf --eagle3 -p "Hello!" -n 256 --draft 8 --temp 0 --top-k 1 --seed 42 --no-mmap -fa on

With EAGLE3:

Result
encoded   10 tokens in    1.945 seconds, speed:    5.141 t/s
decoded  136 tokens in   28.931 seconds, speed:    4.701 t/s

n_draft   = 8
n_predict = 136
n_drafted = 112
n_accept  = 36
accept    = 32.143%

draft:

 Eagle3 Draft encoder:
llama_perf_context_print:        load time =   24387.75 ms
llama_perf_context_print: prompt eval time =      11.86 ms /    74 tokens (    0.16 ms per token,  6239.99 tokens per second)
llama_perf_context_print:        eval time =      23.77 ms /    69 runs   (    0.34 ms per token,  2902.94 tokens per second)
llama_perf_context_print:       total time =   53318.22 ms /   143 tokens
llama_perf_context_print:    graphs reused =          0

Eagle3 Draft decoder:
llama_perf_context_print:        load time =   24392.05 ms
llama_perf_context_print: prompt eval time =     155.84 ms /    74 tokens (    2.11 ms per token,   474.83 tokens per second)
llama_perf_context_print:        eval time =     386.38 ms /    82 runs   (    4.71 ms per token,   212.23 tokens per second)
llama_perf_context_print:       total time =   53318.22 ms /   156 tokens
llama_perf_context_print:    graphs reused =         62

target:

common_perf_print:    sampling time =      16.16 ms
common_perf_print:    samplers time =       5.71 ms /   136 tokens
common_perf_print:        load time =   24122.19 ms
common_perf_print: prompt eval time =   29381.13 ms /   221 tokens (  132.95 ms per token,     7.52 tokens per second)
common_perf_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
common_perf_print:       total time =   53053.81 ms /   222 tokens
common_perf_print: unaccounted time =   23656.52 ms /  44.6 %      (total - sampling - prompt eval - eval) / (total)
common_perf_print:    graphs reused =         83
llama_memory_breakdown_print: | memory breakdown [MiB]      | total    free     self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Vulkan0 (780M Graphics) | 88024 = 82159 + (  951 =     0 +       0 +     951) +        4912 |
llama_memory_breakdown_print: |   - Host                    |                  89323 = 81715 +    7520 +      88                |

Without EAGLE3:

[ Prompt: 5.0 t/s | Generation: 7.4 t/s ]

By the way, it says "inf tokens per second" under eval time, is that being replaced with the "decoded" section on top? Just making sure I'm reading it correctly.

@ichbinhandsome
Copy link
Author

It took me a bit because I had to download the gguf of Qwen3. It does appear to work, but I'm noticing somewhat of a slowdown:

Thank you very much for testing this! The slowdown may be due to the short prompt (“Hello!”) or potential MoE performance issues mentioned in this comment. Could you try running the experiments using the same prompts I provided as examples in this PR? I’d expect a higher accept rate with those prompts, which might result in some speedups.

By the way, it says "inf tokens per second" under eval time, is that being replaced with the "decoded" section on top? Just making sure I'm reading it correctly.

I'm using the same metrics as the original code. I think the reason for the inf value is that the target model is only used for draft token verification (prefill) rather than autoregressive decoding. Since no actual decode steps are performed, the eval time is recorded as 0 ms, resulting in inf t/s.
To clarify: this inf refers to the pure target model's eval time. The actual end-to-end t/s (EAGLE3 + target model) is shown at the top of the output.
@arch-btw

@arch-btw
Copy link
Contributor

@ichbinhandsome No problem!
Here are the results, I've also tested another model from the second list of models that haven't been tested yet:

Qwen3-1.7B regular vs Qwen3-1.7B + AngelSlim/Qwen3-1.7B_eagle3

Task Qwen_Qwen3-1.7B-Q6_K-EAGLE3 (t/s) Qwen_Qwen3-1.7B-Q6_K (t/s) EAGLE speedup (t/s)
Write quicksort algorithm in Python 47.209 46.9 +0.309
Explain Pythagorean theorem 48.109 47.0 +1.109
Plan 1-day trip to DC 45.031 46.8 -1.769

Qwen3-235B-A22B regular vs Qwen3-235B-A22B + lmsys/Qwen3-235B-A22B-EAGLE3

Task Qwen3-235B-A22B.Q2_K-EAGLE3 (t/s) Qwen3-235B-A22B.Q2_K (t/s) EAGLE speedup (t/s)
Write quicksort algorithm in Python 4.299 7.3 -3.001
Explain Pythagorean theorem 4.709 7.3 -2.591
Plan 1-day trip to DC 4.374 7.4 -3.026
Additional info for Qwen3-235B-A22B EAGLE3
Write a quicksort algorithm in Python. Write code only.

encoded   20 tokens in    4.786 seconds, speed:    4.179 t/s
decoded  257 tokens in   59.787 seconds, speed:    4.299 t/s

n_draft   = 8
n_predict = 257
n_drafted = 229
n_accept  = 66
accept    = 28.821%

Explain the Pythagorean theorem

encoded   15 tokens in    5.170 seconds, speed:    2.902 t/s
decoded  258 tokens in   54.794 seconds, speed:    4.709 t/s

n_draft   = 8
n_predict = 258
n_drafted = 219
n_accept  = 96
accept    = 43.836%

Plan a 1 day trip to DC

encoded   16 tokens in    3.757 seconds, speed:    4.258 t/s
decoded  257 tokens in   58.759 seconds, speed:    4.374 t/s

n_draft   = 8
n_predict = 257
n_drafted = 218
n_accept  = 60
accept    = 27.523%

Additional info for Qwen3-1.7B EAGLE3
Write a quicksort algorithm in Python. Write code only.

encoded   20 tokens in    0.121 seconds, speed:  165.070 t/s
decoded  257 tokens in    5.444 seconds, speed:   47.209 t/s

n_draft   = 8
n_predict = 257
n_drafted = 227
n_accept  = 88
accept    = 38.767%

Explain the Pythagorean theorem

encoded   15 tokens in    0.090 seconds, speed:  166.192 t/s
decoded  257 tokens in    5.342 seconds, speed:   48.109 t/s

n_draft   = 8
n_predict = 257
n_drafted = 231
n_accept  = 94
accept    = 40.693%

Plan a 1 day trip to DC

encoded   16 tokens in    0.094 seconds, speed:  169.339 t/s
decoded  258 tokens in    5.729 seconds, speed:   45.031 t/s

n_draft   = 8
n_predict = 258
n_drafted = 241
n_accept  = 81
accept    = 33.610%

@ichbinhandsome
Copy link
Author

Thanks for testing! Glad to see the model works. Though the speedup for these models is relatively small or even worse, and the accept rate is quite low.
This could be due to quantization — since Eagle3 models are usually trained in BF16/FP16 with the original data types of the target models, quantization may degrade model quality and reduce the accept rate.
Another possible factor is that the Eagle3 models themselves may not be of high quality, as they come from third-party sources. @arch-btw

@ngxson ngxson mentioned this pull request Jan 16, 2026
@jukofyork
Copy link
Collaborator

This was just linked on Reddit today:

https://z-lab.ai/projects/dflash/

https://github.com/z-lab/dflash

and seems worth thinking about for any future MPT/Eagle API:

We demonstrate that a free lunch does exist. Our key insight is that the large AR target model’s hidden features implicitly contain information about future tokens, a phenomenon also observed by [4].

Instead of asking a tiny diffusion model to reason from scratch, DFlash conditions the draft model on context features extracted from the target model. This fuses the deep reasoning capabilities of the large model with the parallel generation speed of the small diffusion drafter.

TimPietruskyRunPod pushed a commit to runpod-workers/openclaw2go-llamacpp that referenced this pull request Feb 14, 2026
squash-merge of ggml-org/llama.cpp PR ggml-org#18039 onto main.
adds Eagle-3 speculative decoding support for 1.5-2.5x
generation speedup with draft model pairing.
TimPietruskyRunPod pushed a commit to runpod-workers/openclaw2go-llamacpp that referenced this pull request Feb 14, 2026
checks daily for new llama.cpp releases.
auto-rebases cherry-picks (audio ggml-org#18641, outetss ggml-org#12794, eagle-3 ggml-org#18039).
creates tagged release on clean rebase, PR on conflicts.
PR ggml-org#19460 (GLM-5 DSA) already merged upstream, not in cherry-pick list.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning hot Something that is hot model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Support EAGLE3 models for draft model / speculative decoding use cases

7 participants