Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions configs/qwen3.5-4b-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"architectures": ["LlamaForCausalLMEagle3"],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 248044,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 32768,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 1,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000000,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.57.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 248320,
"draft_vocab_size": 32000
}
30 changes: 30 additions & 0 deletions examples/run_qwen3.5-4b-eagle3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash
# Example: Train EAGLE3 for Qwen3.5-4B

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)

# Generate hidden states
CUDA_VISIBLE_DEVICES=0 torchrun \
--standalone \
--nproc_per_node 1 \
$ROOT_DIR/scripts/prepare_hidden_states.py \
--target-model-path Qwen/Qwen3.5-4B \
--enable-aux-hidden-states \
--data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
--output-path $ROOT_DIR/cache/hidden_states/qwen3.5-4b \
--chat-template qwen \
--max-length 1024 \
--batch-size 8 \
--sglang-mem-fraction-static 0.6

# Train draft model
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \
--standalone \
--nproc_per_node 4 \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path Qwen/Qwen3.5-4B \
--draft-model-config $ROOT_DIR/configs/qwen3.5-4b-eagle3.json \
--train-hidden-states-path $ROOT_DIR/cache/hidden_states/qwen3.5-4b \
--output-path $ROOT_DIR/outputs/qwen3.5-4b-eagle3 \
--embedding-key "model.language_model.embed_tokens.weight"
16 changes: 16 additions & 0 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Tuple, Union

# Enable Qwen3.5 EAGLE3 patch via environment variable
# Set QWEN35_EAGLE3_ENABLE=1 to enable the patch
os.environ['QWEN35_EAGLE3_ENABLE'] = '1'

# Patch Qwen3.5 4B for EAGLE3 support
try:
from specforge.modeling.qwen3_5_eagle_patch import patch_qwen3_5_for_eagle3, patch_qwen3_5_instance
patch_qwen3_5_for_eagle3()
except Exception as e:
print(f"Warning: Failed to apply Qwen3.5 EAGLE3 patch: {e}")

import torch
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -305,6 +316,11 @@ def build_target_model(
draft_model_config.eagle_config["eagle_aux_hidden_state_layer_ids"]
)
else:
# Apply instance-level patch if model wasn't patched at class level
try:
patch_qwen3_5_instance(target_model.model_runner.model)
except:
pass
target_model.set_aux_hidden_states_layers()

if args.is_vlm:
Expand Down
9 changes: 8 additions & 1 deletion specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,14 @@ def rope_get(key, default=None):
scaling_type = rope_get("rope_type", rope_get("type"))
scaling_factor = rope_get("factor")

if scaling_type == "linear":
# Handle "default" as no scaling - use standard rotary embedding
if scaling_type is None or scaling_type == "default":
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=getattr(self.config, "rope_theta", 10000),
)
elif scaling_type == "linear":
if scaling_factor is None:
raise ValueError(
"Linear RoPE scaling requires 'factor' in rope_scaling config."
Expand Down
Loading