-
Notifications
You must be signed in to change notification settings - Fork 236
MLFlow Integration #2112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
MLFlow Integration #2112
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
bf39c8a
MLFlow Integration
yaoyu-33 46635c4
update executors; be careful with perf script change.
yaoyu-33 b4efe69
update uv.lock
79507d7
fix: move mlflow_tracking_uri to correct indentation level
yaoyu-33 fa03b2c
test: add pytest.mark.unit markers to MLFlow unit tests
yaoyu-33 0d27c70
[test] feat: Add unit tests for MLflow integration
yaoyu-33 9a0a12d
fix: validate mlflow config before import and handle circular import …
yaoyu-33 80e623a
fix: mock mlflow import in tests to avoid timeout
yaoyu-33 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| from megatron.bridge.training.config import ConfigContainer | ||
| from megatron.bridge.training.nvrx_straggler import NVRxStragglerDetectionManager | ||
| from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer | ||
| from megatron.bridge.training.utils.log_utils import safe_serialize | ||
| from megatron.bridge.training.utils.sig_utils import DistributedSignalHandler | ||
| from megatron.bridge.utils.common_utils import get_rank_safe, get_world_size_safe | ||
|
|
||
|
|
@@ -124,6 +125,7 @@ def __init__(self) -> None: | |
| self._tokenizer: Optional[Any] = None | ||
| self._tensorboard_logger: Optional[SummaryWriter] = None | ||
| self._wandb_logger: Optional[Any] = None | ||
| self._mlflow_logger: Optional[Any] = None | ||
| self._timers: Optional[Timers] = None | ||
| self._train_state: Optional[TrainState] = None | ||
| self.rank_monitor_client: Optional[Any] = None | ||
|
|
@@ -193,26 +195,6 @@ def wandb_logger(self) -> Optional[Any]: | |
|
|
||
| save_dir = self.cfg.logger.wandb_save_dir or os.path.join(self.cfg.checkpoint.save, "wandb") | ||
|
|
||
| # Sanitize config for WandB by doing a JSON round-trip | ||
| # This ensures all objects are converted to basic Python types that WandB can handle | ||
| def safe_serialize(obj): | ||
| """Safely convert any object to a JSON-serializable type. | ||
|
|
||
| Handles objects with broken __str__ or __repr__ methods that return | ||
| non-string types (e.g., PipelineParallelLayerLayout returns list). | ||
| """ | ||
| try: | ||
| # Try str() first | ||
| result = str(obj) | ||
| # Verify it actually returns a string | ||
| if not isinstance(result, str): | ||
| # __str__ returned non-string type, use type name instead | ||
| return f"<{type(obj).__name__}>" | ||
| return result | ||
| except Exception: | ||
| # __str__ raised an exception, use type name as fallback | ||
| return f"<{type(obj).__name__}>" | ||
|
|
||
| config_dict = self.cfg.to_dict() | ||
| sanitized_config = json.loads(json.dumps(config_dict, default=safe_serialize)) | ||
|
|
||
|
|
@@ -234,12 +216,79 @@ def safe_serialize(obj): | |
| self._wandb_logger = None | ||
| return self._wandb_logger | ||
|
|
||
| @property | ||
| def mlflow_logger(self) -> Optional[Any]: | ||
| """The MLFlow logger instance. | ||
|
|
||
| Uses the configuration under LoggerConfig to create or resume an MLFlow run. | ||
| Restricted to the last rank to avoid duplicate entries or probably any racing conditions. | ||
| """ | ||
| if self._mlflow_logger is None: | ||
| cfg = self.cfg | ||
| if cfg is None: | ||
| self._mlflow_logger = None | ||
| return self._mlflow_logger | ||
|
|
||
| logger_cfg = cfg.logger | ||
| if logger_cfg.mlflow_experiment and get_rank_safe() == (get_world_size_safe() - 1): | ||
| if logger_cfg.mlflow_run_name == "": | ||
| raise ValueError("Please specify the mlflow_run_name for MLFlow logging!") | ||
|
Comment on lines
+219
to
+235
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ruff TRY003: long raise message will trip lint. ✅ Suggested lint-safe tweak- raise ValueError("Please specify the mlflow_run_name for MLFlow logging!")
+ raise ValueError(
+ "Please specify the mlflow_run_name for MLFlow logging!"
+ ) # noqa: TRY003🧰 Tools🪛 Ruff (0.14.14)235-235: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||
|
|
||
| import mlflow | ||
|
|
||
| # set tracking URI | ||
| if logger_cfg.mlflow_tracking_uri: | ||
| mlflow.set_tracking_uri(logger_cfg.mlflow_tracking_uri) | ||
|
|
||
| # Set or get experiment | ||
| mlflow.set_experiment(logger_cfg.mlflow_experiment) | ||
|
|
||
| # Prepare tags and params | ||
| def _flatten_dict(d: dict[str, Any], parent_key: str = "", sep: str = ".") -> dict[str, Any]: | ||
| items: dict[str, Any] = {} | ||
| for k, v in d.items(): | ||
| new_key = f"{parent_key}{sep}{k}" if parent_key else k | ||
| if isinstance(v, dict): | ||
| items.update(_flatten_dict(v, new_key, sep=sep)) | ||
| else: | ||
| if isinstance(v, (list, tuple)): | ||
| v = [safe_serialize(x) for x in v] | ||
| items[new_key] = v | ||
| return items | ||
|
|
||
| config_dict = cfg.to_dict() | ||
| sanitized_config = json.loads(json.dumps(config_dict, default=safe_serialize)) | ||
| flat_params = _flatten_dict(sanitized_config) | ||
|
|
||
| # Start or resume a run | ||
| run_name = logger_cfg.mlflow_run_name | ||
| tags = logger_cfg.mlflow_tags or {} | ||
|
|
||
| active_run = mlflow.active_run() | ||
| if active_run is None: | ||
| mlflow.start_run(run_name=run_name, tags=tags or None) | ||
| elif tags: | ||
| # If there is already an active run, at least set provided tags | ||
| mlflow.set_tags(tags) | ||
|
|
||
| # Log flattened configuration as params (best-effort) | ||
| stringified_params = { | ||
| key: (safe_serialize(value) if not isinstance(value, (int, float, bool, str)) else value) | ||
| for key, value in flat_params.items() | ||
| } | ||
| mlflow.log_params(stringified_params) | ||
| self._mlflow_logger = mlflow | ||
| else: | ||
| self._mlflow_logger = None | ||
| return self._mlflow_logger | ||
|
|
||
| @property | ||
| def timers(self) -> Timers: | ||
| """The Megatron Timers instance used for tracking execution times.""" | ||
| if self._timers is None: | ||
| self._timers = Timers(self.cfg.logger.timing_log_level, self.cfg.logger.timing_log_option) | ||
| self._timers.write_to_wandb = types.MethodType(_timers_write_to_wandb, self._timers) | ||
| self._timers.write_to_mlflow = types.MethodType(_timers_write_to_mlflow, self._timers) | ||
| return self._timers | ||
|
|
||
| @property | ||
|
|
@@ -344,6 +393,7 @@ def reset_for_restart(self) -> None: | |
| self._train_state = None | ||
| self._tensorboard_logger = None | ||
| self._wandb_logger = None | ||
| self._mlflow_logger = None | ||
| self._energy_monitor = None | ||
| self._energy_monitor_created = False | ||
| self._signal_handler = None | ||
|
|
@@ -371,3 +421,29 @@ def _timers_write_to_wandb( | |
| for name in name_to_min_max_time: | ||
| _, max_time = name_to_min_max_time[name] | ||
| writer.log({name + "-time": max_time}, iteration) | ||
|
|
||
|
|
||
| def _timers_write_to_mlflow( | ||
| self: Timers, | ||
| names: list[str], | ||
| logger: Any, | ||
| iteration: int, | ||
| normalizer: float = 1.0, | ||
| reset: bool = True, | ||
| barrier: bool = False, | ||
| ) -> None: | ||
| """Patch to write timers to MLFlow for Megatron Core Timers.""" | ||
| assert normalizer > 0.0 | ||
| name_to_min_max_time = self._get_global_min_max_time(names, reset, barrier, normalizer) | ||
| if logger is not None: | ||
| metrics: dict[str, float] = {} | ||
| for name in name_to_min_max_time: | ||
| _, max_time = name_to_min_max_time[name] | ||
| sanitized_name = name.replace("/", "_") + "-time" | ||
| metrics[sanitized_name] = max_time | ||
| try: | ||
| logger.log_metrics(metrics, step=iteration) | ||
| except Exception: | ||
| import warnings | ||
|
|
||
| warnings.warn("Failed to log timer metrics to MLFlow; continuing without timer metrics.") | ||
yaoyu-33 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.