Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/training/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,49 @@ The plugin automatically forwards the `WANDB_API_KEY` and by default injects CLI
This allows seamless integration of W&B logging into your training workflow without manual configuration.


### MLFlow

Megatron Bridge can log metrics and artifacts to MLFlow, following the same pattern as the W&B integration.

#### What Gets Logged

When enabled, MLFlow receives:

- Training configuration as run parameters
- Scalar metrics (losses, learning rate, batch size, throughput, timers, memory, runtime, norms, energy, etc.)
- Checkpoint artifacts saved under an experiment-specific artifact path per iteration

#### Enable MLFlow Logging

1) Install MLFlow (installed by default with Megatron Bridge):

```bash
pip install mlflow / uv add mlflow
```

2) Configure the tracking server (Optional):
- Either set `MLFLOW_TRACKING_URI` in the environment, or
- Pass an explicit `mlflow_tracking_uri` in the logger config.

3) Configure logging in your training setup.

```python
from megatron.bridge.training.config import LoggerConfig

cfg.logger = LoggerConfig(
tensorboard_dir="./runs/tensorboard",
mlflow_experiment="my_megatron_experiment",
mlflow_run_name="llama32_1b_pretrain_run",
mlflow_tracking_uri="http://mlflow:5000", # optional
mlflow_tags={ # optional
"project": "llama32",
"phase": "pretrain",
},
)
```



#### Progress Log

When `logger.log_progress` is enabled, the framework generates a `progress.txt` file in the checkpoint save directory.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ dependencies = [
"flash-linear-attention",
"timm",
"open-clip-torch>=3.2.0",
"mlflow>=3.2.0",
]


Expand Down
13 changes: 12 additions & 1 deletion src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from megatron.bridge.training.state import GlobalState, TrainState
from megatron.bridge.training.tokenizers.config import TokenizerConfig
from megatron.bridge.training.tokenizers.tokenizer import MegatronTokenizer
from megatron.bridge.training.utils import wandb_utils
from megatron.bridge.training.utils import mlflow_utils, wandb_utils
from megatron.bridge.training.utils.checkpoint_utils import (
checkpoint_exists,
ensure_directory_exists,
Expand Down Expand Up @@ -788,11 +788,21 @@ def wandb_finalize_fn() -> None:
wandb_writer=state.wandb_logger,
)

def mlflow_finalize_fn() -> None:
mlflow_utils.on_save_checkpoint_success(
checkpoint_name,
save_dir,
train_state.step,
mlflow_logger=state.mlflow_logger,
)

if ckpt_cfg.async_save:
assert async_save_request is not None
async_save_request.add_finalize_fn(wandb_finalize_fn)
async_save_request.add_finalize_fn(mlflow_finalize_fn)
else:
wandb_finalize_fn()
mlflow_finalize_fn()

if ckpt_cfg.async_save:
schedule_async_save(state, async_save_request)
Expand Down Expand Up @@ -1739,6 +1749,7 @@ def _load_checkpoint_from_path(

if not torch.distributed.is_initialized() or is_last_rank():
wandb_utils.on_load_checkpoint_success(checkpoint_name, load_dir, state.wandb_logger)
mlflow_utils.on_load_checkpoint_success(checkpoint_name, load_dir, state.mlflow_logger)

torch.cuda.empty_cache()

Expand Down
38 changes: 38 additions & 0 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,18 @@ class LoggerConfig:
wandb_entity: Optional[str] = None
"""The wandb entity name."""

mlflow_experiment: Optional[str] = None
"""The MLFlow experiment name."""

mlflow_run_name: Optional[str] = None
"""The MLFlow run name."""

mlflow_tracking_uri: Optional[str] = None
"""Optional MLFlow tracking URI."""

mlflow_tags: Optional[dict[str, str]] = None
"""Optional tags to apply to the MLFlow run."""

logging_level: int = logging.INFO
"""Set default logging level"""

Expand All @@ -988,6 +1000,31 @@ class LoggerConfig:
save_config_filepath: Optional[str] = None
"""If set, save the task configuration (ConfigContainer) to this file."""

def finalize(self) -> None:
"""Validate logger settings and optional MLFlow dependency."""
using_mlflow = any(
[
self.mlflow_experiment,
self.mlflow_run_name,
self.mlflow_tracking_uri,
self.mlflow_tags,
]
)

if using_mlflow:
try:
import importlib

importlib.import_module("mlflow")
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"MLFlow logging is configured, but the 'mlflow' package is not installed. "
"Install it via pip install mlflow or uv add mlflow"
) from exc

if self.mlflow_experiment and (self.mlflow_run_name is None or self.mlflow_run_name == ""):
raise ValueError("Set logger.mlflow_run_name when enabling MLFlow logging.")


@dataclass(kw_only=True)
class ProfilingConfig:
Expand Down Expand Up @@ -1337,6 +1374,7 @@ def validate(self) -> None:
if hasattr(self.model, "finalize"):
self.model.finalize()

self.logger.finalize()
self.train.finalize()
self.scheduler.finalize()
self.checkpoint.finalize()
Expand Down
116 changes: 96 additions & 20 deletions src/megatron/bridge/training/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.nvrx_straggler import NVRxStragglerDetectionManager
from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer
from megatron.bridge.training.utils.log_utils import safe_serialize
from megatron.bridge.training.utils.sig_utils import DistributedSignalHandler
from megatron.bridge.utils.common_utils import get_rank_safe, get_world_size_safe

Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(self) -> None:
self._tokenizer: Optional[Any] = None
self._tensorboard_logger: Optional[SummaryWriter] = None
self._wandb_logger: Optional[Any] = None
self._mlflow_logger: Optional[Any] = None
self._timers: Optional[Timers] = None
self._train_state: Optional[TrainState] = None
self.rank_monitor_client: Optional[Any] = None
Expand Down Expand Up @@ -193,26 +195,6 @@ def wandb_logger(self) -> Optional[Any]:

