Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def on_train_end(self, args, state, control, **kwargs):

class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/).
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
"""

def __init__(self):
Expand All @@ -656,28 +656,40 @@ def __init__(self):

self._wandb = wandb
self._initialized = False
# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
# log model
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"future versions. Use one of `end` or `checkpoint` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()

def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (*wandb*) integration.

One can subclass and override this method to customize the setup if needed. Find more information
[here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment
[here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
variables:

Environment:
- **WANDB_LOG_MODEL** (`bool`, *optional*, defaults to `False`):
Whether or not to log model as artifact at the end of training. Use along with
[`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
- **WANDB_WATCH** (`str`, *optional*, defaults to `gradients`):
Can be `gradients`, `all` or `false`. Set to `false` to disable gradient logging or `all` to log gradients
and parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `huggingface`):
Set this to a custom string to store results in a different project.
- **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`.
If set to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the
checkpoint will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be
uploaded. Use along with *TrainingArguments.load_best_model_at_end* to upload best model.
*Warning*: Setting `WANDB_LOG_MODEL` as `bool` is deprecated and will be removed in future
versions.
- **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
parameters.
- **WANDB_PROJECT** (`str`, *optional*, defaults to `"huggingface"`):
Set this to a custom string to store results in a different project.
- **WANDB_DISABLED** (`bool`, *optional*, defaults to `False`):
Whether or not to disable wandb entirely. Set `WANDB_DISABLED=True` to disable.
Whether to disable wandb entirely. Set *WANDB_DISABLED=true* to disable.
"""
if self._wandb is None:
return
Expand All @@ -694,15 +706,16 @@ def setup(self, args, state, model, **kwargs):
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
run_name = trial_name
init_args["name"] = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name
if not (args.run_name is None or args.run_name == args.output_dir):
init_args["name"] = args.run_name

if self._wandb.run is None:

self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
name=run_name,
**init_args,
)
# add config parameters (run may have been created manually)
Expand All @@ -714,10 +727,9 @@ def setup(self, args, state, model, **kwargs):
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)

# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
self._wandb.watch(
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
)
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, args.logging_steps))

def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
Expand All @@ -733,7 +745,7 @@ def on_train_begin(self, args, state, control, model=None, **kwargs):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model and self._initialized and state.is_world_process_zero:
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
from .trainer import Trainer

fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
Expand All @@ -751,7 +763,13 @@ def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwarg
"train/total_floss": state.total_flos,
}
)
artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
logger.info("Logging model artifacts. ...")
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
Expand All @@ -767,6 +785,26 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
logs = rewrite_logs(logs)
self._wandb.log({**logs, "train/global_step": state.global_step})

def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}

ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (
f"checkpoint-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])


class CometCallback(TrainerCallback):
"""
Expand Down
38 changes: 0 additions & 38 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,37 +357,6 @@ def load_tf_weights_in_albert(*args, **kwargs):
requires_backends(load_tf_weights_in_albert, ["torch"])


ALTCLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None


class AltCLIPModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AltCLIPPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AltCLIPTextModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class AltCLIPVisionModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


AUDIO_SPECTROGRAM_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None


Expand Down Expand Up @@ -4003,13 +3972,6 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class MT5PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None


Expand Down