-
Notifications
You must be signed in to change notification settings - Fork 31.9k
feat(wandb): save model as artifact #8119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR!
Could you split it in two parts? The fix in _new_step is something we want to merge ASAP (great catch btw!) but the model-saving on the Wandb side might trigger a longer discussion.
| def _new_step(self): | ||
| """ Internal method that resets the variable for a new step. """ | ||
| self.should_save_model = False | ||
| self.should_save = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch!
|
Ok I submitted #8121 for the typo in As for this PR, I'm thinking the logic should be:
|
|
After experimenting a bit more:
|
|
@sgugger Do you want me to make an attempt at giving access to the |
|
Hi @borisdayma, sorry I took a bit of time to reply on this, we were waiting for the new version of the model hub to materialize before moving forward on this. So! The callbacks aren't 2-way in Transformers because then you have to be very careful about the order of their execution. Here the design was to just allow for callbacks that can read the state, not write, and for any piece of code that needs the write access, users should subclass the Like you said, you have access to the state with |
|
The issue with The next issues would then be:
What do you think? The alternative would be to completely ignore that logic, let wandb save a model somewhere and upload it. I had not realized we could have access to |
|
I think the most logical is to save the final model, the intermediate checkpoints are there to resume training is something went wrong, or load the best model at the end (which is done before the If you use the logic of unpacking the model from the kwargs, you can simply create a new |
|
Finally getting closer! Few notes:
|
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting back on this! Looks in good shape, I just have a few comments.
src/transformers/integrations.py
Outdated
|
|
||
| Environment: | ||
| WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
| Whether or not to log model as artifact, requires use of `TrainingArguments.save_steps` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Whether or not to log model as artifact, requires use of `TrainingArguments.save_steps` | |
| Whether or not to log model as artifact, requires use of :obj:`TrainingArguments.save_steps` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, actually it's not required anymore (it was the previous logic) since we now just save the model at the end.
I'll correct the description.
src/transformers/integrations.py
Outdated
| wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) | ||
|
|
||
| # log outputs | ||
| self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() == "TRUE" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've got something in file_utils to pick all thru-thy values, that could help here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I added the True value (only in the wandb section). Not sure if you want it as part of ENV_VARS_TRUE_VALUES
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is any harm in that.
src/transformers/integrations.py
Outdated
| self.setup(args, state, model, reinit=hp_search, **kwargs) | ||
|
|
||
| def on_train_end(self, args, state, control, **kwargs): | ||
| if self._log_model and self._initialized and state.is_world_process_zero and "model" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model will always be in kwargs, it's passed to all events. I think it would be better to extract it in the signature (with model=None)
src/transformers/integrations.py
Outdated
| if self._log_model and self._initialized and state.is_world_process_zero and "model" in kwargs: | ||
| from .trainer import Trainer | ||
|
|
||
| fake_trainer = Trainer(args=args, model=kwargs["model"], tokenizer=kwargs.get("tokenizer")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tokenizer is not passed however, which is something we could add in this PR.
src/transformers/integrations.py
Outdated
| fake_trainer.save_model(temp_dir) | ||
| # use run name and ensure it's a valid Artifact name | ||
| artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name) | ||
| state = dict(vars(state)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename to state_dict the new object to avoid confusion with state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, actually I'm not sure it's a good idea to log it.
Maybe just the total flos is really interesting there + the final metrics.
|
I think it's becoming pretty cool. Here is an artifact logged with this method. The current limitation is that it only works with Pytorch for now. |
|
|
src/transformers/integrations.py
Outdated
| if self._log_model and self._initialized and state.is_world_process_zero: | ||
| from .trainer import Trainer | ||
|
|
||
| fake_trainer = Trainer(args=args, model=model, tokenizer=kwargs.get("tokenizer")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, not that the tokenizer is passed along, let's unpack it in the signature :-)
|
I adjusted the metadata when we use |
|
Small change:
I'm now ready on my side. Feel free to ping me! |
|
@LysandreJik let me know if you have any comments |
|
Happy new year everybody! |
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Sorry for taking so long to review, and thank you for your work on this @borisdayma.
What does this PR do?
EDIT
The logic has been simplified.
Model is just saved to a temporary folder and uploaded as artifact at the end of training.
ORIGINAL message
Save trained model as artifact.
A few different possibilities:
log model at
on_savecallback -> the issue is there could quickly be too many checkpoints to upload, high bandwidth…log model at
on_train_endstate.best_model_checkpoint, we should just upload that folderoutput_dirbut it could be very large (same problem ason_save, ideally we only upload one modelTrainer.save_model(where are the 2-way callbacks 😉)_sorted_checkpointsand log only last element (which would also be the best model when metrics are given)I'm thinking I should actually go with the last option (use of
_sorted_checkpoints). What do you think?Before submitting
Pull Request section?
to the it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR. @sgugger