diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index e3374b3c4108..00ebaa29afcc 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -644,7 +644,7 @@ def on_train_end(self, args, state, control, **kwargs): class WandbCallback(TrainerCallback): """ - A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/). + A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). """ def __init__(self): @@ -656,28 +656,44 @@ def __init__(self): self._wandb = wandb self._initialized = False - # log outputs - self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) + # log model + if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}): + DeprecationWarning( + f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in " + "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead." + ) + logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead") + self._log_model = "end" + else: + self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower() def setup(self, args, state, model, **kwargs): """ Setup the optional Weights & Biases (*wandb*) integration. One can subclass and override this method to customize the setup if needed. Find more information - [here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment + [here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment variables: Environment: - - **WANDB_LOG_MODEL** (`bool`, *optional*, defaults to `False`): - Whether or not to log model as artifact at the end of training. Use along with - [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model. - - **WANDB_WATCH** (`str`, *optional*, defaults to `gradients`): - Can be `gradients`, `all` or `false`. Set to `false` to disable gradient logging or `all` to log gradients - and parameters. - - **WANDB_PROJECT** (`str`, *optional*, defaults to `huggingface`): + - **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`): + Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set + to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint + will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along + with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model. + + + + Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers. + + + - **WANDB_WATCH** (`str`, *optional* defaults to `"false"`): + Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and + parameters. + - **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`): Set this to a custom string to store results in a different project. - **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`): - Whether or not to disable wandb entirely. Set `WANDB_DISABLED=True` to disable. + Whether to disable wandb entirely. Set `WANDB_DISABLED=true` to disable. """ if self._wandb is None: return @@ -694,15 +710,16 @@ def setup(self, args, state, model, **kwargs): trial_name = state.trial_name init_args = {} if trial_name is not None: - run_name = trial_name + init_args["name"] = trial_name init_args["group"] = args.run_name else: - run_name = args.run_name + if not (args.run_name is None or args.run_name == args.output_dir): + init_args["name"] = args.run_name if self._wandb.run is None: + self._wandb.init( project=os.getenv("WANDB_PROJECT", "huggingface"), - name=run_name, **init_args, ) # add config parameters (run may have been created manually) @@ -714,10 +731,9 @@ def setup(self, args, state, model, **kwargs): self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True) # keep track of model topology and gradients, unsupported on TPU - if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": - self._wandb.watch( - model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps) - ) + _watch_model = os.getenv("WANDB_WATCH", "false") + if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"): + self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps)) def on_train_begin(self, args, state, control, model=None, **kwargs): if self._wandb is None: @@ -733,7 +749,7 @@ def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): if self._wandb is None: return - if self._log_model and self._initialized and state.is_world_process_zero: + if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero: from .trainer import Trainer fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) @@ -751,7 +767,13 @@ def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwarg "train/total_floss": state.total_flos, } ) - artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata) + logger.info("Logging model artifacts. ...") + model_name = ( + f"model-{self._wandb.run.id}" + if (args.run_name is None or args.run_name == args.output_dir) + else f"model-{self._wandb.run.name}" + ) + artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata) for f in Path(temp_dir).glob("*"): if f.is_file(): with artifact.new_file(f.name, mode="wb") as fa: @@ -767,6 +789,26 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): logs = rewrite_logs(logs) self._wandb.log({**logs, "train/global_step": state.global_step}) + def on_save(self, args, state, control, **kwargs): + if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero: + checkpoint_metadata = { + k: v + for k, v in dict(self._wandb.summary).items() + if isinstance(v, numbers.Number) and not k.startswith("_") + } + + ckpt_dir = f"checkpoint-{state.global_step}" + artifact_path = os.path.join(args.output_dir, ckpt_dir) + logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...") + checkpoint_name = ( + f"checkpoint-{self._wandb.run.id}" + if (args.run_name is None or args.run_name == args.output_dir) + else f"checkpoint-{self._wandb.run.name}" + ) + artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata) + artifact.add_dir(artifact_path) + self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"]) + class CometCallback(TrainerCallback): """