save_dir = self.cfg.logger.wandb_save_dir or os.path.join(self.cfg.checkpoint.save, "wandb")

# Sanitize config for WandB by doing a JSON round-trip
# This ensures all objects are converted to basic Python types that WandB can handle
def safe_serialize(obj):
"""Safely convert any object to a JSON-serializable type.

Handles objects with broken __str__ or __repr__ methods that return
non-string types (e.g., PipelineParallelLayerLayout returns list).
"""
try:
# Try str() first
result = str(obj)
# Verify it actually returns a string
if not isinstance(result, str):
# __str__ returned non-string type, use type name instead
return f"<{type(obj).__name__}>"
return result
except Exception:
# __str__ raised an exception, use type name as fallback
return f"<{type(obj).__name__}>"

config_dict = self.cfg.to_dict()
sanitized_config = json.loads(json.dumps(config_dict, default=safe_serialize))

Expand All @@ -234,12 +216,79 @@ def safe_serialize(obj):
self._wandb_logger = None
return self._wandb_logger

@property
def mlflow_logger(self) -> Optional[Any]:
"""The MLFlow logger instance.

Uses the configuration under LoggerConfig to create or resume an MLFlow run.
Restricted to the last rank to avoid duplicate entries or probably any racing conditions.
"""
if self._mlflow_logger is None:
cfg = self.cfg
if cfg is None:
self._mlflow_logger = None
return self._mlflow_logger

logger_cfg = cfg.logger
if logger_cfg.mlflow_experiment and get_rank_safe() == (get_world_size_safe() - 1):
if logger_cfg.mlflow_run_name == "":
raise ValueError("Please specify the mlflow_run_name for MLFlow logging!")

import mlflow

# set tracking URI
if logger_cfg.mlflow_tracking_uri:
mlflow.set_tracking_uri(logger_cfg.mlflow_tracking_uri)

# Set or get experiment
mlflow.set_experiment(logger_cfg.mlflow_experiment)

# Prepare tags and params
def _flatten_dict(d: dict[str, Any], parent_key: str = "", sep: str = ".") -> dict[str, Any]:
items: dict[str, Any] = {}
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.update(_flatten_dict(v, new_key, sep=sep))
else:
if isinstance(v, (list, tuple)):
v = [safe_serialize(x) for x in v]
items[new_key] = v
return items

config_dict = cfg.to_dict()
sanitized_config = json.loads(json.dumps(config_dict, default=safe_serialize))
flat_params = _flatten_dict(sanitized_config)

# Start or resume a run
run_name = logger_cfg.mlflow_run_name
tags = logger_cfg.mlflow_tags or {}

active_run = mlflow.active_run()
if active_run is None:
mlflow.start_run(run_name=run_name, tags=tags or None)
elif tags:
# If there is already an active run, at least set provided tags
mlflow.set_tags(tags)

# Log flattened configuration as params (best-effort)
stringified_params = {
key: (safe_serialize(value) if not isinstance(value, (int, float, bool, str)) else value)
for key, value in flat_params.items()
}
mlflow.log_params(stringified_params)
self._mlflow_logger = mlflow
else:
self._mlflow_logger = None
return self._mlflow_logger

@property
def timers(self) -> Timers:
"""The Megatron Timers instance used for tracking execution times."""
if self._timers is None:
self._timers = Timers(self.cfg.logger.timing_log_level, self.cfg.logger.timing_log_option)
self._timers.write_to_wandb = types.MethodType(_timers_write_to_wandb, self._timers)
self._timers.write_to_mlflow = types.MethodType(_timers_write_to_mlflow, self._timers)
return self._timers

@property
Expand Down Expand Up @@ -344,6 +393,7 @@ def reset_for_restart(self) -> None:
self._train_state = None
self._tensorboard_logger = None
self._wandb_logger = None
self._mlflow_logger = None
self._energy_monitor = None
self._energy_monitor_created = False
self._signal_handler = None
Expand Down Expand Up @@ -371,3 +421,29 @@ def _timers_write_to_wandb(
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
writer.log({name + "-time": max_time}, iteration)


def _timers_write_to_mlflow(
self: Timers,
names: list[str],
logger: Any,
iteration: int,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
) -> None:
"""Patch to write timers to MLFlow for Megatron Core Timers."""
assert normalizer > 0.0
name_to_min_max_time = self._get_global_min_max_time(names, reset, barrier, normalizer)
if logger is not None:
metrics: dict[str, float] = {}
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
sanitized_name = name.replace("/", "_") + "-time"
metrics[sanitized_name] = max_time
try:
logger.log_metrics(metrics, step=iteration)
except Exception:
import warnings

warnings.warn("Failed to log timer metrics to MLFlow; continuing without timer metrics.")
19 changes: 19 additions & 0 deletions src/megatron/bridge/training/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,22 @@ def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs:
logger.log(*args, **kwargs)
else:
logger.log(*args, **kwargs)


def safe_serialize(obj) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure if we need this anymore after NVIDIA/Megatron-LM#2055
cc @suiyoubi

"""Safely convert any object to a JSON-serializable type.

Handles objects with broken __str__ or __repr__ methods that return
non-string types (e.g., PipelineParallelLayerLayout returns list).
"""
try:
# Try str() first
result = str(obj)
# Verify it actually returns a string
if not isinstance(result, str):
# __str__ returned non-string type, use type name instead
return f"<{type(obj).__name__}>"
return result
except Exception:
# __str__ raised an exception, use type name as fallback
return f"<{type(obj).__name__}>"
Loading
Loading