diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index ddae993f1d2c..971d55da9041 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -16,6 +16,7 @@ """ import functools import importlib.util +import json import numbers import os import sys @@ -772,6 +773,7 @@ def __init__(self): self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH self._initialized = False + self._auto_end_run = False self._log_artifacts = False self._ml_flow = mlflow @@ -790,18 +792,32 @@ def setup(self, args, state, model): point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment to be activated. If an experiment with this name does not exist, a new experiment with this name is created. + MLFLOW_TAGS (`str`, *optional*): + A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example: + os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}' + MLFLOW_NESTED_RUN (`str`, *optional*): + Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current + run. + MLFLOW_RUN_ID (`str`, *optional*): + Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint. + When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified + run ID and other parameters are ignored. """ - log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() - if log_artifacts in {"TRUE", "1"}: - self._log_artifacts = True - experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) - logger.debug(f"MLFLOW experiment_name={experiment_name}, run_name={args.run_name}") + self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES + self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES + self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) + self._run_id = os.getenv("MLFLOW_RUN_ID", None) + logger.debug( + f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}" + ) if state.is_world_process_zero: - if self._ml_flow.active_run() is None: - if experiment_name: + if self._ml_flow.active_run() is None or self._nested_run or self._run_id: + if self._experiment_name: # Use of set_experiment() ensure that Experiment is created if not exists - self._ml_flow.set_experiment(experiment_name) - self._ml_flow.start_run(run_name=args.run_name) + self._ml_flow.set_experiment(self._experiment_name) + self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run) + logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}") + self._auto_end_run = True combined_dict = args.to_dict() if hasattr(model, "config") and model.config is not None: model_config = model.config.to_dict() @@ -821,6 +837,10 @@ def setup(self, args, state, model): combined_dict_items = list(combined_dict.items()) for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH): self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH])) + mlflow_tags = os.getenv("MLFLOW_TAGS", None) + if mlflow_tags: + mlflow_tags = json.loads(mlflow_tags) + self._ml_flow.set_tags(mlflow_tags) self._initialized = True def on_train_begin(self, args, state, control, model=None, **kwargs): @@ -849,11 +869,13 @@ def on_train_end(self, args, state, control, **kwargs): if self._log_artifacts: logger.info("Logging artifacts. This may take time.") self._ml_flow.log_artifacts(args.output_dir) + if self._auto_end_run and self._ml_flow.active_run(): + self._ml_flow.end_run() def __del__(self): # if the previous run is not terminated correctly, the fluent API will # not let you start a new run before the previous one is killed - if self._ml_flow.active_run() is not None: + if self._auto_end_run and self._ml_flow.active_run() is not None: self._ml_flow.end_run()