Skip to content

MLFlow Integration#2112

Merged
yaoyu-33 merged 8 commits intomainfrom
yuya/mlflow-integration
Jan 29, 2026
Merged

MLFlow Integration#2112
yaoyu-33 merged 8 commits intomainfrom
yuya/mlflow-integration

Conversation

@yaoyu-33
Copy link
Contributor

@yaoyu-33 yaoyu-33 commented Jan 28, 2026

What does this PR do?

This PR picks up and continues the excellent community contribution from #1542 by @therealnaveenkamal to add comprehensive MLFlow support to Megatron Bridge.

Features

  • MLFlow Logger Integration: Add MLFlow logger support in GlobalState with configurable experiment, run name, tracking URI, and tags
  • Metrics Logging: Log training metrics (losses, learning rate, batch size, throughput, timers, memory, runtime, norms, energy) to MLFlow
  • Checkpoint Artifacts: Log checkpoint artifacts to MLFlow with iteration-based artifact paths
  • Configuration: Add MLFlow configuration options to LoggerConfig:
    • mlflow_experiment: The MLFlow experiment name
    • mlflow_run_name: The MLFlow run name
    • mlflow_tracking_uri: Optional MLFlow tracking URI
    • mlflow_tags: Optional tags to apply to the MLFlow run
  • Validation: Add validation in LoggerConfig.finalize() to check MLFlow availability
  • Code Refactoring: Move safe_serialize to log_utils.py for reuse across WandB and MLFlow
  • Tests: Add comprehensive unit tests for MLFlow utilities
  • Documentation: Add MLFlow logging configuration documentation

Usage

from megatron.bridge.training.config import LoggerConfig

cfg.logger = LoggerConfig(
    tensorboard_dir="./runs/tensorboard",
    mlflow_experiment="my_megatron_experiment",
    mlflow_run_name="llama32_1b_pretrain_run",
    mlflow_tracking_uri="http://mlflow:5000",  # optional
    mlflow_tags={                              # optional
        "project": "llama32",
        "phase": "pretrain",
    },
)

Attribution

This PR is based on the community contribution in #1542 by @therealnaveenkamal. Thank you for the excellent work!

Supersedes #1542


Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • MLFlow is added as a required dependency in pyproject.toml

cc @therealnaveenkamal @Phlip79

Summary by CodeRabbit

  • New Features

    • Added MLFlow integration for comprehensive training logging, supporting metrics, checkpoints, and artifacts alongside WandB and TensorBoard.
    • Configure MLFlow experiment tracking with customizable run names, tags, and tracking URIs.
  • Documentation

    • Added MLFlow logging setup and configuration guide with installation and enablement instructions.
  • Chores

    • Added mlflow as a runtime dependency.

✏️ Tip: You can customize this high-level summary in your review settings.

Add comprehensive MLFlow support to Megatron Bridge for experiment tracking and artifact logging.

- Add MLFlow logger support in GlobalState with configurable experiment, run name, tracking URI, and tags
- Log training metrics (losses, learning rate, batch size, throughput, timers, memory, runtime, norms, energy) to MLFlow
- Log checkpoint artifacts to MLFlow with iteration-based artifact paths
- Add MLFlow configuration options to LoggerConfig (mlflow_experiment, mlflow_run_name, mlflow_tracking_uri, mlflow_tags)
- Add validation in LoggerConfig.finalize() to check MLFlow availability
- Move safe_serialize to log_utils.py for reuse across WandB and MLFlow
- Add comprehensive unit tests for MLFlow utilities
- Add documentation for MLFlow logging configuration

Based on community contribution from @therealnaveenkamal in PR #1542.

Co-authored-by: Naveenraj Kamalakannan <therealnaveenkamal@users.noreply.github.com>
@yaoyu-33 yaoyu-33 requested a review from a team as a code owner January 28, 2026 21:32
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

This PR adds comprehensive MLflow integration for training logging, mirroring the existing WandB support. Changes include configuration options for MLflow, metrics logging during training iterations, checkpoint artifact management, state initialization with lazy logger instantiation, and full unit test coverage.

Changes

