From b19404d8a89c0e76dbfac4bd158d3dbb8c9b813d Mon Sep 17 00:00:00 2001 From: Ryan Li Date: Fri, 30 Jan 2026 20:12:32 +0000 Subject: [PATCH] mlflow improvements Signed-off-by: Ryan Li --- src/megatron/bridge/training/eval.py | 11 ++ .../bridge/training/utils/mlflow_utils.py | 13 +- .../bridge/training/utils/train_utils.py | 153 +++++++++++------- .../training/utils/test_mlflow_utils.py | 34 ++-- .../training/utils/test_train_utils.py | 1 + 5 files changed, 132 insertions(+), 80 deletions(-) diff --git a/src/megatron/bridge/training/eval.py b/src/megatron/bridge/training/eval.py index 99b6d6a0af..a0782de860 100644 --- a/src/megatron/bridge/training/eval.py +++ b/src/megatron/bridge/training/eval.py @@ -35,6 +35,7 @@ from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.forward_step_func_types import ForwardStepCallable from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.utils.mlflow_utils import _sanitize_mlflow_metrics from megatron.bridge.training.utils.pg_utils import get_pg_collection from megatron.bridge.training.utils.train_utils import prepare_forward_step_func from megatron.bridge.utils.common_utils import is_last_rank, print_rank_0, print_rank_last @@ -336,6 +337,7 @@ def evaluate_and_print_results( writer = None wandb_writer = state.wandb_logger + mlflow_writer = state.mlflow_logger if should_fire(callback_manager, start_event): callback_manager.fire( @@ -386,6 +388,15 @@ def evaluate_and_print_results( if state.cfg.logger.log_validation_ppl_to_tensorboard: wandb_writer.log({"{} validation ppl".format(key): ppl}, state.train_state.step) + if mlflow_writer and is_last_rank(): + mlflow_writer.log_metrics( + _sanitize_mlflow_metrics({f"val/{key}": total_loss_dict[key].item()}), step=state.train_state.step + ) + if state.cfg.logger.log_validation_ppl_to_tensorboard: + mlflow_writer.log_metrics( + _sanitize_mlflow_metrics({f"val/{key} ppl": ppl}), step=state.train_state.step + ) + if process_non_loss_data_func is not None and writer and is_last_rank(): process_non_loss_data_func(collected_non_loss_data, state.train_state.step, writer) diff --git a/src/megatron/bridge/training/utils/mlflow_utils.py b/src/megatron/bridge/training/utils/mlflow_utils.py index 9d6a678eb5..c05c9211a5 100644 --- a/src/megatron/bridge/training/utils/mlflow_utils.py +++ b/src/megatron/bridge/training/utils/mlflow_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from pathlib import Path from typing import Any, Optional @@ -85,4 +86,14 @@ def on_load_checkpoint_success( def _sanitize_mlflow_metrics(metrics: dict[str, Any]) -> dict[str, Any]: """Sanitize all metric names in a dictionary for MLFlow logging.""" - return {key.replace("/", "_"): value for key, value in metrics.items()} + + def _sanitize_key(key): + sanitized = key.replace("@", "_at_") + sanitized = re.sub(r"/+", "/", sanitized) + if "/" in key: + first, rest = sanitized.split("/", 1) + sanitized = first + "/" + rest.replace("/", "_") + sanitized = re.sub(r"[^/\w.\- :]", "_", sanitized) + return sanitized + + return {_sanitize_key(key): value for key, value in metrics.items()} diff --git a/src/megatron/bridge/training/utils/train_utils.py b/src/megatron/bridge/training/utils/train_utils.py index b2f57afa9d..7ec970305c 100644 --- a/src/megatron/bridge/training/utils/train_utils.py +++ b/src/megatron/bridge/training/utils/train_utils.py @@ -385,6 +385,8 @@ def training_log( train_config = config.train pg_collection = get_pg_collection(model) + loggers_exist = writer is not None or wandb_writer is not None or mlflow_logger is not None + # Advanced, skipped, and Nan iterations. advanced_iters_key = "advanced iterations" skipped_iters_key = "skipped iterations" @@ -477,7 +479,7 @@ def training_log( dump(snapshot, f) print_rank_0(f"Saved memory snapshot to {filename}") - if writer and (iteration % logger_config.tensorboard_log_interval == 0): + if loggers_exist and iteration % logger_config.tensorboard_log_interval == 0: if logger_config.log_throughput_to_tensorboard: throughput_report = report_throughput( iteration=iteration, @@ -486,8 +488,9 @@ def training_log( history_wct=history_wct, window_size=logger_config.throughput_window_size, ) - for metric, value in throughput_report.items(): - writer.add_scalar(metric, value, iteration) + if writer: + for metric, value in throughput_report.items(): + writer.add_scalar(metric, value, iteration) if wandb_writer: wandb_writer.log(throughput_report, iteration) if mlflow_logger: @@ -495,8 +498,9 @@ def training_log( if logger_config.log_memory_to_tensorboard: memory_report = report_memory(memory_keys=logger_config.memory_keys) memory_report = {f"memory/{mem_stat}": val for (mem_stat, val) in memory_report.items()} - for metric, value in memory_report.items(): - writer.add_scalar(metric, value, iteration) + if writer: + for metric, value in memory_report.items(): + writer.add_scalar(metric, value, iteration) if wandb_writer: wandb_writer.log(memory_report, iteration) if mlflow_logger: @@ -509,16 +513,20 @@ def training_log( train_iters=train_config.train_iters, time_unit=logger_config.runtime_time_unit, ) - for metric, value in runtime_report.items(): - writer.add_scalar(metric, value, iteration) + if writer: + for metric, value in runtime_report.items(): + writer.add_scalar(metric, value, iteration) if wandb_writer: wandb_writer.log(runtime_report, iteration) if mlflow_logger: mlflow_logger.log_metrics(_sanitize_mlflow_metrics(runtime_report), step=iteration) + + # l2 grad norm if logger_config.log_l2_norm_grad_to_tensorboard: l2_report = report_l2_norm_grad(model) - for metric, value in l2_report.items(): - writer.add_scalar(metric, value, iteration) + if writer: + for metric, value in l2_report.items(): + writer.add_scalar(metric, value, iteration) if wandb_writer: wandb_writer.log(l2_report, iteration) if mlflow_logger: @@ -527,16 +535,26 @@ def training_log( wandb_writer.log({"samples vs steps": train_state.consumed_train_samples}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"samples vs steps": train_state.consumed_train_samples}, step=iteration) - writer.add_scalar("learning-rate", learning_rate, iteration) - writer.add_scalar("learning-rate vs samples", learning_rate, train_state.consumed_train_samples) - if wandb_writer and learning_rate is not None: - wandb_writer.log({"learning-rate": learning_rate}, iteration) - if mlflow_logger and learning_rate is not None: - mlflow_logger.log_metrics({"learning-rate": learning_rate}, step=iteration) + + # learning rate + if learning_rate is not None: + if writer: + writer.add_scalar("learning-rate", learning_rate, iteration) + writer.add_scalar("learning-rate vs samples", learning_rate, train_state.consumed_train_samples) + if wandb_writer: + wandb_writer.log({"learning-rate": learning_rate}, iteration) + if mlflow_logger: + mlflow_logger.log_metrics({"learning-rate": learning_rate}, step=iteration) + + # decoupled lr if config.optimizer.decoupled_lr is not None: - writer.add_scalar("decoupled-learning-rate", decoupled_learning_rate, iteration) + if writer: + writer.add_scalar("decoupled-learning-rate", decoupled_learning_rate, iteration) + + # skipped samples if global_state.train_state.skipped_train_samples > 0: - writer.add_scalar("skipped-train-samples", global_state.train_state.skipped_train_samples, iteration) + if writer: + writer.add_scalar("skipped-train-samples", global_state.train_state.skipped_train_samples, iteration) if wandb_writer: wandb_writer.log({"skipped-train-samples": global_state.train_state.skipped_train_samples}, iteration) if mlflow_logger: @@ -544,66 +562,92 @@ def training_log( {"skipped-train-samples": global_state.train_state.skipped_train_samples}, step=iteration, ) - writer.add_scalar("batch-size", batch_size, iteration) - writer.add_scalar("batch-size vs samples", batch_size, global_state.train_state.consumed_train_samples) + + # batch size + if writer: + writer.add_scalar("batch-size", batch_size, iteration) + writer.add_scalar("batch-size vs samples", batch_size, global_state.train_state.consumed_train_samples) if wandb_writer: wandb_writer.log({"batch-size": batch_size}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"batch-size": batch_size}, step=iteration) + + # loss dict for key in loss_dict: - writer.add_scalar(key, loss_dict[key], iteration) - writer.add_scalar(key + " vs samples", loss_dict[key], global_state.train_state.consumed_train_samples) + if writer: + writer.add_scalar(key, loss_dict[key], iteration) + writer.add_scalar(key + " vs samples", loss_dict[key], global_state.train_state.consumed_train_samples) if wandb_writer: wandb_writer.log({key: loss_dict[key]}, iteration) if mlflow_logger: loss_metrics = {key: float(val) for key, val in loss_dict.items()} mlflow_logger.log_metrics(loss_metrics, step=iteration) + + # loss scale if logger_config.log_loss_scale_to_tensorboard: - writer.add_scalar("loss-scale", loss_scale, iteration) - writer.add_scalar("loss-scale vs samples", loss_scale, global_state.train_state.consumed_train_samples) + if writer: + writer.add_scalar("loss-scale", loss_scale, iteration) + writer.add_scalar("loss-scale vs samples", loss_scale, global_state.train_state.consumed_train_samples) if wandb_writer: wandb_writer.log({"loss-scale": loss_scale}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"loss-scale": loss_scale}, step=iteration) + + # world size if logger_config.log_world_size_to_tensorboard: - writer.add_scalar("world-size", get_world_size_safe(), iteration) - writer.add_scalar( - "world-size vs samples", get_world_size_safe(), global_state.train_state.consumed_train_samples - ) + if writer: + writer.add_scalar("world-size", get_world_size_safe(), iteration) + writer.add_scalar( + "world-size vs samples", get_world_size_safe(), global_state.train_state.consumed_train_samples + ) if wandb_writer: wandb_writer.log({"world-size": get_world_size_safe()}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"world-size": get_world_size_safe()}, step=iteration) + + # grad norm if grad_norm is not None: - writer.add_scalar("grad-norm", grad_norm, iteration) - writer.add_scalar("grad-norm vs samples", grad_norm, global_state.train_state.consumed_train_samples) + if writer: + writer.add_scalar("grad-norm", grad_norm, iteration) + writer.add_scalar("grad-norm vs samples", grad_norm, global_state.train_state.consumed_train_samples) if wandb_writer: wandb_writer.log({"grad-norm": grad_norm}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"grad-norm": grad_norm}, step=iteration) + + # num zeros in grad if num_zeros_in_grad is not None: - writer.add_scalar("num-zeros", num_zeros_in_grad, iteration) - writer.add_scalar( - "num-zeros vs samples", num_zeros_in_grad, global_state.train_state.consumed_train_samples - ) + if writer: + writer.add_scalar("num-zeros", num_zeros_in_grad, iteration) + writer.add_scalar( + "num-zeros vs samples", num_zeros_in_grad, global_state.train_state.consumed_train_samples + ) if wandb_writer: wandb_writer.log({"num-zeros": num_zeros_in_grad}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"num-zeros": num_zeros_in_grad}, step=iteration) + + # params norm if params_norm is not None: - writer.add_scalar("params-norm", params_norm, iteration) - writer.add_scalar("params-norm vs samples", params_norm, global_state.train_state.consumed_train_samples) + if writer: + writer.add_scalar("params-norm", params_norm, iteration) + writer.add_scalar( + "params-norm vs samples", params_norm, global_state.train_state.consumed_train_samples + ) if wandb_writer: wandb_writer.log({"params-norm": params_norm}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"params-norm": params_norm}, step=iteration) + + # max attention logit if log_max_attention_logit is not None: - writer.add_scalar("max-attention-logit", log_max_attention_logit, iteration) - writer.add_scalar( - "max-attention-logit vs samples", - log_max_attention_logit, - global_state.train_state.consumed_train_samples, - ) + if writer: + writer.add_scalar("max-attention-logit", log_max_attention_logit, iteration) + writer.add_scalar( + "max-attention-logit vs samples", + log_max_attention_logit, + global_state.train_state.consumed_train_samples, + ) if wandb_writer: wandb_writer.log({"max-attention-logit": log_max_attention_logit}, iteration) if mlflow_logger: @@ -657,24 +701,24 @@ def training_log( f"Step Time : {elapsed_time_per_iteration:.2f}s GPU utilization: {per_gpu_tf:.1f}MODEL_TFLOP/s/GPU" ) + # throughput if logger_config.log_throughput_to_tensorboard: if writer: writer.add_scalar("throughput/tflops/device", per_gpu_tf, iteration) writer.add_scalar("throughput/tflops", per_gpu_tf * get_world_size_safe(), iteration) - if wandb_writer: - wandb_writer.log({"throughput/tflops/device": per_gpu_tf}, iteration) - wandb_writer.log({"throughput/tflops": per_gpu_tf * get_world_size_safe()}, iteration) - if mlflow_logger: - mlflow_logger.log_metrics( - _sanitize_mlflow_metrics( - { - "throughput/tflops/device": per_gpu_tf, - "throughput/tflops": per_gpu_tf * get_world_size_safe(), - } - ), - step=iteration, - ) + if wandb_writer: + wandb_writer.log({"throughput/tflops/device": per_gpu_tf}, iteration) + wandb_writer.log({"throughput/tflops": per_gpu_tf * get_world_size_safe()}, iteration) + if mlflow_logger: + mlflow_logger.log_metrics( + metrics={ + "throughput/tflops_per_device": per_gpu_tf, + "throughput/tflops": per_gpu_tf * get_world_size_safe(), + }, + step=iteration, + ) + # timers if logger_config.log_timers_to_tensorboard: if writer: writer.add_scalar("iteration-time", elapsed_time_per_iteration, iteration) @@ -682,6 +726,7 @@ def training_log( wandb_writer.log({"iteration-time": elapsed_time_per_iteration}, iteration) if mlflow_logger: mlflow_logger.log_metrics({"iteration-time": elapsed_time_per_iteration}, step=iteration) + log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]" log_string += " iteration {:8d}/{:8d} |".format(iteration, train_config.train_iters) log_string += " consumed samples: {:12d} |".format(global_state.train_state.consumed_train_samples) diff --git a/tests/unit_tests/training/utils/test_mlflow_utils.py b/tests/unit_tests/training/utils/test_mlflow_utils.py index a64322eac4..079b0ecd06 100644 --- a/tests/unit_tests/training/utils/test_mlflow_utils.py +++ b/tests/unit_tests/training/utils/test_mlflow_utils.py @@ -217,22 +217,6 @@ def test_handles_exception_gracefully(self): class TestSanitizeMlflowMetrics: """Test cases for _sanitize_mlflow_metrics function.""" - def test_replaces_slashes_with_underscores(self): - """Test that forward slashes are replaced with underscores.""" - metrics = { - "train/loss": 0.5, - "train/accuracy": 0.95, - "eval/loss": 0.3, - } - - result = _sanitize_mlflow_metrics(metrics) - - assert result == { - "train_loss": 0.5, - "train_accuracy": 0.95, - "eval_loss": 0.3, - } - def test_handles_multiple_slashes(self): """Test that multiple slashes in a key are all replaced.""" metrics = { @@ -243,8 +227,8 @@ def test_handles_multiple_slashes(self): result = _sanitize_mlflow_metrics(metrics) assert result == { - "train_layer_0_loss": 1.0, - "model_encoder_attention_weight": 0.5, + "train/layer_0_loss": 1.0, + "model/encoder_attention_weight": 0.5, } def test_preserves_keys_without_slashes(self): @@ -276,11 +260,11 @@ def test_preserves_values(self): result = _sanitize_mlflow_metrics(metrics) - assert result["train_int_metric"] == 42 - assert result["train_float_metric"] == 3.14159 - assert result["train_string_metric"] == "value" - assert result["train_none_metric"] is None - assert result["train_list_metric"] == [1, 2, 3] + assert result["train/int_metric"] == 42 + assert result["train/float_metric"] == 3.14159 + assert result["train/string_metric"] == "value" + assert result["train/none_metric"] is None + assert result["train/list_metric"] == [1, 2, 3] def test_mixed_keys(self): """Test dictionary with both slash and non-slash keys.""" @@ -294,8 +278,8 @@ def test_mixed_keys(self): result = _sanitize_mlflow_metrics(metrics) assert result == { - "train_loss": 0.5, + "train/loss": 0.5, "global_step": 1000, - "eval_accuracy": 0.9, + "eval/accuracy": 0.9, "learning_rate": 0.001, } diff --git a/tests/unit_tests/training/utils/test_train_utils.py b/tests/unit_tests/training/utils/test_train_utils.py index 4f9bc916cc..c9e8f4ca9b 100644 --- a/tests/unit_tests/training/utils/test_train_utils.py +++ b/tests/unit_tests/training/utils/test_train_utils.py @@ -1315,6 +1315,7 @@ def test_no_loggers_present( # Remove loggers mock_global_state.tensorboard_logger = None mock_global_state.wandb_logger = None + mock_global_state.mlflow_logger = None # Set iteration to match logging intervals mock_global_state.train_state.step = 10