Skip to content

Support EagleDraftModel.from_pretrained for finetuning#333

Merged
rahul-tuli merged 7 commits into
mainfrom
support_from_pretrained
Mar 11, 2026
Merged

Support EagleDraftModel.from_pretrained for finetuning#333
rahul-tuli merged 7 commits into
mainfrom
support_from_pretrained

Conversation

@fynnsu

@fynnsu fynnsu commented Mar 5, 2026

Copy link
Copy Markdown
Collaborator

Prereq

Depends on #332, which should be merged first.

Purpose

We'd like to support finetuning existing Eagle3 models. This pr adds the --from-pretrained option, which loads a pretrained model from HF / local path.

Setup flow

Fresh model setup, i.e. from_training_args pathway

  1. model = Eagle3DraftModel.init
    • Initialize all modules/parameters/buffers
    • Use torch.zeros (w/ correct shape + dtype) as placeholders for t2d/d2t
    • Use torch.nan (w/ correct shape + dtype) as placeholders for lm_head, embed_tokens, and verifier_lm_head
  2. model.load_vocab_mappings(t2d, d2t)
    • t2d/d2t can be None, which will cause an early return
    • Verify shapes match and then load t2d/d2t
  3. model.load_verifier_weights()
    • Loads embed_tokens, lm_head (into both lm_head and verifier_lm_head), verifier_norm from verifier model
    • Checks if embed_tokens and lm_head are still NaN (from init) before loading. Otherwise they are skipped
    • verifier_lm_head is always loaded
    • verifier_norm is skipped if the weight isn't found (with a warning). Note we don't use NaN init for this module

Continues below

Finetuning setup, i.e. from_pretrained pathway

  1. model = Eagle3DraftModel.init call internally by PreTrainedModel.from_pretrained under meta device context
  2. PreTrainedModel.from_pretrained also loads model weights for us
  3. model.load_vocab_mappings(t2d, d2t) called, same as above
  4. model.load_verifier_weights() called
    • lm_head and embed_token loading should be skipped because the values have already been set by PreTrainedModel.from_pretrained
    • verifier_lm_head and verifier_norm loaded (overriding existing values but should be the same)

Continues below

Joint (Fresh + Finetuning) next steps (in Trainer.setup_model)

If distributed:
    cache state dict on rank0
    Apply `fully_shard` to model layers and model (don't move to meta device, don't re-init weights)
    
    if checkpoint exists:
        load checkpoint (distributed)
    else:
        broadcast cached state dict from rank0 to all ranks
else:
    move model to local device
    if checkpoint exists:
        load checkpoint (single device)

Implementation

This uses the .from_pretrained method from the base PreTrainedModel class to resolve local vs remote model and handle downloading checkpoints.

As part of this pr, I also clean up the model initialize process, which previously also including loading verifier weights and t2d/d2t tensors. These will now be loaded after the init function. This was required to support .from_pretrained as the model init is run under a meta device context.

This means the init instead sets up placeholder buffers for verifier parameters/vocab mapping. These are being set intentionally in a way that makes it easy to confirm they are overwritten when training starts (e.g. by initializing some values to NaN so that failing to overwrite them will result in immediate NaN outputs from the model).

I also updated the fully shard handling to ensure values are set on all ranks correctly when using distributed training.

Testing

Added comprehensive test coverage for loading pathway combinations (e.g. loading from checkpoint, pretrained, fresh init, w/ and w/o vocab mappings, single gpu and distributed, etc.).

Note some tests require single or multi-gpu. These are skipped if requirements are not met. It would be good to ensure these are being run correctly at regular intervals.

@github-actions

github-actions Bot commented Mar 5, 2026

Copy link
Copy Markdown

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/vllm-project/speculators/actions/runs/22909838668/artifacts/5852352190.
They will be retained for up to 30 days.
Commit: c09fab9

Comment thread src/speculators/models/eagle3/core.py
Comment thread src/speculators/models/eagle3/core.py
Comment thread scripts/train.py
Base automatically changed from cleanup_repo to main March 6, 2026 11:48
@mergify

mergify Bot commented Mar 6, 2026

Copy link
Copy Markdown

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fynnsu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 6, 2026
@fynnsu fynnsu force-pushed the support_from_pretrained branch from 45ea878 to c3704b7 Compare March 6, 2026 15:06
@mergify mergify Bot removed the needs-rebase label Mar 6, 2026
Comment thread tmp.pt Outdated
@VincentG1234

VincentG1234 commented Mar 6, 2026

Copy link
Copy Markdown
Contributor

Tested locally and everything behaved smoothly as expected. 👍

One small observation: I noticed that the current tests are mainly unit tests. For a feature like this (loading from pretrained + finetuning), an E2E sanity check like the test here could be useful and robust to guard against regressions in the full workflow. That kind of test tends to catch subtle issues that unit tests may miss.

Of course, it could also make CI heavier since it involves running a short training job, so this is just a suggestion.

@fynnsu

fynnsu commented Mar 7, 2026

Copy link
Copy Markdown
Collaborator Author

@VincentG1234

Tested locally and everything behaved smoothly as expected. 👍

One small observation: I noticed that the current tests are mainly unit tests. For a feature like this (loading from pretrained + finetuning), an E2E sanity check like the test here could be useful and robust to guard against regressions in the full workflow. That kind of test tends to catch subtle issues that unit tests may miss.