Cohort / File(s) Summary
Configuration & Dependencies
pyproject.toml, src/megatron/bridge/training/config.py, docs/training/logging.md
Added mlflow>=3.2.0 dependency. Introduced four MLflow config fields (experiment, run_name, tracking_uri, tags) to LoggerConfig with validation in new finalize() method. Added MLflow logging documentation mirroring W&B pattern.
Core Logging Infrastructure
src/megatron/bridge/training/state.py, src/megatron/bridge/training/utils/log_utils.py, src/megatron/bridge/training/utils/mlflow_utils.py
Added mlflow_logger property to GlobalState with lazy initialization gated to last rank. Introduced safe_serialize() utility for safe object-to-string conversion. Created mlflow_utils.py with checkpoint success callbacks (on_save/load) and metric sanitization function. Extended Timers integration with _timers_write_to_mlflow().
Training Integration
src/megatron/bridge/training/utils/train_utils.py, src/megatron/bridge/training/checkpointing.py
Extended train_utils.py to log training metrics to MLflow across multiple logging points with _sanitize_mlflow_metrics(). Updated checkpointing.py to register MLflow finalization callbacks alongside WandB, executing in both async and sync save paths and on checkpoint load.
Testing & Examples
tests/unit_tests/training/utils/test_mlflow_utils.py, tutorials/recipes/llama/conf/llama32_1b_pretrain.yaml, tutorials/recipes/llama/conf/llama32_1b_finetune.yaml
Added 296 lines of comprehensive unit tests covering checkpoint logging, tag setting, metric sanitization, error handling, and path resolution. Updated tutorial configs with commented MLflow configuration blocks and adjusted batch size parameters.

Sequence Diagram

sequenceDiagram
    participant Training Loop
    participant GlobalState
    participant MLflow Logger
    participant Checkpointing
    participant MLflow API

    Training Loop->>GlobalState: Access mlflow_logger property
    GlobalState->>GlobalState: Lazy initialize MLflow client
    GlobalState->>MLflow API: Set experiment & start run
    MLflow API-->>GlobalState: Return active run
    
    Training Loop->>MLflow Logger: Log metrics (loss, accuracy, etc.)
    MLflow Logger->>MLflow API: Log sanitized metrics
    MLflow API->>MLflow API: Store metrics
    
    Training Loop->>Checkpointing: Save checkpoint
    Checkpointing->>MLflow Logger: on_save_checkpoint_success()
    MLflow Logger->>MLflow API: Log checkpoint as artifact
    MLflow API->>MLflow API: Store artifact reference
    
    Training Loop->>Checkpointing: Load checkpoint (resume)
    Checkpointing->>MLflow Logger: on_load_checkpoint_success()
    MLflow Logger->>MLflow API: Set tags (last_loaded_checkpoint)
    MLflow API->>MLflow API: Update run tags
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested reviewers

  • cuichenx
🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'MLFlow Integration' directly and accurately summarizes the primary change—adding MLFlow support throughout the codebase.
Docstring Coverage ✅ Passed Docstring coverage is 91.43% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed PR description documents tests and documentation added for MLFlow integration major feature with 296 lines of unit tests and comprehensive coverage.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

