Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import functools
import importlib.util
import json
import numbers
import os
import sys
Expand Down Expand Up @@ -772,6 +773,7 @@ def __init__(self):
self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH

self._initialized = False
self._auto_end_run = False
self._log_artifacts = False
self._ml_flow = mlflow

Expand All @@ -790,18 +792,32 @@ def setup(self, args, state, model):
point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment
to be activated. If an experiment with this name does not exist, a new experiment with this name is
created.
MLFLOW_TAGS (`str`, *optional*):
A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example:
os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'
MLFLOW_NESTED_RUN (`str`, *optional*):
Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current
run.
MLFLOW_RUN_ID (`str`, *optional*):
Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint.
When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified
run ID and other parameters are ignored.
"""
log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
logger.debug(f"MLFLOW experiment_name={experiment_name}, run_name={args.run_name}")
self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
logger.debug(
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}"
)
if state.is_world_process_zero:
if self._ml_flow.active_run() is None:
if experiment_name:
if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
if self._experiment_name:
# Use of set_experiment() ensure that Experiment is created if not exists
self._ml_flow.set_experiment(experiment_name)
self._ml_flow.start_run(run_name=args.run_name)
self._ml_flow.set_experiment(self._experiment_name)
self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run)
logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}")
self._auto_end_run = True
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
Expand All @@ -821,6 +837,10 @@ def setup(self, args, state, model):
combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
mlflow_tags = os.getenv("MLFLOW_TAGS", None)
if mlflow_tags:
mlflow_tags = json.loads(mlflow_tags)
self._ml_flow.set_tags(mlflow_tags)
self._initialized = True

def on_train_begin(self, args, state, control, model=None, **kwargs):
Expand Down Expand Up @@ -849,11 +869,13 @@ def on_train_end(self, args, state, control, **kwargs):
if self._log_artifacts:
logger.info("Logging artifacts. This may take time.")
self._ml_flow.log_artifacts(args.output_dir)
if self._auto_end_run and self._ml_flow.active_run():
self._ml_flow.end_run()

def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if self._ml_flow.active_run() is not None:
if self._auto_end_run and self._ml_flow.active_run() is not None:
self._ml_flow.end_run()


Expand Down