Of course, it could also make CI heavier since it involves running a short training job, so this is just a suggestion.

Yeah this make sense, I will add something like that. We have a separate nightly job that runs our E2E tests, so it won't block prs but will be a useful signal if there's a regression.

@fynnsu fynnsu force-pushed the support_from_pretrained branch from c3704b7 to 95310e0 Compare March 7, 2026 15:14
@github-actions

This comment was marked as outdated.

@github-actions

This comment was marked as outdated.

@fynnsu fynnsu force-pushed the support_from_pretrained branch from 6be261f to 0212662 Compare March 9, 2026 18:34
@mergify

mergify Bot commented Mar 10, 2026

Copy link
Copy Markdown

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fynnsu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 10, 2026
fynnsu added 7 commits March 10, 2026 15:13
…etrained

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
There were some potentially robustness issues with the previous implementation. 
In particular, we were resetting parameters on the model after loading the verifier
weights. Although we were mostly avoiding the verifier weights there were some
risks with this approach. This commit simplifies the logic and adds an additional
broadcast step from rank0 to ensure all ranks have the same weight copy before 
training begins.


Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Test correctness of different load from pretrained/checkpoint/fresh init
on single gpu and multi-gpu setups.

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
@fynnsu fynnsu force-pushed the support_from_pretrained branch from 0212662 to c09fab9 Compare March 10, 2026 15:19
@fynnsu

fynnsu commented Mar 10, 2026

Copy link
Copy Markdown
Collaborator Author

Removed merge commit and rebased.

@mergify mergify Bot removed the needs-rebase label Mar 10, 2026
@rahul-tuli rahul-tuli merged commit 6c04447 into main Mar 11, 2026
12 checks passed
@rahul-tuli rahul-tuli deleted the support_from_pretrained branch March 11, 2026 14:16
YzTongNiar pushed a commit to YzTongNiar/speculators that referenced this pull request Apr 10, 2026
…t#333)

## Prereq

Depends on vllm-project#332, which should be merged first. 

## Purpose
We'd like to support finetuning existing Eagle3 models. This pr adds the
`--from-pretrained` option, which loads a pretrained model from HF /
local path.

### Setup flow
#### Fresh model setup, i.e. `from_training_args` pathway
1. model = Eagle3DraftModel.__init__
    - Initialize all modules/parameters/buffers
- Use `torch.zeros` (w/ correct shape + dtype) as placeholders for
t2d/d2t
- Use `torch.nan` (w/ correct shape + dtype) as placeholders for
lm_head, embed_tokens, and verifier_lm_head
2. model.load_vocab_mappings(t2d, d2t)
    - t2d/d2t can be `None`, which will cause an early return
    - Verify shapes match and then load t2d/d2t
3. model.load_verifier_weights()
- Loads embed_tokens, lm_head (into both lm_head and verifier_lm_head),
verifier_norm from verifier model
- Checks if embed_tokens and lm_head are still `NaN` (from init) before
loading. Otherwise they are skipped
    - verifier_lm_head is always loaded
- verifier_norm is skipped if the weight isn't found (with a warning).
Note we don't use NaN init for this module
 
Continues below

#### Finetuning setup, i.e. `from_pretrained` pathway
1. model = Eagle3DraftModel.__init__ call internally by
`PreTrainedModel.from_pretrained` under meta device context
2. `PreTrainedModel.from_pretrained` also loads model weights for us
3. model.load_vocab_mappings(t2d, d2t) called, same as above
4. model.load_verifier_weights() called
- lm_head and embed_token loading should be skipped because the values
have already been set by `PreTrainedModel.from_pretrained`
- verifier_lm_head and verifier_norm loaded (overriding existing values
but should be the same)
 
Continues below

#### Joint (Fresh + Finetuning) next steps (in `Trainer.setup_model`)
```
If distributed:
    cache state dict on rank0
    Apply `fully_shard` to model layers and model (don't move to meta device, don't re-init weights)
    
    if checkpoint exists:
        load checkpoint (distributed)
    else:
        broadcast cached state dict from rank0 to all ranks
else:
    move model to local device
    if checkpoint exists:
        load checkpoint (single device)
```

## Implementation
This uses the `.from_pretrained` method from the base `PreTrainedModel`
class to resolve local vs remote model and handle downloading
checkpoints.

As part of this pr, I also clean up the model initialize process, which
previously also including loading verifier weights and t2d/d2t tensors.
These will now be loaded after the init function. This was required to
support `.from_pretrained` as the model init is run under a meta device
context.

This means the init instead sets up placeholder buffers for verifier
parameters/vocab mapping. These are being set intentionally in a way
that makes it easy to confirm they are overwritten when training starts
(e.g. by initializing some values to `NaN` so that failing to overwrite
them will result in immediate NaN outputs from the model).

I also updated the fully shard handling to ensure values are set on all
ranks correctly when using distributed training.

## Testing

Added comprehensive test coverage for loading pathway combinations (e.g.
loading from checkpoint, pretrained, fresh init, w/ and w/o vocab
mappings, single gpu and distributed, etc.).

Note some tests require single or multi-gpu. These are skipped if
requirements are not met. It would be good to ensure these are being run
correctly at regular intervals.

---------

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants