Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,9 @@ The scripts has the following optional arguments:
- `--log-dir`: The path to save the logs. Defaults to `./logs`.
- `--run-name`: The name of the run. Defaults to None.
- `--num-layers`: The number of layers to use. Defaults to 1.
- `--d2t-path`: The path to the d2t tensor. Defaults to `d2t.npy`.
- `--t2d-path`: The path to the t2d tensor. Defaults to `t2d.npy`.
- `--pretrained-model-path`: Path to a pretrained EAGLE3 model (HuggingFace Hub or local path) for fine-tuning. When specified, vocabulary mappings (`d2t`/`t2d`) are automatically extracted from the model. Cannot be used together with `--d2t-path` and `--t2d-path`.
- `--d2t-path`: The path to the d2t tensor. Defaults to `d2t.npy`. Not needed when using `--pretrained-model-path`.
- `--t2d-path`: The path to the t2d tensor. Defaults to `t2d.npy`. Not needed when using `--pretrained-model-path`.
- `--ttt-steps`: The number of TTT steps to use. Defaults to 3.
- `--ttt-step-loss-decay`: The loss decay factor to use for the TTT steps. Defaults to 1.0.

Expand All @@ -290,6 +291,29 @@ torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py \
--ttt-step-loss-decay 1.0
```

## Fine-tuning from Pretrained Models

Fine-tune an existing EAGLE3 model using `--pretrained-model-path`. Vocabulary mappings (`d2t`/`t2d`) are automatically extracted - no need to provide them separately.

**Example (Single GPU with python):**

```bash
python scripts/train.py \
--verifier-name-or-path "meta-llama/Llama-3.1-8B-Instruct" \
--pretrained-model-path "RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3" \
--data-path "./new_data" \
--save-path "./checkpoints/finetuned" \
--epochs 3 \
--lr 5e-5
```

**Notes:**

- Use HuggingFace Hub model IDs or local paths for `--pretrained-model-path`
- Cannot use `--d2t-path`/`--t2d-path` with `--pretrained-model-path`
- Optimizer state starts fresh
- Use lower learning rate (e.g., 5e-5) for fine-tuning

## E2E Pipeline

### Overview
Expand Down
151 changes: 129 additions & 22 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
import random
import warnings

Expand All @@ -9,6 +10,7 @@
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config

from speculators.config import SpeculatorModelConfig
from speculators.model import SpeculatorModel
from speculators.train.data import (
Eagle3SampleFileDataset,
Expand All @@ -23,6 +25,13 @@
from speculators.train.noise_transforms import AddUniformNoise
from speculators.train.trainer import Trainer, TrainerConfig
from speculators.train.utils import maybe_destroy_distributed, maybe_setup_distributed
from speculators.utils.loading import (
extract_vocab_mappings,
load_full_state_dict,
load_pretrained_weights,
)

logger = logging.getLogger(__name__)

DRAFT_ARCH_CONFIGS: dict[str, type] = {
"llama": LlamaConfig,
Expand All @@ -43,6 +52,90 @@ def set_seed(seed: int, deterministic: bool = False):
torch.backends.cudnn.benchmark = False


def load_pretrained_model(
pretrained_path: str, device: torch.device
) -> tuple[dict[str, torch.Tensor], torch.Tensor, torch.Tensor, int]:
"""
Load pretrained EAGLE3 model and extract components.

Returns:
Tuple of (state_dict, d2t, t2d, draft_vocab_size)
"""
logger.info(f"Loading pretrained model from {pretrained_path}")

# Load full state dict
state_dict = load_full_state_dict(pretrained_path)

# Extract vocab mappings
d2t, t2d = extract_vocab_mappings(state_dict, device)

# Derive draft_vocab_size
draft_vocab_size = d2t.shape[0]
logger.info(f"Derived draft_vocab_size={draft_vocab_size}")

return state_dict, d2t, t2d, draft_vocab_size


def load_vocab_mappings(
d2t_path: str, t2d_path: str, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Load vocabulary mappings from numpy files."""
if not (d2t_path and t2d_path):
raise ValueError(
"Both d2t and t2d paths must be provided together. "
f"Got d2t={'provided' if d2t_path else 'missing'}, "
f"t2d={'provided' if t2d_path else 'missing'}"
)

d2t = torch.from_numpy(np.load(d2t_path)).to(device)
t2d = torch.from_numpy(np.load(t2d_path)).to(device)
draft_vocab_size = d2t.shape[0]

return d2t, t2d, draft_vocab_size


def initialize_vocab_config(
args: argparse.Namespace, device: torch.device
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
int,
dict[str, torch.Tensor] | None,
]:
"""
Initialize vocabulary configuration from args.

Returns:
Tuple of (d2t, t2d, draft_vocab_size, pretrained_state_dict)
"""
# Check for conflicting args
if args.pretrained_model_path and (args.d2t_path or args.t2d_path):
raise ValueError(
"--pretrained-model-path overrides --d2t-path and "
"--t2d-path. Please remove --d2t-path and --t2d-path."
)

# Load from pretrained model
if args.pretrained_model_path:
state_dict, d2t, t2d, vocab_size = load_pretrained_model(
args.pretrained_model_path,
device,
)
return d2t, t2d, vocab_size, state_dict

# Load from numpy files
if args.d2t_path or args.t2d_path:
d2t, t2d, vocab_size = load_vocab_mappings(args.d2t_path, args.t2d_path, device)
return d2t, t2d, vocab_size, None

# No vocab mapping provided
verifier_config = AutoConfig.from_pretrained(args.verifier_name_or_path)
if hasattr(verifier_config, "text_config"):
verifier_config = verifier_config.text_config

return None, None, verifier_config.vocab_size, None


def setup_dataloader(
file_list: list[str],
world_size: int,
Expand Down Expand Up @@ -151,30 +244,24 @@ def main(args: argparse.Namespace):
local_rank, world_size, rank, is_distributed = maybe_setup_distributed()
device = torch.device(local_rank)

# Load t2d and d2t tensors if provided
if args.d2t_path or args.t2d_path:
if not (args.d2t_path and args.t2d_path):
raise ValueError(
"Both t2d and d2t must be provided together, or both must be omitted. "
f"Got t2d={'provided' if args.t2d_path is not None else 'not provided'}"
f"d2t={'provided' if args.d2t_path is not None else 'not provided'}"
)
d2t = torch.from_numpy(np.load(args.d2t_path)).to(device)
t2d = torch.from_numpy(np.load(args.t2d_path)).to(device)
draft_vocab_size = d2t.shape[0]
else:
d2t = None
t2d = None
# When vocab mapping is not provided, use the full verifier vocab
verifier_config = AutoConfig.from_pretrained(args.verifier_name_or_path)
if hasattr(verifier_config, "text_config"):
verifier_config = verifier_config.text_config
draft_vocab_size = verifier_config.vocab_size
# Initialize vocabulary configuration
d2t, t2d, draft_vocab_size, pretrained_state_dict = initialize_vocab_config(
args, device
)

# Setup speculator config
transformer_layer_config = create_transformer_layer_config(
args.verifier_name_or_path, args.num_layers, draft_arch=args.draft_arch
)
# If finetuning, preserve the transformer_layer_config from pretrained model
if args.pretrained_model_path:
pretrained_config = SpeculatorModelConfig.from_pretrained(
args.pretrained_model_path
)
transformer_layer_config = pretrained_config.transformer_layer_config
transformer_layer_config._attn_implementation = "simple_flex_attention" # noqa: SLF001
logger.info("Using transformer_layer_config from pretrained model ")
else:
transformer_layer_config = create_transformer_layer_config(
args.verifier_name_or_path, args.num_layers, draft_arch=args.draft_arch
)

# Get model class from registry and create model using its factory method
if SpeculatorModel.registry_auto_discovery:
Expand All @@ -195,6 +282,12 @@ def main(args: argparse.Namespace):
**vars(args),
)

# Load pretrained weights if provided (for fine-tuning)
if pretrained_state_dict is not None:
load_pretrained_weights(
draft_model, pretrained_state_dict, args.pretrained_model_path
)

# Setup dataloaders
train_files, val_files = split_files(args.data_path, ratio=0.9)
train_loader = setup_dataloader(
Expand Down Expand Up @@ -276,6 +369,20 @@ def parse_args():
)
parser.add_argument("--d2t-path", type=str, default=None)
parser.add_argument("--t2d-path", type=str, default=None)
parser.add_argument(
"--pretrained-model-path",
type=str,
default=None,
help=(
"Path to pretrained EAGLE3 model directory "
"(HuggingFace format with safetensors). "
"When provided, d2t/t2d mappings and model weights "
"will be loaded from this model, enabling "
"warm-start/fine-tuning. Overrides --d2t-path "
"and --t2d-path."
),
)

parser.add_argument("--ttt-steps", type=int, default=3)
parser.add_argument("--ttt-step-loss-decay", type=float, default=1.0)
parser.add_argument(
Expand Down
Loading