Skip to content

Conversation

@borisdayma
Copy link
Contributor

@borisdayma borisdayma commented Oct 28, 2020

What does this PR do?

EDIT

The logic has been simplified.
Model is just saved to a temporary folder and uploaded as artifact at the end of training.

ORIGINAL message

Save trained model as artifact.

A few different possibilities:

  • log model at on_save callback -> the issue is there could quickly be too many checkpoints to upload, high bandwidth…

  • log model at on_train_end

    • when we have access to state.best_model_checkpoint, we should just upload that folder
    • we could upload entire output_dir but it could be very large (same problem as on_save, ideally we only upload one model
    • we can save last model in a separate folder and upload it -> issue is that we don't have access to Trainer.save_model (where are the 2-way callbacks 😉)
    • we could just use _sorted_checkpoints and log only last element (which would also be the best model when metrics are given)

I'm thinking I should actually go with the last option (use of _sorted_checkpoints). What do you think?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to the it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR. @sgugger

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR!
Could you split it in two parts? The fix in _new_step is something we want to merge ASAP (great catch btw!) but the model-saving on the Wandb side might trigger a longer discussion.

def _new_step(self):
""" Internal method that resets the variable for a new step. """
self.should_save_model = False
self.should_save = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

@borisdayma
Copy link
Contributor Author

Ok I submitted #8121 for the typo in _new_step.

As for this PR, I'm thinking the logic should be:

  • use folder referenced by the last item returned from _sorted_checkpoints
  • in case it's empty, we should probably save the current checkpoint locally and upload it (since we specifically requested an upload to wandb)

@borisdayma
Copy link
Contributor Author

After experimenting a bit more:

  1. Should we upload a model only if _sorted_checkpoints(…) is non-empty?

    • we don't necessarily get the last model (eg save every 100 steps with 520 steps total)
  2. Should we just save current state model at end of training in args.output_dir + "\wandb"

    • we need to have access to Trainer.save_model from WandbCallback
    • we could decide to use state.best_model_checkpoint when present instead
    • we ignore any checkpoint

@borisdayma
Copy link
Contributor Author

@sgugger Do you want me to make an attempt at giving access to the Trainer from callbacks or is it a pattern you want to avoid?

@sgugger
Copy link
Collaborator

sgugger commented Nov 9, 2020

Hi @borisdayma, sorry I took a bit of time to reply on this, we were waiting for the new version of the model hub to materialize before moving forward on this.

So! The callbacks aren't 2-way in Transformers because then you have to be very careful about the order of their execution. Here the design was to just allow for callbacks that can read the state, not write, and for any piece of code that needs the write access, users should subclass the Trainer. The circular reference is also problematic for memory management so we leave 2-way callbacks for libraries focused on training models, and keep our simple reporting callbacks as they are :-)

Like you said, you have access to the state with best_model_checkpoint. You can also unpack the model from the kwargs and access it. What is in the Trainer.save_model method that you need? Worst case scenario, you can even isntantiate an empty Trainer with just the model and the training arguments, and use its save_model method.

@borisdayma
Copy link
Contributor Author

The issue with best_model_checkpoint is that it does not exist if there's no measurement metric set.
It could make sense to define it as the last checkpoint in that case.

The next issues would then be:

  • sometimes no model has been saved yet (maybe not enough epochs) while we want to log the model -> we could accept that it's an issue on the user side and give a warning
  • sometimes we may log every 100 steps and run for 180 steps. The last checkpoint is a bit old -> on this aspect I feel like the Trainer should automatically save the final step as a checkpoint

What do you think?

The alternative would be to completely ignore that logic, let wandb save a model somewhere and upload it. I had not realized we could have access to model from the callback (though saving from Trainer is better as it handles TPU, save tokenizer, args and may also change in the future).

@sgugger
Copy link
Collaborator

sgugger commented Nov 9, 2020

I think the most logical is to save the final model, the intermediate checkpoints are there to resume training is something went wrong, or load the best model at the end (which is done before the on_train_end event). That's also why we don't always save the model at the end of training, leaving that part to the user in a script.

If you use the logic of unpacking the model from the kwargs, you can simply create a new Trainer with it which can then save it easily with the Trainer.save_model method. Normally the model you unpack is the reference to the real model, so you won't have a DistributedDataParallel or something like that, and everything should work smoothly.

@borisdayma
Copy link
Contributor Author

Finally getting closer!

Few notes:

  • I import Trainer inside my function to avoid circular reference
  • I need to find a way to see if I need Trainer or TfTrainer, should I infer it through type(model)
  • I use Trainer.state as model metadata but maybe it's not that useful. The artifact is associated to a run which already has all the config parameters but it could be useful to relog it, or maybe I should just log the final metrics instead (that I can get through wandb)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting back on this! Looks in good shape, I just have a few comments.


Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact, requires use of `TrainingArguments.save_steps`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Whether or not to log model as artifact, requires use of `TrainingArguments.save_steps`
Whether or not to log model as artifact, requires use of :obj:`TrainingArguments.save_steps`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, actually it's not required anymore (it was the previous logic) since we now just save the model at the end.
I'll correct the description.

wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))

# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() == "TRUE"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've got something in file_utils to pick all thru-thy values, that could help here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I added the True value (only in the wandb section). Not sure if you want it as part of ENV_VARS_TRUE_VALUES

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is any harm in that.

self.setup(args, state, model, reinit=hp_search, **kwargs)

def on_train_end(self, args, state, control, **kwargs):
if self._log_model and self._initialized and state.is_world_process_zero and "model" in kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model will always be in kwargs, it's passed to all events. I think it would be better to extract it in the signature (with model=None)

if self._log_model and self._initialized and state.is_world_process_zero and "model" in kwargs:
from .trainer import Trainer

fake_trainer = Trainer(args=args, model=kwargs["model"], tokenizer=kwargs.get("tokenizer"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tokenizer is not passed however, which is something we could add in this PR.

fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
state = dict(vars(state))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename to state_dict the new object to avoid confusion with state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, actually I'm not sure it's a good idea to log it.
Maybe just the total flos is really interesting there + the final metrics.

@borisdayma
Copy link
Contributor Author

I think it's becoming pretty cool. Here is an artifact logged with this method.

The current limitation is that it only works with Pytorch for now.
Are there any plan for more synergy between Trainer or TfTrainer or should they be considered independently?

@borisdayma borisdayma marked this pull request as ready for review December 11, 2020 18:00
@sgugger
Copy link
Collaborator

sgugger commented Dec 11, 2020

TFTrainer will be reworked in the near future and be a simple wrap around the Keras fit method (and the callbacks will be regular Keras callbacks).

if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer

fake_trainer = Trainer(args=args, model=model, tokenizer=kwargs.get("tokenizer"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, not that the tokenizer is passed along, let's unpack it in the signature :-)

@sgugger sgugger requested a review from LysandreJik December 11, 2020 18:21
@borisdayma
Copy link
Contributor Author

I adjusted the metadata when we use load_best_model_at_end.
In that case we don't want to log the last metrics but only flos and best metric.

@borisdayma
Copy link
Contributor Author

Small change:

  • force commit of last step
  • more robust way to get metadata, it will consider any data that has been logged and is a number

I'm now ready on my side. Feel free to ping me!

@borisdayma
Copy link
Contributor Author

@LysandreJik let me know if you have any comments

@borisdayma
Copy link
Contributor Author

Happy new year everybody!
Since it has already been a while since this PR was made, am I supposed to merge master and verify the tests are passing?
Let me know if I need to do anything on my side.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Sorry for taking so long to review, and thank you for your work on this @borisdayma.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants