Skip to content

MLFlow Integration #1542

Closed
therealnaveenkamal wants to merge 19 commits intoNVIDIA-NeMo:mainfrom
therealnaveenkamal:main
Closed

MLFlow Integration #1542
therealnaveenkamal wants to merge 19 commits intoNVIDIA-NeMo:mainfrom
therealnaveenkamal:main

Conversation

@therealnaveenkamal
Copy link
Contributor

@therealnaveenkamal therealnaveenkamal commented Dec 1, 2025

What does this PR do ?

  • Add MLFlow integration for experiment tracking and artifact logging.

  • This PR adds comprehensive MLFlow support to Megatron Bridge, enabling users to log training metrics, configuration parameters, and checkpoint artifacts to MLFlow tracking servers. The integration includes automatic checkpoint artifact logging with iteration-based artifact paths.

  • MLFlow logging follows the same pattern as the existing W&B integration and can be enabled via LoggerConfig with mlflow_experiment and mlflow_run_name parameters.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
  • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

cc @Phlip79

Summary by CodeRabbit

  • New Features

    • Added MLFlow support for logging training metrics and checkpoint artifacts during training runs.
    • Introduced MLFlow experiment tracking configuration options including experiment name, run name, tracking URI, and custom tags.
  • Documentation

    • Added comprehensive MLFlow logging configuration and usage guide.
  • Tests

    • Added unit test suite for MLFlow integration utilities.
  • Chores

    • Added mlflow>=3.2.0 as a project dependency.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 1, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Phlip79 Phlip79 linked an issue Dec 1, 2025 that may be closed by this pull request
@ericharper ericharper requested a review from Phlip79 December 1, 2025 16:37
@ericharper ericharper requested review from maanug-nv and removed request for maanug-nv December 1, 2025 16:38
@Phlip79 Phlip79 requested a review from maanug-nv December 2, 2025 02:35
therealnaveenkamal and others added 3 commits December 2, 2025 11:45
Co-authored-by: Philip Petrakian <pgpetrak@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
@therealnaveenkamal
Copy link
Contributor Author

Hi @Phlip79 - implemented all the changes.

@yaoyu-33
Copy link
Contributor

yaoyu-33 commented Dec 2, 2025

thanks for the contribution, all look good. Had one comment.

Copy link
Contributor

@maanug-nv maanug-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the support. Left 1 suggestion. Can I also ask you to add unit tests similar to what we do for wandb?

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
@therealnaveenkamal therealnaveenkamal requested a review from a team as a code owner December 5, 2025 20:11
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
@Phlip79
Copy link
Member

Phlip79 commented Dec 5, 2025

/ok to test 3148469

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
logger.log(*args, **kwargs)


def safe_serialize(obj) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure if we need this anymore after NVIDIA/Megatron-LM#2055
cc @suiyoubi

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
maanug-nv
maanug-nv previously approved these changes Dec 13, 2025
@santurini
Copy link

Hello, is the MLFlow integration to be released soon?

Signed-off-by: Naveenraj Kamalakannan <therealnaveenkamal@gmail.com>
@therealnaveenkamal
Copy link
Contributor Author

@Phlip79 / @maanug-nv - Can we run the CI?

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Jan 20, 2026
@ananthsub
Copy link
Contributor

/ok to test e07db35

@yaoyu-33
Copy link
Contributor

/ok to test cd4a36a

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

This PR adds comprehensive MLFlow logging integration to the training framework, parallel to existing W&B support. Changes include MLFlow configuration options, logger initialization, checkpoint artifact logging, and metric collection throughout the training loop.

Changes

Cohort / File(s) Summary
Documentation & Dependency
docs/training/logging.md, pyproject.toml
Adds MLFlow logging guide documenting configuration, setup, and progress logging; adds mlflow>=3.2.0 dependency
Configuration & Validation
src/megatron/bridge/training/config.py
Introduces LoggerConfig fields for mlflow_experiment, mlflow_run_name, mlflow_tracking_uri, mlflow_tags; adds finalize() method to validate MLFlow setup and enforce mlflow_run_name requirement when enabled
MLFlow Utilities
src/megatron/bridge/training/utils/log_utils.py, src/megatron/bridge/training/utils/mlflow_utils.py
Adds safe_serialize() for JSON-safe object serialization; new mlflow_utils module with on_save_checkpoint_success(), on_load_checkpoint_success(), and _sanitize_mlflow_metrics() for checkpoint artifact logging and metric name sanitization
State & Logger Management
src/megatron/bridge/training/state.py
Introduces mlflow_logger property on GlobalState to initialize/manage MLFlow runs only on last rank; integrates timer metrics logging via _timers_write_to_mlflow handler; refactors shared safe_serialize import
Training Loop Integration
src/megatron/bridge/training/utils/train_utils.py
Augments training_log to log sanitized metrics (throughput, loss, gradients, etc.) to MLFlow at each logging interval when mlflow_logger is present
Checkpoint Integration
src/megatron/bridge/training/checkpointing.py
Registers mlflow_finalize_fn callback on checkpoint save success for artifact logging; invokes mlflow_utils.on_load_checkpoint_success on checkpoint load completion
Testing & Configuration Examples
tests/unit_tests/training/utils/test_mlflow_utils.py, tutorials/recipes/llama/conf/llama32_1b_finetune.yaml, tutorials/recipes/llama/conf/llama32_1b_pretrain.yaml
Comprehensive unit tests covering checkpoint artifact logging, tag setting, metric sanitization, and error handling; adds commented MLFlow config examples to tutorial configs; updates batch size parameters

Sequence Diagram(s)

sequenceDiagram
    participant Config as Config Validation
    participant GlobalState as GlobalState
    participant MLFlow as MLFlow
    participant Training as Training Loop
    participant Checkpoint as Checkpoint Manager

    Config->>Config: Validate MLFlow config<br/>(finalize())
    Config->>GlobalState: Config ready
    
    Training->>GlobalState: Request mlflow_logger
    GlobalState->>MLFlow: Initialize run<br/>(experiment, tracking_uri)
    MLFlow-->>GlobalState: MLFlow run handle
    GlobalState-->>Training: Return mlflow_logger
    
    loop Each training iteration
        Training->>Training: Compute metrics<br/>(loss, throughput, etc.)
        Training->>MLFlow: Log sanitized metrics<br/>(log_metrics)
        Training->>MLFlow: Log timer data<br/>(write_to_mlflow)
    end
    
    Training->>Checkpoint: Save checkpoint
    Checkpoint->>MLFlow: Log checkpoint artifacts<br/>(on_save_checkpoint_success)
    MLFlow-->>Checkpoint: Artifacts logged
    
    Training->>Checkpoint: Load checkpoint
    Checkpoint->>MLFlow: Set MLFlow tags<br/>(on_load_checkpoint_success)
    MLFlow-->>Checkpoint: Tags set
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major MLFlow integration feature but PR description explicitly states testing items remain incomplete with no documented test execution results. Document test results including unit test pass confirmation, coverage metrics, end-to-end validation, and regression testing before merge.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'MLFlow Integration' directly describes the main objective of the pull request, which adds MLFlow logging and artifact tracking functionality throughout the codebase.
Docstring Coverage ✅ Passed Docstring coverage is 91.43% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@tutorials/recipes/llama/conf/llama32_1b_finetune.yaml`:
- Around line 75-82: The mlflow_tracking_uri is currently nested under
mlflow_tags in the YAML, which will treat it as a tag instead of a top-level
LoggerConfig field; update the YAML so mlflow_tracking_uri is a sibling of
mlflow_tags (same indentation level) rather than nested under it, ensuring it
matches the LoggerConfig field name used in config.py and will be parsed as the
logger's tracking URI.

In `@tutorials/recipes/llama/conf/llama32_1b_pretrain.yaml`:
- Around line 65-71: The mlflow_tracking_uri key is incorrectly nested under
mlflow_tags; move mlflow_tracking_uri out of the mlflow_tags mapping and place
it as a top-level MLflow logger field (peer to mlflow_experiment /
mlflow_run_name), i.e., remove mlflow_tracking_uri from under mlflow_tags and
add mlflow_tracking_uri: <your_uri> at the same indentation level as
mlflow_experiment so the logger can read it correctly.
🧹 Nitpick comments (6)
src/megatron/bridge/training/utils/log_utils.py (1)

178-194: Consider adding type hint for the obj parameter.

The function is well-designed for its purpose of safely serializing objects that may have broken __str__ methods. The broad Exception catch is appropriate here since this is a last-resort fallback.

Optional: Add type hint
-def safe_serialize(obj) -> str:
+def safe_serialize(obj: object) -> str:
docs/training/logging.md (2)

173-175: Improve the install command formatting.

The current format pip install mlflow / uv add mlflow is confusing. Consider separating these as distinct options.

Suggested improvement
   1) Install MLFlow (installed by default with Megatron Bridge):

-  ```bash
-  pip install mlflow / uv add mlflow
-  ```
+  ```bash
+  pip install mlflow
+  # or with uv:
+  # uv add mlflow
+  ```

177-179: Fix markdown list indentation.

The list items should not have leading indentation to comply with markdown linting standards.

Suggested fix
   2) Configure the tracking server (Optional):
-  - Either set `MLFLOW_TRACKING_URI` in the environment, or
-  - Pass an explicit `mlflow_tracking_uri` in the logger config.
+- Either set `MLFLOW_TRACKING_URI` in the environment, or
+- Pass an explicit `mlflow_tracking_uri` in the logger config.
src/megatron/bridge/training/utils/train_utils.py (1)

538-540: Loss metrics should be sanitized for MLFlow.

The loss_dict keys may contain "/" characters (e.g., "lm_loss/validation"). Other metric dictionaries are sanitized via _sanitize_mlflow_metrics, but here the raw keys are used directly.

♻️ Proposed fix
         if mlflow_logger:
-            loss_metrics = {key: float(val) for key, val in loss_dict.items()}
+            loss_metrics = _sanitize_mlflow_metrics({key: float(val) for key, val in loss_dict.items()})
             mlflow_logger.log_metrics(loss_metrics, step=iteration)
tests/unit_tests/training/utils/test_mlflow_utils.py (1)

26-27: Consider adding pytest markers for test categorization.

As per coding guidelines, tests should use pytest.mark to categorize tests (unit, integration, system). These are unit tests and should be marked accordingly.

♻️ Proposed fix
+import pytest
+
+
+@pytest.mark.unit
 class TestOnSaveCheckpointSuccess:
     """Test cases for on_save_checkpoint_success function."""

Apply similar @pytest.mark.unit decorator to TestOnLoadCheckpointSuccess and TestSanitizeMlflowMetrics classes.

Also applies to: 137-138, 212-213

src/megatron/bridge/training/state.py (1)

446-449: Add stacklevel parameter to warnings.warn.

The warning message will point to this line rather than the caller. Setting stacklevel=2 would improve debuggability.

♻️ Proposed fix
         except Exception:
             import warnings

-            warnings.warn("Failed to log timer metrics to MLFlow; continuing without timer metrics.")
+            warnings.warn("Failed to log timer metrics to MLFlow; continuing without timer metrics.", stacklevel=2)

Comment on lines +75 to +82
# mlflow_experiment: llama32_1b_finetuned # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_finetuned
# mlflow_tags:
# project: llama32
# phase: finetune
# variant: mlflow_example
# mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

mlflow_tracking_uri is incorrectly nested under mlflow_tags.

Based on the LoggerConfig structure in config.py, mlflow_tracking_uri should be a sibling field to mlflow_tags, not nested within it. The current indentation would cause the tracking URI to be treated as a tag value rather than the configuration parameter.

Suggested fix
   # mlflow_experiment: llama32_1b_finetuned  # Uncomment to enable MLFlow logging
   # mlflow_run_name: llama32_1b_finetuned
   # mlflow_tags:
   #   project: llama32
   #   phase: finetune
   #   variant: mlflow_example
-  #   mlflow_tracking_uri: http://localhost:5000  # Optional: use remote MLflow server
+  # mlflow_tracking_uri: http://localhost:5000  # Optional: use remote MLflow server
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# mlflow_experiment: llama32_1b_finetuned # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_finetuned
# mlflow_tags:
# project: llama32
# phase: finetune
# variant: mlflow_example
# mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server
# mlflow_experiment: llama32_1b_finetuned # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_finetuned
# mlflow_tags:
# project: llama32
# phase: finetune
# variant: mlflow_example
# mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server
🤖 Prompt for AI Agents
In `@tutorials/recipes/llama/conf/llama32_1b_finetune.yaml` around lines 75 - 82,
The mlflow_tracking_uri is currently nested under mlflow_tags in the YAML, which
will treat it as a tag instead of a top-level LoggerConfig field; update the
YAML so mlflow_tracking_uri is a sibling of mlflow_tags (same indentation level)
rather than nested under it, ensuring it matches the LoggerConfig field name
used in config.py and will be parsed as the logger's tracking URI.

Comment on lines +65 to +71
# mlflow_experiment: llama32_1b_pretrain # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_pretrain_run
# mlflow_tags:
# project: llama32
# phase: pretrain
# variant: mlflow_example
# mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

mlflow_tracking_uri is incorrectly nested under mlflow_tags.

Same issue as in the finetune config - mlflow_tracking_uri should be a top-level logger field, not nested under mlflow_tags.

Suggested fix
   # mlflow_experiment: llama32_1b_pretrain  # Uncomment to enable MLFlow logging
   # mlflow_run_name: llama32_1b_pretrain_run
   # mlflow_tags:
   #   project: llama32
   #   phase: pretrain
   #   variant: mlflow_example
-  #   mlflow_tracking_uri: http://localhost:5000  # Optional: use remote MLflow server
+  # mlflow_tracking_uri: http://localhost:5000  # Optional: use remote MLflow server
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# mlflow_experiment: llama32_1b_pretrain # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_pretrain_run
# mlflow_tags:
# project: llama32
# phase: pretrain
# variant: mlflow_example
# mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server
# mlflow_experiment: llama32_1b_pretrain # Uncomment to enable MLFlow logging
# mlflow_run_name: llama32_1b_pretrain_run
# mlflow_tags:
# project: llama32
# phase: pretrain
# variant: mlflow_example
# mlflow_tracking_uri: http://localhost:5000 # Optional: use remote MLflow server
🤖 Prompt for AI Agents
In `@tutorials/recipes/llama/conf/llama32_1b_pretrain.yaml` around lines 65 - 71,
The mlflow_tracking_uri key is incorrectly nested under mlflow_tags; move
mlflow_tracking_uri out of the mlflow_tags mapping and place it as a top-level
MLflow logger field (peer to mlflow_experiment / mlflow_run_name), i.e., remove
mlflow_tracking_uri from under mlflow_tags and add mlflow_tracking_uri:
<your_uri> at the same indentation level as mlflow_experiment so the logger can
read it correctly.

yaoyu-33 added a commit that referenced this pull request Jan 28, 2026
Add comprehensive MLFlow support to Megatron Bridge for experiment tracking and artifact logging.

- Add MLFlow logger support in GlobalState with configurable experiment, run name, tracking URI, and tags
- Log training metrics (losses, learning rate, batch size, throughput, timers, memory, runtime, norms, energy) to MLFlow
- Log checkpoint artifacts to MLFlow with iteration-based artifact paths
- Add MLFlow configuration options to LoggerConfig (mlflow_experiment, mlflow_run_name, mlflow_tracking_uri, mlflow_tags)
- Add validation in LoggerConfig.finalize() to check MLFlow availability
- Move safe_serialize to log_utils.py for reuse across WandB and MLFlow
- Add comprehensive unit tests for MLFlow utilities
- Add documentation for MLFlow logging configuration

Based on community contribution from @therealnaveenkamal in PR #1542.

Co-authored-by: Naveenraj Kamalakannan <therealnaveenkamal@users.noreply.github.com>
@yaoyu-33 yaoyu-33 mentioned this pull request Jan 28, 2026
4 tasks
@therealnaveenkamal
Copy link
Contributor Author

Continued by @yaoyu-33 in #2112

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Logging to MLFlow

7 participants