diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 971d55da9041..625ee6875a6c 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -23,7 +23,7 @@ import tempfile from pathlib import Path -from .utils import is_datasets_available, logging +from .utils import flatten_dict, is_datasets_available, logging logger = logging.get_logger(__name__) @@ -802,10 +802,13 @@ def setup(self, args, state, model): Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint. When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified run ID and other parameters are ignored. + MLFLOW_FLATTEN_PARAMS (`str`, *optional*): + Whether to flatten the parameters dictionary before logging. Default to `False`. """ self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None) + self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES self._run_id = os.getenv("MLFLOW_RUN_ID", None) logger.debug( f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}" @@ -822,15 +825,15 @@ def setup(self, args, state, model): if hasattr(model, "config") and model.config is not None: model_config = model.config.to_dict() combined_dict = {**model_config, **combined_dict} + combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict # remove params that are too long for MLflow for name, value in list(combined_dict.items()): # internally, all values are converted to str in MLflow if len(str(value)) > self._MAX_PARAM_VAL_LENGTH: logger.warning( - f"Trainer is attempting to log a value of " - f'"{value}" for key "{name}" as a parameter. ' - f"MLflow's log_param() only accepts values no longer than " - f"250 characters so we dropped this attribute." + f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. ' + f"MLflow's log_param() only accepts values no longer than 250 characters so we dropped this attribute. " + f"You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and avoid this message." ) del combined_dict[name] # MLflow cannot log more than 100 values in one go, so we have to split it @@ -857,10 +860,8 @@ def on_log(self, args, state, control, logs, model=None, **kwargs): metrics[k] = v else: logger.warning( - f"Trainer is attempting to log a value of " - f'"{v}" of type {type(v)} for key "{k}" as a metric. ' - f"MLflow's log_metric() only accepts float and " - f"int types so we dropped this attribute." + f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. ' + f"MLflow's log_metric() only accepts float and int types so we dropped this attribute." ) self._ml_flow.log_metrics(metrics=metrics, step=state.global_step) @@ -875,7 +876,7 @@ def on_train_end(self, args, state, control, **kwargs): def __del__(self): # if the previous run is not terminated correctly, the fluent API will # not let you start a new run before the previous one is killed - if self._auto_end_run and self._ml_flow.active_run() is not None: + if self._auto_end_run and self._ml_flow and self._ml_flow.active_run() is not None: self._ml_flow.end_run() diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2c473b389d4e..88891365f03e 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -38,6 +38,7 @@ TensorType, cached_property, find_labels, + flatten_dict, is_tensor, to_numpy, to_py_obj, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index bea5b3dd4775..136762a37858 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -17,6 +17,7 @@ import inspect from collections import OrderedDict, UserDict +from collections.abc import MutableMapping from contextlib import ExitStack from dataclasses import fields from enum import Enum @@ -310,3 +311,17 @@ def find_labels(model_class): return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")] else: return [p for p in signature.parameters if "label" in p] + + +def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."): + """Flatten a nested dict into a single level dict.""" + + def _flatten_dict(d, parent_key="", delimiter="."): + for k, v in d.items(): + key = str(parent_key) + delimiter + str(k) if parent_key else k + if v and isinstance(v, MutableMapping): + yield from flatten_dict(v, key, delimiter=delimiter).items() + else: + yield key, v + + return dict(_flatten_dict(d, parent_key, delimiter)) diff --git a/tests/utils/test_generic.py b/tests/utils/test_generic.py new file mode 100644 index 000000000000..6fbdbee40360 --- /dev/null +++ b/tests/utils/test_generic.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.utils import flatten_dict + + +class GenericTester(unittest.TestCase): + def test_flatten_dict(self): + input_dict = { + "task_specific_params": { + "summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4}, + "summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4}, + "summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6}, + } + } + expected_dict = { + "task_specific_params.summarization.length_penalty": 1.0, + "task_specific_params.summarization.max_length": 128, + "task_specific_params.summarization.min_length": 12, + "task_specific_params.summarization.num_beams": 4, + "task_specific_params.summarization_cnn.length_penalty": 2.0, + "task_specific_params.summarization_cnn.max_length": 142, + "task_specific_params.summarization_cnn.min_length": 56, + "task_specific_params.summarization_cnn.num_beams": 4, + "task_specific_params.summarization_xsum.length_penalty": 1.0, + "task_specific_params.summarization_xsum.max_length": 62, + "task_specific_params.summarization_xsum.min_length": 11, + "task_specific_params.summarization_xsum.num_beams": 6, + } + + self.assertEqual(flatten_dict(input_dict), expected_dict) diff --git a/utils/check_inits.py b/utils/check_inits.py index 99700ccec598..5b95e609e18d 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -249,7 +249,7 @@ def get_transformers_submodules(): if fname == "__init__.py": continue short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS)) - submodule = short_path.replace(os.path.sep, ".").replace(".py", "") + submodule = short_path.replace(".py", "").replace(os.path.sep, ".") if len(submodule.split(".")) == 1: submodules.append(submodule) return submodules diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index f733f301e8dc..1eda2be47f57 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -268,7 +268,7 @@ def create_reverse_dependency_map(): "feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py", "feature_extraction_utils.py": "test_feature_extraction_common.py", "file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"], - "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py"], + "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"], "utils/hub.py": "utils/test_file_utils.py", "modelcard.py": "utils/test_model_card.py", "modeling_flax_utils.py": "test_modeling_flax_common.py",