Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tempfile
from pathlib import Path

from .utils import is_datasets_available, logging
from .utils import flatten_dict, is_datasets_available, logging


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -802,10 +802,13 @@ def setup(self, args, state, model):
Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint.
When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified
run ID and other parameters are ignored.
MLFLOW_FLATTEN_PARAMS (`str`, *optional*):
Whether to flatten the parameters dictionary before logging. Default to `False`.
"""
self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
logger.debug(
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}"
Expand All @@ -822,15 +825,15 @@ def setup(self, args, state, model):
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict
# remove params that are too long for MLflow
for name, value in list(combined_dict.items()):
# internally, all values are converted to str in MLflow
if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
logger.warning(
f"Trainer is attempting to log a value of "
f'"{value}" for key "{name}" as a parameter. '
f"MLflow's log_param() only accepts values no longer than "
f"250 characters so we dropped this attribute."
f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. '
f"MLflow's log_param() only accepts values no longer than 250 characters so we dropped this attribute. "
f"You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and avoid this message."
)
del combined_dict[name]
# MLflow cannot log more than 100 values in one go, so we have to split it
Expand All @@ -857,10 +860,8 @@ def on_log(self, args, state, control, logs, model=None, **kwargs):
metrics[k] = v
else:
logger.warning(
f"Trainer is attempting to log a value of "
f'"{v}" of type {type(v)} for key "{k}" as a metric. '
f"MLflow's log_metric() only accepts float and "
f"int types so we dropped this attribute."
f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
f"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
)
self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)

Expand All @@ -875,7 +876,7 @@ def on_train_end(self, args, state, control, **kwargs):
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if self._auto_end_run and self._ml_flow.active_run() is not None:
if self._auto_end_run and self._ml_flow and self._ml_flow.active_run() is not None:
self._ml_flow.end_run()


Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TensorType,
cached_property,
find_labels,
flatten_dict,
is_tensor,
to_numpy,
to_py_obj,
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import inspect
from collections import OrderedDict, UserDict
from collections.abc import MutableMapping
from contextlib import ExitStack
from dataclasses import fields
from enum import Enum
Expand Down Expand Up @@ -310,3 +311,17 @@ def find_labels(model_class):
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:
return [p for p in signature.parameters if "label" in p]


def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
"""Flatten a nested dict into a single level dict."""

def _flatten_dict(d, parent_key="", delimiter="."):
for k, v in d.items():
key = str(parent_key) + delimiter + str(k) if parent_key else k
if v and isinstance(v, MutableMapping):
yield from flatten_dict(v, key, delimiter=delimiter).items()
else:
yield key, v

return dict(_flatten_dict(d, parent_key, delimiter))
45 changes: 45 additions & 0 deletions tests/utils/test_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from transformers.utils import flatten_dict


class GenericTester(unittest.TestCase):
def test_flatten_dict(self):
input_dict = {
"task_specific_params": {
"summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4},
"summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4},
"summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6},
}
}
expected_dict = {
"task_specific_params.summarization.length_penalty": 1.0,
"task_specific_params.summarization.max_length": 128,
"task_specific_params.summarization.min_length": 12,
"task_specific_params.summarization.num_beams": 4,
"task_specific_params.summarization_cnn.length_penalty": 2.0,
"task_specific_params.summarization_cnn.max_length": 142,
"task_specific_params.summarization_cnn.min_length": 56,
"task_specific_params.summarization_cnn.num_beams": 4,
"task_specific_params.summarization_xsum.length_penalty": 1.0,
"task_specific_params.summarization_xsum.max_length": 62,
"task_specific_params.summarization_xsum.min_length": 11,
"task_specific_params.summarization_xsum.num_beams": 6,
}

self.assertEqual(flatten_dict(input_dict), expected_dict)
2 changes: 1 addition & 1 deletion utils/check_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def get_transformers_submodules():
if fname == "__init__.py":
continue
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
submodule = short_path.replace(os.path.sep, ".").replace(".py", "")
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
if len(submodule.split(".")) == 1:
submodules.append(submodule)
return submodules
Expand Down
2 changes: 1 addition & 1 deletion utils/tests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def create_reverse_dependency_map():
"feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
"feature_extraction_utils.py": "test_feature_extraction_common.py",
"file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
"utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
"utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
"utils/hub.py": "utils/test_file_utils.py",
"modelcard.py": "utils/test_model_card.py",
"modeling_flax_utils.py": "test_modeling_flax_common.py",
Expand Down