Support EagleDraftModel.from_pretrained for finetuning#333
Conversation
|
📦 Build Artifacts Available |
cb6dea0 to
f8aed44
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
45ea878 to
c3704b7
Compare
|
Tested locally and everything behaved smoothly as expected. 👍 One small observation: I noticed that the current tests are mainly unit tests. For a feature like this (loading from pretrained + finetuning), an E2E sanity check like the test here could be useful and robust to guard against regressions in the full workflow. That kind of test tends to catch subtle issues that unit tests may miss. Of course, it could also make CI heavier since it involves running a short training job, so this is just a suggestion. |
Yeah this make sense, I will add something like that. We have a separate nightly job that runs our E2E tests, so it won't block prs but will be a useful signal if there's a regression. |
c3704b7 to
95310e0
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
6be261f to
0212662
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
…etrained Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
There were some potentially robustness issues with the previous implementation. In particular, we were resetting parameters on the model after loading the verifier weights. Although we were mostly avoiding the verifier weights there were some risks with this approach. This commit simplifies the logic and adds an additional broadcast step from rank0 to ensure all ranks have the same weight copy before training begins. Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Test correctness of different load from pretrained/checkpoint/fresh init on single gpu and multi-gpu setups. Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
0212662 to
c09fab9
Compare
|
Removed merge commit and rebased. |
…t#333) ## Prereq Depends on vllm-project#332, which should be merged first. ## Purpose We'd like to support finetuning existing Eagle3 models. This pr adds the `--from-pretrained` option, which loads a pretrained model from HF / local path. ### Setup flow #### Fresh model setup, i.e. `from_training_args` pathway 1. model = Eagle3DraftModel.__init__ - Initialize all modules/parameters/buffers - Use `torch.zeros` (w/ correct shape + dtype) as placeholders for t2d/d2t - Use `torch.nan` (w/ correct shape + dtype) as placeholders for lm_head, embed_tokens, and verifier_lm_head 2. model.load_vocab_mappings(t2d, d2t) - t2d/d2t can be `None`, which will cause an early return - Verify shapes match and then load t2d/d2t 3. model.load_verifier_weights() - Loads embed_tokens, lm_head (into both lm_head and verifier_lm_head), verifier_norm from verifier model - Checks if embed_tokens and lm_head are still `NaN` (from init) before loading. Otherwise they are skipped - verifier_lm_head is always loaded - verifier_norm is skipped if the weight isn't found (with a warning). Note we don't use NaN init for this module Continues below #### Finetuning setup, i.e. `from_pretrained` pathway 1. model = Eagle3DraftModel.__init__ call internally by `PreTrainedModel.from_pretrained` under meta device context 2. `PreTrainedModel.from_pretrained` also loads model weights for us 3. model.load_vocab_mappings(t2d, d2t) called, same as above 4. model.load_verifier_weights() called - lm_head and embed_token loading should be skipped because the values have already been set by `PreTrainedModel.from_pretrained` - verifier_lm_head and verifier_norm loaded (overriding existing values but should be the same) Continues below #### Joint (Fresh + Finetuning) next steps (in `Trainer.setup_model`) ``` If distributed: cache state dict on rank0 Apply `fully_shard` to model layers and model (don't move to meta device, don't re-init weights) if checkpoint exists: load checkpoint (distributed) else: broadcast cached state dict from rank0 to all ranks else: move model to local device if checkpoint exists: load checkpoint (single device) ``` ## Implementation This uses the `.from_pretrained` method from the base `PreTrainedModel` class to resolve local vs remote model and handle downloading checkpoints. As part of this pr, I also clean up the model initialize process, which previously also including loading verifier weights and t2d/d2t tensors. These will now be loaded after the init function. This was required to support `.from_pretrained` as the model init is run under a meta device context. This means the init instead sets up placeholder buffers for verifier parameters/vocab mapping. These are being set intentionally in a way that makes it easy to confirm they are overwritten when training starts (e.g. by initializing some values to `NaN` so that failing to overwrite them will result in immediate NaN outputs from the model). I also updated the fully shard handling to ensure values are set on all ranks correctly when using distributed training. ## Testing Added comprehensive test coverage for loading pathway combinations (e.g. loading from checkpoint, pretrained, fresh init, w/ and w/o vocab mappings, single gpu and distributed, etc.). Note some tests require single or multi-gpu. These are skipped if requirements are not met. It would be good to ensure these are being run correctly at regular intervals. --------- Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Prereq
Depends on #332, which should be merged first.
Purpose
We'd like to support finetuning existing Eagle3 models. This pr adds the
--from-pretrainedoption, which loads a pretrained model from HF / local path.Setup flow
Fresh model setup, i.e.
from_training_argspathwaytorch.zeros(w/ correct shape + dtype) as placeholders for t2d/d2ttorch.nan(w/ correct shape + dtype) as placeholders for lm_head, embed_tokens, and verifier_lm_headNone, which will cause an early returnNaN(from init) before loading. Otherwise they are skippedContinues below
Finetuning setup, i.e.
from_pretrainedpathwayPreTrainedModel.from_pretrainedunder meta device contextPreTrainedModel.from_pretrainedalso loads model weights for usPreTrainedModel.from_pretrainedContinues below
Joint (Fresh + Finetuning) next steps (in
Trainer.setup_model)Implementation
This uses the
.from_pretrainedmethod from the basePreTrainedModelclass to resolve local vs remote model and handle downloading checkpoints.As part of this pr, I also clean up the model initialize process, which previously also including loading verifier weights and t2d/d2t tensors. These will now be loaded after the init function. This was required to support
.from_pretrainedas the model init is run under a meta device context.This means the init instead sets up placeholder buffers for verifier parameters/vocab mapping. These are being set intentionally in a way that makes it easy to confirm they are overwritten when training starts (e.g. by initializing some values to
NaNso that failing to overwrite them will result in immediate NaN outputs from the model).I also updated the fully shard handling to ensure values are set on all ranks correctly when using distributed training.
Testing
Added comprehensive test coverage for loading pathway combinations (e.g. loading from checkpoint, pretrained, fresh init, w/ and w/o vocab mappings, single gpu and distributed, etc.).
Note some tests require single or multi-gpu. These are skipped if requirements are not met. It would be good to ensure these are being run correctly at regular intervals.