🤖 Fix all issues with AI agents
In `@docs/training/logging.md`:
- Around line 171-179: Split the combined install example into two separate
lines so the shell commands are clear (replace the single-line "pip install
mlflow / uv add mlflow" with two examples: "pip install mlflow" and "uv add
mlflow"), and fix the markdown list indentation under the tracking server
section by unindenting the two list items so they start at column 0 (replace the
indented "- Either set `MLFLOW_TRACKING_URI`..." and "- Pass an explicit
`mlflow_tracking_uri`..." with top-level list items).

In `@pyproject.toml`:
- Line 91: Revert the direct edit that added the "mlflow>=3.2.0" line to
pyproject.toml and instead add the dependency using the project dependency
manager by running the command uv add mlflow; ensure the repository no longer
contains the manual "mlflow>=3.2.0" entry (if committed, revert that change) so
the dependency is recorded and versioned via the uv tool rather than direct file
edits.

In `@src/megatron/bridge/training/config.py`:
- Around line 1003-1026: The long literal in the raise inside finalize()
triggers Ruff TRY003; replace the long ModuleNotFoundError raise with a short
custom exception (e.g., MLFlowNotInstalledError) declared in this module and
raise that with a short message, moving the verbose install guidance into the
exception class docstring or module-level constant, or alternatively add a
targeted "# noqa: TRY003" comment on the raise lines if you prefer not to
introduce a new exception; update the two raise sites in finalize() (the
ModuleNotFoundError raise and the ValueError line if it also exceeds length) to
use the chosen approach so Ruff no longer flags TRY003.

In `@src/megatron/bridge/training/state.py`:
- Around line 426-449: In _timers_write_to_mlflow, update the broad except block
around logger.log_metrics to capture the exception into a variable (e.g., except
Exception as e) and call warnings.warn with an informative message including the
exception details and stacklevel=2; if you prefer to keep a broad catch, add an
explicit noqa (e.g., "# noqa: BLE001,B028") to the except line so Ruff is
satisfied, otherwise narrow the exception type to the specific errors
logger.log_metrics can raise; ensure you reference the logger.log_metrics call
and the _timers_write_to_mlflow function when making the change.
- Around line 219-235: The raise in the mlflow_logger property (check
logger_cfg.mlflow_run_name) currently uses a long literal which trips Ruff
TRY003; fix it by either replacing the long literal with a short message and
raising a custom exception class (e.g., define MissingMLFlowRunNameError and
raise that from mlflow_logger) or, if you prefer minimal change, append a
targeted inline noqa to the raise line (raise ValueError("Specify
mlflow_run_name")  # noqa: TRY003) so the linter is satisfied; update references
to logger_cfg.mlflow_run_name and the mlflow_logger property accordingly.

In `@src/megatron/bridge/training/utils/mlflow_utils.py`:
- Around line 43-51: The broad "except Exception as exc:" in the MLFlow artifact
logging blocks should be made lint-safe: replace it with a narrow set of
expected exceptions (e.g., "except (mlflow.exceptions.MlflowException, OSError,
ValueError) as exc:" and add "from mlflow.exceptions import MlflowException" at
top) and apply the same change to the second try/except block that wraps
mlflow_logger.log_artifacts (the code that references checkpoint_path,
expected_ckpt_path, artifact_subdir and print_rank_last); if you intentionally
want a broad catch keep the current handler but annotate it with "# noqa:
BLE001" on each except line instead of leaving the generic catch.

In `@tests/unit_tests/training/utils/test_mlflow_utils.py`:
- Around line 26-296: Add pytest.mark.unit markers to these unit tests by
importing pytest (if not present) and decorating the test classes or individual
test functions with `@pytest.mark.unit`; apply to TestOnSaveCheckpointSuccess,
TestOnLoadCheckpointSuccess, and TestSanitizeMlflowMetrics (or their contained
test_* methods) so the test runner can categorize them properly and marker-based
runs include them.

In `@tutorials/recipes/llama/conf/llama32_1b_finetune.yaml`:
- Around line 75-83: The mlflow_tracking_uri key is incorrectly nested under
mlflow_tags in the YAML; move mlflow_tracking_uri out of the mlflow_tags block
so it sits at the same level as mlflow_experiment/mlflow_run_name, ensuring
mlflow_tags only contains tag mappings (e.g., project, phase, variant), and
remove the trailing blank lines at the end of the file; update items around
mlflow_experiment, mlflow_run_name, mlflow_tags, and mlflow_tracking_uri
accordingly.

In `@tutorials/recipes/llama/conf/llama32_1b_pretrain.yaml`:
- Around line 65-71: The mlflow_tracking_uri key is incorrectly nested under
mlflow_tags; update the LoggerConfig by removing mlflow_tracking_uri from inside
mlflow_tags and placing it at the same indentation level as mlflow_experiment
and mlflow_run_name so mlflow_experiment, mlflow_run_name, mlflow_tags, and
mlflow_tracking_uri are all top-level keys in the config (reference keys:
mlflow_experiment, mlflow_run_name, mlflow_tags, mlflow_tracking_uri,
LoggerConfig).
🧹 Nitpick comments (3)
docs/training/logging.md (1)

197-199: Consider adding a note about lazy initialization similar to WandB.

The WandB section includes a note about lazy initialization on the last rank. For consistency and to help users understand the behavior, consider adding a similar note for MLFlow.

📝 Suggested addition
       },
   )

+{note} +MLFlow is initialized lazily on the last rank when `mlflow_experiment` is set and `mlflow_run_name` is non-empty. +

</details>

</blockquote></details>
<details>
<summary>src/megatron/bridge/training/utils/log_utils.py (1)</summary><blockquote>

`178-194`: **Good defensive implementation. Minor refactor suggestion.**

The `safe_serialize` function correctly handles edge cases with broken `__str__` methods. Consider the minor refactor suggested by static analysis (TRY300) to move the return to an `else` block for cleaner control flow.

The docstring mentions "JSON-serializable type" but the function always returns a `str`. Consider updating the docstring to be more precise.



<details>
<summary>♻️ Suggested refactor</summary>

```diff
 def safe_serialize(obj) -> str:
-    """Safely convert any object to a JSON-serializable type.
+    """Safely convert any object to a string representation.

     Handles objects with broken __str__ or __repr__ methods that return
     non-string types (e.g., PipelineParallelLayerLayout returns list).
+
+    Args:
+        obj: Any object to serialize.
+
+    Returns:
+        String representation of the object, or type name as fallback.
     """
     try:
         # Try str() first
         result = str(obj)
         # Verify it actually returns a string
         if not isinstance(result, str):
             # __str__ returned non-string type, use type name instead
             return f"<{type(obj).__name__}>"
-        return result
     except Exception:
         # __str__ raised an exception, use type name as fallback
         return f"<{type(obj).__name__}>"
+    else:
+        return result
src/megatron/bridge/training/utils/mlflow_utils.py (1)

15-27: Prefer T | None over Optional[T] for new APIs.
Align type hints with the project’s Python 3.10+ style.

♻️ Suggested refactor
-from typing import Any, Optional
+from typing import Any
@@
-def on_save_checkpoint_success(
-    checkpoint_path: str,
-    save_dir: str,
-    iteration: int,
-    mlflow_logger: Optional[Any],
-) -> None:
+def on_save_checkpoint_success(
+    checkpoint_path: str,
+    save_dir: str,
+    iteration: int,
+    mlflow_logger: Any | None,
+) -> None:
@@
-def on_load_checkpoint_success(
-    checkpoint_path: str,
-    load_dir: str,
-    mlflow_logger: Optional[Any],
-) -> None:
+def on_load_checkpoint_success(
+    checkpoint_path: str,
+    load_dir: str,
+    mlflow_logger: Any | None,
+) -> None:
As per coding guidelines: Use 'T | None' for nullable types instead of 'Optional[T]'.

Comment on lines +219 to +235
@property
def mlflow_logger(self) -> Optional[Any]:
"""The MLFlow logger instance.

Uses the configuration under LoggerConfig to create or resume an MLFlow run.
Restricted to the last rank to avoid duplicate entries or probably any racing conditions.
"""
if self._mlflow_logger is None:
cfg = self.cfg
if cfg is None:
self._mlflow_logger = None
return self._mlflow_logger

logger_cfg = cfg.logger
if logger_cfg.mlflow_experiment and get_rank_safe() == (get_world_size_safe() - 1):
if logger_cfg.mlflow_run_name == "":
raise ValueError("Please specify the mlflow_run_name for MLFlow logging!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ruff TRY003: long raise message will trip lint.
Consider a custom exception or add a targeted # noqa: TRY003 to the raise line.

✅ Suggested lint-safe tweak
-                    raise ValueError("Please specify the mlflow_run_name for MLFlow logging!")
+                    raise ValueError(
+                        "Please specify the mlflow_run_name for MLFlow logging!"
+                    )  # noqa: TRY003
🧰 Tools
🪛 Ruff (0.14.14)

235-235: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/state.py` around lines 219 - 235, The raise in
the mlflow_logger property (check logger_cfg.mlflow_run_name) currently uses a
long literal which trips Ruff TRY003; fix it by either replacing the long
literal with a short message and raising a custom exception class (e.g., define
MissingMLFlowRunNameError and raise that from mlflow_logger) or, if you prefer
minimal change, append a targeted inline noqa to the raise line (raise
ValueError("Specify mlflow_run_name")  # noqa: TRY003) so the linter is
satisfied; update references to logger_cfg.mlflow_run_name and the mlflow_logger
property accordingly.

Comment on lines +43 to +51
try:
checkpoint_path = str(Path(checkpoint_path).resolve())
base_name = Path(save_dir).name or "checkpoints"
expected_ckpt_path = get_checkpoint_name(save_dir, iteration)
artifact_subdir = f"{base_name}/{Path(expected_ckpt_path).name}"
mlflow_logger.log_artifacts(checkpoint_path, artifact_path=artifact_subdir)
except Exception as exc:
# continue training
print_rank_last(f"Failed to log checkpoint artifacts to MLFlow: {exc}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ruff BLE001: broad Exception catch needs explicit intent.
Either narrow the exception types or add a targeted # noqa: BLE001 to keep lint clean.

✅ Suggested lint-safe tweak
-    except Exception as exc:
+    except Exception as exc:  # noqa: BLE001
@@
-    except Exception as exc:
+    except Exception as exc:  # noqa: BLE001

Also applies to: 73-83

🧰 Tools
🪛 Ruff (0.14.14)

49-49: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/mlflow_utils.py` around lines 43 - 51, The
broad "except Exception as exc:" in the MLFlow artifact logging blocks should be
made lint-safe: replace it with a narrow set of expected exceptions (e.g.,
"except (mlflow.exceptions.MlflowException, OSError, ValueError) as exc:" and
add "from mlflow.exceptions import MlflowException" at top) and apply the same
change to the second try/except block that wraps mlflow_logger.log_artifacts
(the code that references checkpoint_path, expected_ckpt_path, artifact_subdir
and print_rank_last); if you intentionally want a broad catch keep the current
handler but annotate it with "# noqa: BLE001" on each except line instead of
leaving the generic catch.

Comment on lines +26 to +296
class TestOnSaveCheckpointSuccess:
"""Test cases for on_save_checkpoint_success function."""

def test_noop_when_mlflow_logger_is_none(self):
"""Test that the function does nothing when mlflow_logger is None."""
# Should not raise any exception
on_save_checkpoint_success(
checkpoint_path="/path/to/checkpoint",
save_dir="/path/to",
iteration=100,
mlflow_logger=None,
)

def test_logs_artifacts_with_correct_path(self):
"""Test that log_artifacts is called with correct arguments."""
mock_mlflow = MagicMock()

with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = Path(tmpdir) / "checkpoint"
checkpoint_path.mkdir()
save_dir = tmpdir

on_save_checkpoint_success(
checkpoint_path=str(checkpoint_path),
save_dir=save_dir,
iteration=1000,
mlflow_logger=mock_mlflow,
)

mock_mlflow.log_artifacts.assert_called_once()
call_args = mock_mlflow.log_artifacts.call_args

# Verify the checkpoint path is resolved
assert call_args[0][0] == str(checkpoint_path.resolve())

# Verify artifact_path format includes iteration
artifact_path = call_args[1]["artifact_path"]
base_name = Path(save_dir).name
assert artifact_path == f"{base_name}/iter_0001000"

def test_artifact_path_format_with_different_iterations(self):
"""Test that iteration is zero-padded to 7 digits in artifact path."""
mock_mlflow = MagicMock()

test_cases = [
(0, "iter_0000000"),
(1, "iter_0000001"),
(999, "iter_0000999"),
(1234567, "iter_1234567"),
(9999999, "iter_9999999"),
]

with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = Path(tmpdir) / "checkpoint"
checkpoint_path.mkdir()

for iteration, expected_suffix in test_cases:
mock_mlflow.reset_mock()

on_save_checkpoint_success(
checkpoint_path=str(checkpoint_path),
save_dir=tmpdir,
iteration=iteration,
mlflow_logger=mock_mlflow,
)

artifact_path = mock_mlflow.log_artifacts.call_args[1]["artifact_path"]
assert artifact_path.endswith(expected_suffix), (
f"Expected artifact_path to end with {expected_suffix}, got {artifact_path}"
)

def test_uses_checkpoints_as_default_base_name(self):
"""Test that 'checkpoints' is used when save_dir has no name."""
mock_mlflow = MagicMock()

with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = Path(tmpdir) / "checkpoint"
checkpoint_path.mkdir()

# Use root-like path that would have empty name
on_save_checkpoint_success(
checkpoint_path=str(checkpoint_path),
save_dir="/",
iteration=100,
mlflow_logger=mock_mlflow,
)

artifact_path = mock_mlflow.log_artifacts.call_args[1]["artifact_path"]
assert artifact_path.startswith("checkpoints/")

def test_handles_exception_gracefully(self):
"""Test that exceptions are caught and logged, not raised."""
mock_mlflow = MagicMock()
mock_mlflow.log_artifacts.side_effect = Exception("MLFlow connection error")

with patch("megatron.bridge.training.utils.mlflow_utils.print_rank_last") as mock_print:
# Should not raise exception
on_save_checkpoint_success(
checkpoint_path="/path/to/checkpoint",
save_dir="/path/to",
iteration=100,
mlflow_logger=mock_mlflow,
)

# Should print error message
mock_print.assert_called_once()
error_msg = mock_print.call_args[0][0]
assert "Failed to log checkpoint artifacts to MLFlow" in error_msg
assert "MLFlow connection error" in error_msg


class TestOnLoadCheckpointSuccess:
"""Test cases for on_load_checkpoint_success function."""

def test_noop_when_mlflow_logger_is_none(self):
"""Test that the function does nothing when mlflow_logger is None."""
# Should not raise any exception
on_load_checkpoint_success(
checkpoint_path="/path/to/checkpoint",
load_dir="/path/to",
mlflow_logger=None,
)

def test_sets_correct_tags(self):
"""Test that set_tags is called with correct checkpoint information."""
mock_mlflow = MagicMock()

with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = Path(tmpdir) / "checkpoint"
checkpoint_path.mkdir()
load_dir = tmpdir

on_load_checkpoint_success(
checkpoint_path=str(checkpoint_path),
load_dir=load_dir,
mlflow_logger=mock_mlflow,
)

mock_mlflow.set_tags.assert_called_once()
tags = mock_mlflow.set_tags.call_args[0][0]

assert "last_loaded_checkpoint" in tags
assert "checkpoint_base_dir" in tags
assert tags["last_loaded_checkpoint"] == str(checkpoint_path.resolve())
assert tags["checkpoint_base_dir"] == str(Path(load_dir).resolve())

def test_resolves_relative_paths(self):
"""Test that relative paths are resolved to absolute paths."""
mock_mlflow = MagicMock()

with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = Path(tmpdir) / "checkpoint"
checkpoint_path.mkdir()

on_load_checkpoint_success(
checkpoint_path=str(checkpoint_path),
load_dir=tmpdir,
mlflow_logger=mock_mlflow,
)

tags = mock_mlflow.set_tags.call_args[0][0]

# Both paths should be absolute (resolved)
assert Path(tags["last_loaded_checkpoint"]).is_absolute()
assert Path(tags["checkpoint_base_dir"]).is_absolute()

def test_handles_exception_gracefully(self):
"""Test that exceptions are caught and logged, not raised."""
mock_mlflow = MagicMock()
mock_mlflow.set_tags.side_effect = Exception("MLFlow API error")

with patch("megatron.bridge.training.utils.mlflow_utils.print_rank_last") as mock_print:
# Should not raise exception
on_load_checkpoint_success(
checkpoint_path="/path/to/checkpoint",
load_dir="/path/to",
mlflow_logger=mock_mlflow,
)

# Should print error message
mock_print.assert_called_once()
error_msg = mock_print.call_args[0][0]
assert "Failed to record loaded checkpoint information to MLFlow" in error_msg
assert "MLFlow API error" in error_msg


class TestSanitizeMlflowMetrics:
"""Test cases for _sanitize_mlflow_metrics function."""

def test_replaces_slashes_with_underscores(self):
"""Test that forward slashes are replaced with underscores."""
metrics = {
"train/loss": 0.5,
"train/accuracy": 0.95,
"eval/loss": 0.3,
}

result = _sanitize_mlflow_metrics(metrics)

assert result == {
"train_loss": 0.5,
"train_accuracy": 0.95,
"eval_loss": 0.3,
}

def test_handles_multiple_slashes(self):
"""Test that multiple slashes in a key are all replaced."""
metrics = {
"train/layer/0/loss": 1.0,
"model/encoder/attention/weight": 0.5,
}

result = _sanitize_mlflow_metrics(metrics)

assert result == {
"train_layer_0_loss": 1.0,
"model_encoder_attention_weight": 0.5,
}

def test_preserves_keys_without_slashes(self):
"""Test that keys without slashes are unchanged."""
metrics = {
"loss": 0.5,
"accuracy": 0.95,
"learning_rate": 0.001,
}

result = _sanitize_mlflow_metrics(metrics)

assert result == metrics

def test_handles_empty_dict(self):
"""Test that empty dictionary returns empty dictionary."""
result = _sanitize_mlflow_metrics({})
assert result == {}

def test_preserves_values(self):
"""Test that metric values are preserved unchanged."""
metrics = {
"train/int_metric": 42,
"train/float_metric": 3.14159,
"train/string_metric": "value",
"train/none_metric": None,
"train/list_metric": [1, 2, 3],
}

result = _sanitize_mlflow_metrics(metrics)

assert result["train_int_metric"] == 42
assert result["train_float_metric"] == 3.14159
assert result["train_string_metric"] == "value"
assert result["train_none_metric"] is None
assert result["train_list_metric"] == [1, 2, 3]

def test_mixed_keys(self):
"""Test dictionary with both slash and non-slash keys."""
metrics = {
"train/loss": 0.5,
"global_step": 1000,
"eval/accuracy": 0.9,
"learning_rate": 0.001,
}

result = _sanitize_mlflow_metrics(metrics)

assert result == {
"train_loss": 0.5,
"global_step": 1000,
"eval_accuracy": 0.9,
"learning_rate": 0.001,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add pytest markers for unit-test categorization.
Missing pytest.mark.unit breaks the test categorization guideline and may skip these tests in marker-based runs.

✅ Suggested update
+import pytest
+
+pytestmark = pytest.mark.unit
As per coding guidelines: Use 'pytest.mark' to categorize tests (unit, integration, system).
🤖 Prompt for AI Agents
In `@tests/unit_tests/training/utils/test_mlflow_utils.py` around lines 26 - 296,
Add pytest.mark.unit markers to these unit tests by importing pytest (if not
present) and decorating the test classes or individual test functions with
`@pytest.mark.unit`; apply to TestOnSaveCheckpointSuccess,
TestOnLoadCheckpointSuccess, and TestSanitizeMlflowMetrics (or their contained
test_* methods) so the test runner can categorize them properly and marker-based
runs include them.

Comment on lines +75 to +83
# mlflow_experiment: llama32_1b_finetuned # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_finetuned
# mlflow_tags:
# project: llama32
# phase: finetune
# variant: mlflow_example
# mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Same nesting issue: mlflow_tracking_uri should not be inside mlflow_tags.

Similar to the pretrain config, mlflow_tracking_uri is incorrectly nested inside the mlflow_tags block. Also, there are trailing blank lines that could be removed.

🔧 Proposed fix
   # mlflow_experiment: llama32_1b_finetuned  # Uncomment to enable MLFlow logging
   # mlflow_run_name: llama32_1b_finetuned
   # mlflow_tags:
   #   project: llama32
   #   phase: finetune
   #   variant: mlflow_example
-  #   mlflow_tracking_uri: http://localhost:5000  # Optional: use remote MLflow server
-
-
+  # mlflow_tracking_uri: http://localhost:5000  # Optional: use remote MLflow server

 # Random seed
🤖 Prompt for AI Agents
In `@tutorials/recipes/llama/conf/llama32_1b_finetune.yaml` around lines 75 - 83,
The mlflow_tracking_uri key is incorrectly nested under mlflow_tags in the YAML;
move mlflow_tracking_uri out of the mlflow_tags block so it sits at the same
level as mlflow_experiment/mlflow_run_name, ensuring mlflow_tags only contains
tag mappings (e.g., project, phase, variant), and remove the trailing blank
lines at the end of the file; update items around mlflow_experiment,
mlflow_run_name, mlflow_tags, and mlflow_tracking_uri accordingly.

yaoyu-33 and others added 4 commits January 28, 2026 13:42
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
mlflow_tracking_uri was incorrectly nested inside mlflow_tags.
Move it to be a top-level key alongside mlflow_experiment, mlflow_run_name,
and mlflow_tags in LoggerConfig.
Add @pytest.mark.unit decorators to TestOnSaveCheckpointSuccess,
TestOnLoadCheckpointSuccess, and TestSanitizeMlflowMetrics classes
so the test runner can categorize them properly.
@yaoyu-33
Copy link
Contributor Author

/ok to test fa03b2c

@cuichenx cuichenx self-requested a review January 29, 2026 18:16
@therealnaveenkamal therealnaveenkamal mentioned this pull request Jan 29, 2026
5 tasks
@therealnaveenkamal
Copy link
Contributor

Thanks for continuing this work, @yaoyu-33

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants