Conversation
Add comprehensive MLFlow support to Megatron Bridge for experiment tracking and artifact logging. - Add MLFlow logger support in GlobalState with configurable experiment, run name, tracking URI, and tags - Log training metrics (losses, learning rate, batch size, throughput, timers, memory, runtime, norms, energy) to MLFlow - Log checkpoint artifacts to MLFlow with iteration-based artifact paths - Add MLFlow configuration options to LoggerConfig (mlflow_experiment, mlflow_run_name, mlflow_tracking_uri, mlflow_tags) - Add validation in LoggerConfig.finalize() to check MLFlow availability - Move safe_serialize to log_utils.py for reuse across WandB and MLFlow - Add comprehensive unit tests for MLFlow utilities - Add documentation for MLFlow logging configuration Based on community contribution from @therealnaveenkamal in PR #1542. Co-authored-by: Naveenraj Kamalakannan <therealnaveenkamal@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR adds comprehensive MLflow integration for training logging, mirroring the existing WandB support. Changes include configuration options for MLflow, metrics logging during training iterations, checkpoint artifact management, state initialization with lazy logger instantiation, and full unit test coverage. Changes
Sequence DiagramsequenceDiagram
participant Training Loop
participant GlobalState
participant MLflow Logger
participant Checkpointing
participant MLflow API
Training Loop->>GlobalState: Access mlflow_logger property
GlobalState->>GlobalState: Lazy initialize MLflow client
GlobalState->>MLflow API: Set experiment & start run
MLflow API-->>GlobalState: Return active run
Training Loop->>MLflow Logger: Log metrics (loss, accuracy, etc.)
MLflow Logger->>MLflow API: Log sanitized metrics
MLflow API->>MLflow API: Store metrics
Training Loop->>Checkpointing: Save checkpoint
Checkpointing->>MLflow Logger: on_save_checkpoint_success()
MLflow Logger->>MLflow API: Log checkpoint as artifact
MLflow API->>MLflow API: Store artifact reference
Training Loop->>Checkpointing: Load checkpoint (resume)
Checkpointing->>MLflow Logger: on_load_checkpoint_success()
MLflow Logger->>MLflow API: Set tags (last_loaded_checkpoint)
MLflow API->>MLflow API: Update run tags
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
🤖 Fix all issues with AI agents
In `@docs/training/logging.md`:
- Around line 171-179: Split the combined install example into two separate
lines so the shell commands are clear (replace the single-line "pip install
mlflow / uv add mlflow" with two examples: "pip install mlflow" and "uv add
mlflow"), and fix the markdown list indentation under the tracking server
section by unindenting the two list items so they start at column 0 (replace the
indented "- Either set `MLFLOW_TRACKING_URI`..." and "- Pass an explicit
`mlflow_tracking_uri`..." with top-level list items).
In `@pyproject.toml`:
- Line 91: Revert the direct edit that added the "mlflow>=3.2.0" line to
pyproject.toml and instead add the dependency using the project dependency
manager by running the command uv add mlflow; ensure the repository no longer
contains the manual "mlflow>=3.2.0" entry (if committed, revert that change) so
the dependency is recorded and versioned via the uv tool rather than direct file
edits.
In `@src/megatron/bridge/training/config.py`:
- Around line 1003-1026: The long literal in the raise inside finalize()
triggers Ruff TRY003; replace the long ModuleNotFoundError raise with a short
custom exception (e.g., MLFlowNotInstalledError) declared in this module and
raise that with a short message, moving the verbose install guidance into the
exception class docstring or module-level constant, or alternatively add a
targeted "# noqa: TRY003" comment on the raise lines if you prefer not to
introduce a new exception; update the two raise sites in finalize() (the
ModuleNotFoundError raise and the ValueError line if it also exceeds length) to
use the chosen approach so Ruff no longer flags TRY003.
In `@src/megatron/bridge/training/state.py`:
- Around line 426-449: In _timers_write_to_mlflow, update the broad except block
around logger.log_metrics to capture the exception into a variable (e.g., except
Exception as e) and call warnings.warn with an informative message including the
exception details and stacklevel=2; if you prefer to keep a broad catch, add an
explicit noqa (e.g., "# noqa: BLE001,B028") to the except line so Ruff is
satisfied, otherwise narrow the exception type to the specific errors
logger.log_metrics can raise; ensure you reference the logger.log_metrics call
and the _timers_write_to_mlflow function when making the change.
- Around line 219-235: The raise in the mlflow_logger property (check
logger_cfg.mlflow_run_name) currently uses a long literal which trips Ruff
TRY003; fix it by either replacing the long literal with a short message and
raising a custom exception class (e.g., define MissingMLFlowRunNameError and
raise that from mlflow_logger) or, if you prefer minimal change, append a
targeted inline noqa to the raise line (raise ValueError("Specify
mlflow_run_name") # noqa: TRY003) so the linter is satisfied; update references
to logger_cfg.mlflow_run_name and the mlflow_logger property accordingly.
In `@src/megatron/bridge/training/utils/mlflow_utils.py`:
- Around line 43-51: The broad "except Exception as exc:" in the MLFlow artifact
logging blocks should be made lint-safe: replace it with a narrow set of
expected exceptions (e.g., "except (mlflow.exceptions.MlflowException, OSError,
ValueError) as exc:" and add "from mlflow.exceptions import MlflowException" at
top) and apply the same change to the second try/except block that wraps
mlflow_logger.log_artifacts (the code that references checkpoint_path,
expected_ckpt_path, artifact_subdir and print_rank_last); if you intentionally
want a broad catch keep the current handler but annotate it with "# noqa:
BLE001" on each except line instead of leaving the generic catch.
In `@tests/unit_tests/training/utils/test_mlflow_utils.py`:
- Around line 26-296: Add pytest.mark.unit markers to these unit tests by
importing pytest (if not present) and decorating the test classes or individual
test functions with `@pytest.mark.unit`; apply to TestOnSaveCheckpointSuccess,
TestOnLoadCheckpointSuccess, and TestSanitizeMlflowMetrics (or their contained
test_* methods) so the test runner can categorize them properly and marker-based
runs include them.
In `@tutorials/recipes/llama/conf/llama32_1b_finetune.yaml`:
- Around line 75-83: The mlflow_tracking_uri key is incorrectly nested under
mlflow_tags in the YAML; move mlflow_tracking_uri out of the mlflow_tags block
so it sits at the same level as mlflow_experiment/mlflow_run_name, ensuring
mlflow_tags only contains tag mappings (e.g., project, phase, variant), and
remove the trailing blank lines at the end of the file; update items around
mlflow_experiment, mlflow_run_name, mlflow_tags, and mlflow_tracking_uri
accordingly.
In `@tutorials/recipes/llama/conf/llama32_1b_pretrain.yaml`:
- Around line 65-71: The mlflow_tracking_uri key is incorrectly nested under
mlflow_tags; update the LoggerConfig by removing mlflow_tracking_uri from inside
mlflow_tags and placing it at the same indentation level as mlflow_experiment
and mlflow_run_name so mlflow_experiment, mlflow_run_name, mlflow_tags, and
mlflow_tracking_uri are all top-level keys in the config (reference keys:
mlflow_experiment, mlflow_run_name, mlflow_tags, mlflow_tracking_uri,
LoggerConfig).
🧹 Nitpick comments (3)
docs/training/logging.md (1)
197-199: Consider adding a note about lazy initialization similar to WandB.The WandB section includes a note about lazy initialization on the last rank. For consistency and to help users understand the behavior, consider adding a similar note for MLFlow.
📝 Suggested addition
}, )
+
{note} +MLFlow is initialized lazily on the last rank when `mlflow_experiment` is set and `mlflow_run_name` is non-empty. +</details> </blockquote></details> <details> <summary>src/megatron/bridge/training/utils/log_utils.py (1)</summary><blockquote> `178-194`: **Good defensive implementation. Minor refactor suggestion.** The `safe_serialize` function correctly handles edge cases with broken `__str__` methods. Consider the minor refactor suggested by static analysis (TRY300) to move the return to an `else` block for cleaner control flow. The docstring mentions "JSON-serializable type" but the function always returns a `str`. Consider updating the docstring to be more precise. <details> <summary>♻️ Suggested refactor</summary> ```diff def safe_serialize(obj) -> str: - """Safely convert any object to a JSON-serializable type. + """Safely convert any object to a string representation. Handles objects with broken __str__ or __repr__ methods that return non-string types (e.g., PipelineParallelLayerLayout returns list). + + Args: + obj: Any object to serialize. + + Returns: + String representation of the object, or type name as fallback. """ try: # Try str() first result = str(obj) # Verify it actually returns a string if not isinstance(result, str): # __str__ returned non-string type, use type name instead return f"<{type(obj).__name__}>" - return result except Exception: # __str__ raised an exception, use type name as fallback return f"<{type(obj).__name__}>" + else: + return resultsrc/megatron/bridge/training/utils/mlflow_utils.py (1)
15-27: PreferT | NoneoverOptional[T]for new APIs.
Align type hints with the project’s Python 3.10+ style.As per coding guidelines: Use 'T | None' for nullable types instead of 'Optional[T]'.♻️ Suggested refactor
-from typing import Any, Optional +from typing import Any @@ -def on_save_checkpoint_success( - checkpoint_path: str, - save_dir: str, - iteration: int, - mlflow_logger: Optional[Any], -) -> None: +def on_save_checkpoint_success( + checkpoint_path: str, + save_dir: str, + iteration: int, + mlflow_logger: Any | None, +) -> None: @@ -def on_load_checkpoint_success( - checkpoint_path: str, - load_dir: str, - mlflow_logger: Optional[Any], -) -> None: +def on_load_checkpoint_success( + checkpoint_path: str, + load_dir: str, + mlflow_logger: Any | None, +) -> None:
| @property | ||
| def mlflow_logger(self) -> Optional[Any]: | ||
| """The MLFlow logger instance. | ||
|
|
||
| Uses the configuration under LoggerConfig to create or resume an MLFlow run. | ||
| Restricted to the last rank to avoid duplicate entries or probably any racing conditions. | ||
| """ | ||
| if self._mlflow_logger is None: | ||
| cfg = self.cfg | ||
| if cfg is None: | ||
| self._mlflow_logger = None | ||
| return self._mlflow_logger | ||
|
|
||
| logger_cfg = cfg.logger | ||
| if logger_cfg.mlflow_experiment and get_rank_safe() == (get_world_size_safe() - 1): | ||
| if logger_cfg.mlflow_run_name == "": | ||
| raise ValueError("Please specify the mlflow_run_name for MLFlow logging!") |
There was a problem hiding this comment.
Ruff TRY003: long raise message will trip lint.
Consider a custom exception or add a targeted # noqa: TRY003 to the raise line.
✅ Suggested lint-safe tweak
- raise ValueError("Please specify the mlflow_run_name for MLFlow logging!")
+ raise ValueError(
+ "Please specify the mlflow_run_name for MLFlow logging!"
+ ) # noqa: TRY003🧰 Tools
🪛 Ruff (0.14.14)
235-235: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/state.py` around lines 219 - 235, The raise in
the mlflow_logger property (check logger_cfg.mlflow_run_name) currently uses a
long literal which trips Ruff TRY003; fix it by either replacing the long
literal with a short message and raising a custom exception class (e.g., define
MissingMLFlowRunNameError and raise that from mlflow_logger) or, if you prefer
minimal change, append a targeted inline noqa to the raise line (raise
ValueError("Specify mlflow_run_name") # noqa: TRY003) so the linter is
satisfied; update references to logger_cfg.mlflow_run_name and the mlflow_logger
property accordingly.
| try: | ||
| checkpoint_path = str(Path(checkpoint_path).resolve()) | ||
| base_name = Path(save_dir).name or "checkpoints" | ||
| expected_ckpt_path = get_checkpoint_name(save_dir, iteration) | ||
| artifact_subdir = f"{base_name}/{Path(expected_ckpt_path).name}" | ||
| mlflow_logger.log_artifacts(checkpoint_path, artifact_path=artifact_subdir) | ||
| except Exception as exc: | ||
| # continue training | ||
| print_rank_last(f"Failed to log checkpoint artifacts to MLFlow: {exc}") |
There was a problem hiding this comment.
Ruff BLE001: broad Exception catch needs explicit intent.
Either narrow the exception types or add a targeted # noqa: BLE001 to keep lint clean.
✅ Suggested lint-safe tweak
- except Exception as exc:
+ except Exception as exc: # noqa: BLE001
@@
- except Exception as exc:
+ except Exception as exc: # noqa: BLE001Also applies to: 73-83
🧰 Tools
🪛 Ruff (0.14.14)
49-49: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/mlflow_utils.py` around lines 43 - 51, The
broad "except Exception as exc:" in the MLFlow artifact logging blocks should be
made lint-safe: replace it with a narrow set of expected exceptions (e.g.,
"except (mlflow.exceptions.MlflowException, OSError, ValueError) as exc:" and
add "from mlflow.exceptions import MlflowException" at top) and apply the same
change to the second try/except block that wraps mlflow_logger.log_artifacts
(the code that references checkpoint_path, expected_ckpt_path, artifact_subdir
and print_rank_last); if you intentionally want a broad catch keep the current
handler but annotate it with "# noqa: BLE001" on each except line instead of
leaving the generic catch.
| class TestOnSaveCheckpointSuccess: | ||
| """Test cases for on_save_checkpoint_success function.""" | ||
|
|
||
| def test_noop_when_mlflow_logger_is_none(self): | ||
| """Test that the function does nothing when mlflow_logger is None.""" | ||
| # Should not raise any exception | ||
| on_save_checkpoint_success( | ||
| checkpoint_path="/path/to/checkpoint", | ||
| save_dir="/path/to", | ||
| iteration=100, | ||
| mlflow_logger=None, | ||
| ) | ||
|
|
||
| def test_logs_artifacts_with_correct_path(self): | ||
| """Test that log_artifacts is called with correct arguments.""" | ||
| mock_mlflow = MagicMock() | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| checkpoint_path = Path(tmpdir) / "checkpoint" | ||
| checkpoint_path.mkdir() | ||
| save_dir = tmpdir | ||
|
|
||
| on_save_checkpoint_success( | ||
| checkpoint_path=str(checkpoint_path), | ||
| save_dir=save_dir, | ||
| iteration=1000, | ||
| mlflow_logger=mock_mlflow, | ||
| ) | ||
|
|
||
| mock_mlflow.log_artifacts.assert_called_once() | ||
| call_args = mock_mlflow.log_artifacts.call_args | ||
|
|
||
| # Verify the checkpoint path is resolved | ||
| assert call_args[0][0] == str(checkpoint_path.resolve()) | ||
|
|
||
| # Verify artifact_path format includes iteration | ||
| artifact_path = call_args[1]["artifact_path"] | ||
| base_name = Path(save_dir).name | ||
| assert artifact_path == f"{base_name}/iter_0001000" | ||
|
|
||
| def test_artifact_path_format_with_different_iterations(self): | ||
| """Test that iteration is zero-padded to 7 digits in artifact path.""" | ||
| mock_mlflow = MagicMock() | ||
|
|
||
| test_cases = [ | ||
| (0, "iter_0000000"), | ||
| (1, "iter_0000001"), | ||
| (999, "iter_0000999"), | ||
| (1234567, "iter_1234567"), | ||
| (9999999, "iter_9999999"), | ||
| ] | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| checkpoint_path = Path(tmpdir) / "checkpoint" | ||
| checkpoint_path.mkdir() | ||
|
|
||
| for iteration, expected_suffix in test_cases: | ||
| mock_mlflow.reset_mock() | ||
|
|
||
| on_save_checkpoint_success( | ||
| checkpoint_path=str(checkpoint_path), | ||
| save_dir=tmpdir, | ||
| iteration=iteration, | ||
| mlflow_logger=mock_mlflow, | ||
| ) | ||
|
|
||
| artifact_path = mock_mlflow.log_artifacts.call_args[1]["artifact_path"] | ||
| assert artifact_path.endswith(expected_suffix), ( | ||
| f"Expected artifact_path to end with {expected_suffix}, got {artifact_path}" | ||
| ) | ||
|
|
||
| def test_uses_checkpoints_as_default_base_name(self): | ||
| """Test that 'checkpoints' is used when save_dir has no name.""" | ||
| mock_mlflow = MagicMock() | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| checkpoint_path = Path(tmpdir) / "checkpoint" | ||
| checkpoint_path.mkdir() | ||
|
|
||
| # Use root-like path that would have empty name | ||
| on_save_checkpoint_success( | ||
| checkpoint_path=str(checkpoint_path), | ||
| save_dir="/", | ||
| iteration=100, | ||
| mlflow_logger=mock_mlflow, | ||
| ) | ||
|
|
||
| artifact_path = mock_mlflow.log_artifacts.call_args[1]["artifact_path"] | ||
| assert artifact_path.startswith("checkpoints/") | ||
|
|
||
| def test_handles_exception_gracefully(self): | ||
| """Test that exceptions are caught and logged, not raised.""" | ||
| mock_mlflow = MagicMock() | ||
| mock_mlflow.log_artifacts.side_effect = Exception("MLFlow connection error") | ||
|
|
||
| with patch("megatron.bridge.training.utils.mlflow_utils.print_rank_last") as mock_print: | ||
| # Should not raise exception | ||
| on_save_checkpoint_success( | ||
| checkpoint_path="/path/to/checkpoint", | ||
| save_dir="/path/to", | ||
| iteration=100, | ||
| mlflow_logger=mock_mlflow, | ||
| ) | ||
|
|
||
| # Should print error message | ||
| mock_print.assert_called_once() | ||
| error_msg = mock_print.call_args[0][0] | ||
| assert "Failed to log checkpoint artifacts to MLFlow" in error_msg | ||
| assert "MLFlow connection error" in error_msg | ||
|
|
||
|
|
||
| class TestOnLoadCheckpointSuccess: | ||
| """Test cases for on_load_checkpoint_success function.""" | ||
|
|
||
| def test_noop_when_mlflow_logger_is_none(self): | ||
| """Test that the function does nothing when mlflow_logger is None.""" | ||
| # Should not raise any exception | ||
| on_load_checkpoint_success( | ||
| checkpoint_path="/path/to/checkpoint", | ||
| load_dir="/path/to", | ||
| mlflow_logger=None, | ||
| ) | ||
|
|
||
| def test_sets_correct_tags(self): | ||
| """Test that set_tags is called with correct checkpoint information.""" | ||
| mock_mlflow = MagicMock() | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| checkpoint_path = Path(tmpdir) / "checkpoint" | ||
| checkpoint_path.mkdir() | ||
| load_dir = tmpdir | ||
|
|
||
| on_load_checkpoint_success( | ||
| checkpoint_path=str(checkpoint_path), | ||
| load_dir=load_dir, | ||
| mlflow_logger=mock_mlflow, | ||
| ) | ||
|
|
||
| mock_mlflow.set_tags.assert_called_once() | ||
| tags = mock_mlflow.set_tags.call_args[0][0] | ||
|
|
||
| assert "last_loaded_checkpoint" in tags | ||
| assert "checkpoint_base_dir" in tags | ||
| assert tags["last_loaded_checkpoint"] == str(checkpoint_path.resolve()) | ||
| assert tags["checkpoint_base_dir"] == str(Path(load_dir).resolve()) | ||
|
|
||
| def test_resolves_relative_paths(self): | ||
| """Test that relative paths are resolved to absolute paths.""" | ||
| mock_mlflow = MagicMock() | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| checkpoint_path = Path(tmpdir) / "checkpoint" | ||
| checkpoint_path.mkdir() | ||
|
|
||
| on_load_checkpoint_success( | ||
| checkpoint_path=str(checkpoint_path), | ||
| load_dir=tmpdir, | ||
| mlflow_logger=mock_mlflow, | ||
| ) | ||
|
|
||
| tags = mock_mlflow.set_tags.call_args[0][0] | ||
|
|
||
| # Both paths should be absolute (resolved) | ||
| assert Path(tags["last_loaded_checkpoint"]).is_absolute() | ||
| assert Path(tags["checkpoint_base_dir"]).is_absolute() | ||
|
|
||
| def test_handles_exception_gracefully(self): | ||
| """Test that exceptions are caught and logged, not raised.""" | ||
| mock_mlflow = MagicMock() | ||
| mock_mlflow.set_tags.side_effect = Exception("MLFlow API error") | ||
|
|
||
| with patch("megatron.bridge.training.utils.mlflow_utils.print_rank_last") as mock_print: | ||
| # Should not raise exception | ||
| on_load_checkpoint_success( | ||
| checkpoint_path="/path/to/checkpoint", | ||
| load_dir="/path/to", | ||
| mlflow_logger=mock_mlflow, | ||
| ) | ||
|
|
||
| # Should print error message | ||
| mock_print.assert_called_once() | ||
| error_msg = mock_print.call_args[0][0] | ||
| assert "Failed to record loaded checkpoint information to MLFlow" in error_msg | ||
| assert "MLFlow API error" in error_msg | ||
|
|
||
|
|
||
| class TestSanitizeMlflowMetrics: | ||
| """Test cases for _sanitize_mlflow_metrics function.""" | ||
|
|
||
| def test_replaces_slashes_with_underscores(self): | ||
| """Test that forward slashes are replaced with underscores.""" | ||
| metrics = { | ||
| "train/loss": 0.5, | ||
| "train/accuracy": 0.95, | ||
| "eval/loss": 0.3, | ||
| } | ||
|
|
||
| result = _sanitize_mlflow_metrics(metrics) | ||
|
|
||
| assert result == { | ||
| "train_loss": 0.5, | ||
| "train_accuracy": 0.95, | ||
| "eval_loss": 0.3, | ||
| } | ||
|
|
||
| def test_handles_multiple_slashes(self): | ||
| """Test that multiple slashes in a key are all replaced.""" | ||
| metrics = { | ||
| "train/layer/0/loss": 1.0, | ||
| "model/encoder/attention/weight": 0.5, | ||
| } | ||
|
|
||
| result = _sanitize_mlflow_metrics(metrics) | ||
|
|
||
| assert result == { | ||
| "train_layer_0_loss": 1.0, | ||
| "model_encoder_attention_weight": 0.5, | ||
| } | ||
|
|
||
| def test_preserves_keys_without_slashes(self): | ||
| """Test that keys without slashes are unchanged.""" | ||
| metrics = { | ||
| "loss": 0.5, | ||
| "accuracy": 0.95, | ||
| "learning_rate": 0.001, | ||
| } | ||
|
|
||
| result = _sanitize_mlflow_metrics(metrics) | ||
|
|
||
| assert result == metrics | ||
|
|
||
| def test_handles_empty_dict(self): | ||
| """Test that empty dictionary returns empty dictionary.""" | ||
| result = _sanitize_mlflow_metrics({}) | ||
| assert result == {} | ||
|
|
||
| def test_preserves_values(self): | ||
| """Test that metric values are preserved unchanged.""" | ||
| metrics = { | ||
| "train/int_metric": 42, | ||
| "train/float_metric": 3.14159, | ||
| "train/string_metric": "value", | ||
| "train/none_metric": None, | ||
| "train/list_metric": [1, 2, 3], | ||
| } | ||
|
|
||
| result = _sanitize_mlflow_metrics(metrics) | ||
|
|
||
| assert result["train_int_metric"] == 42 | ||
| assert result["train_float_metric"] == 3.14159 | ||
| assert result["train_string_metric"] == "value" | ||
| assert result["train_none_metric"] is None | ||
| assert result["train_list_metric"] == [1, 2, 3] | ||
|
|
||
| def test_mixed_keys(self): | ||
| """Test dictionary with both slash and non-slash keys.""" | ||
| metrics = { | ||
| "train/loss": 0.5, | ||
| "global_step": 1000, | ||
| "eval/accuracy": 0.9, | ||
| "learning_rate": 0.001, | ||
| } | ||
|
|
||
| result = _sanitize_mlflow_metrics(metrics) | ||
|
|
||
| assert result == { | ||
| "train_loss": 0.5, | ||
| "global_step": 1000, | ||
| "eval_accuracy": 0.9, | ||
| "learning_rate": 0.001, | ||
| } |
There was a problem hiding this comment.
Add pytest markers for unit-test categorization.
Missing pytest.mark.unit breaks the test categorization guideline and may skip these tests in marker-based runs.
✅ Suggested update
+import pytest
+
+pytestmark = pytest.mark.unit🤖 Prompt for AI Agents
In `@tests/unit_tests/training/utils/test_mlflow_utils.py` around lines 26 - 296,
Add pytest.mark.unit markers to these unit tests by importing pytest (if not
present) and decorating the test classes or individual test functions with
`@pytest.mark.unit`; apply to TestOnSaveCheckpointSuccess,
TestOnLoadCheckpointSuccess, and TestSanitizeMlflowMetrics (or their contained
test_* methods) so the test runner can categorize them properly and marker-based
runs include them.
| # mlflow_experiment: llama32_1b_finetuned # Uncomment to enable MLFlow logging | ||
| # mlflow_run_name: llama32_1b_finetuned | ||
| # mlflow_tags: | ||
| # project: llama32 | ||
| # phase: finetune | ||
| # variant: mlflow_example | ||
| # mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server | ||
|
|
||
|
|
There was a problem hiding this comment.
Same nesting issue: mlflow_tracking_uri should not be inside mlflow_tags.
Similar to the pretrain config, mlflow_tracking_uri is incorrectly nested inside the mlflow_tags block. Also, there are trailing blank lines that could be removed.
🔧 Proposed fix
# mlflow_experiment: llama32_1b_finetuned # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_finetuned
# mlflow_tags:
# project: llama32
# phase: finetune
# variant: mlflow_example
- # mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server
-
-
+ # mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server
# Random seed🤖 Prompt for AI Agents
In `@tutorials/recipes/llama/conf/llama32_1b_finetune.yaml` around lines 75 - 83,
The mlflow_tracking_uri key is incorrectly nested under mlflow_tags in the YAML;
move mlflow_tracking_uri out of the mlflow_tags block so it sits at the same
level as mlflow_experiment/mlflow_run_name, ensuring mlflow_tags only contains
tag mappings (e.g., project, phase, variant), and remove the trailing blank
lines at the end of the file; update items around mlflow_experiment,
mlflow_run_name, mlflow_tags, and mlflow_tracking_uri accordingly.
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
mlflow_tracking_uri was incorrectly nested inside mlflow_tags. Move it to be a top-level key alongside mlflow_experiment, mlflow_run_name, and mlflow_tags in LoggerConfig.
Add @pytest.mark.unit decorators to TestOnSaveCheckpointSuccess, TestOnLoadCheckpointSuccess, and TestSanitizeMlflowMetrics classes so the test runner can categorize them properly.
|
/ok to test fa03b2c |
|
Thanks for continuing this work, @yaoyu-33 |
What does this PR do?
This PR picks up and continues the excellent community contribution from #1542 by @therealnaveenkamal to add comprehensive MLFlow support to Megatron Bridge.
Features
mlflow_experiment: The MLFlow experiment namemlflow_run_name: The MLFlow run namemlflow_tracking_uri: Optional MLFlow tracking URImlflow_tags: Optional tags to apply to the MLFlow runLoggerConfig.finalize()to check MLFlow availabilitysafe_serializetolog_utils.pyfor reuse across WandB and MLFlowUsage
Attribution
This PR is based on the community contribution in #1542 by @therealnaveenkamal. Thank you for the excellent work!
Supersedes #1542
Before your PR is "Ready for review"
Pre checks:
cc @therealnaveenkamal @Phlip79
Summary by CodeRabbit
New Features
Documentation
Chores
✏️ Tip: You can customize this high-level summary in your review settings.