Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/megatron/bridge/training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.forward_step_func_types import ForwardStepCallable
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.utils.mlflow_utils import _sanitize_mlflow_metrics
from megatron.bridge.training.utils.pg_utils import get_pg_collection
from megatron.bridge.training.utils.train_utils import prepare_forward_step_func
from megatron.bridge.utils.common_utils import is_last_rank, print_rank_0, print_rank_last
Expand Down Expand Up @@ -336,6 +337,7 @@ def evaluate_and_print_results(
writer = None

wandb_writer = state.wandb_logger
mlflow_writer = state.mlflow_logger
comet_logger = state.comet_logger

if should_fire(callback_manager, start_event):
Expand Down Expand Up @@ -387,6 +389,14 @@ def evaluate_and_print_results(
if state.cfg.logger.log_validation_ppl_to_tensorboard:
wandb_writer.log({"{} validation ppl".format(key): ppl}, state.train_state.step)

if mlflow_writer and is_last_rank():
mlflow_writer.log_metrics(
_sanitize_mlflow_metrics({f"val/{key}": total_loss_dict[key].item()}), step=state.train_state.step
)
if state.cfg.logger.log_validation_ppl_to_tensorboard:
mlflow_writer.log_metrics(
_sanitize_mlflow_metrics({f"val/{key} ppl": ppl}), step=state.train_state.step
)
if comet_logger and is_last_rank():
comet_logger.log_metrics(
{"{} validation".format(key): total_loss_dict[key].item()}, step=state.train_state.step
Expand Down
13 changes: 12 additions & 1 deletion src/megatron/bridge/training/utils/mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from pathlib import Path
from typing import Any, Optional

Expand Down Expand Up @@ -85,4 +86,14 @@ def on_load_checkpoint_success(

def _sanitize_mlflow_metrics(metrics: dict[str, Any]) -> dict[str, Any]:
"""Sanitize all metric names in a dictionary for MLFlow logging."""
return {key.replace("/", "_"): value for key, value in metrics.items()}

def _sanitize_key(key):
sanitized = key.replace("@", "_at_")
sanitized = re.sub(r"/+", "/", sanitized)
if "/" in key:
first, rest = sanitized.split("/", 1)
sanitized = first + "/" + rest.replace("/", "_")
sanitized = re.sub(r"[^/\w.\- :]", "_", sanitized)
return sanitized

return {_sanitize_key(key): value for key, value in metrics.items()}
176 changes: 109 additions & 67 deletions src/megatron/bridge/training/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def training_log(
train_config = config.train
pg_collection = get_pg_collection(model)

loggers_exist = writer is not None or wandb_writer is not None or mlflow_logger is not None

# Advanced, skipped, and Nan iterations.
advanced_iters_key = "advanced iterations"
skipped_iters_key = "skipped iterations"
Expand Down Expand Up @@ -480,7 +482,7 @@ def training_log(
dump(snapshot, f)
print_rank_0(f"Saved memory snapshot to {filename}")

if writer and (iteration % logger_config.tensorboard_log_interval == 0):
if loggers_exist and iteration % logger_config.tensorboard_log_interval == 0:
if logger_config.log_throughput_to_tensorboard:
throughput_report = report_throughput(
iteration=iteration,
Expand All @@ -489,8 +491,9 @@ def training_log(
history_wct=history_wct,
window_size=logger_config.throughput_window_size,
)
for metric, value in throughput_report.items():
writer.add_scalar(metric, value, iteration)
if writer:
for metric, value in throughput_report.items():
writer.add_scalar(metric, value, iteration)
if wandb_writer:
wandb_writer.log(throughput_report, iteration)
if mlflow_logger:
Expand All @@ -500,8 +503,9 @@ def training_log(
if logger_config.log_memory_to_tensorboard:
memory_report = report_memory(memory_keys=logger_config.memory_keys)
memory_report = {f"memory/{mem_stat}": val for (mem_stat, val) in memory_report.items()}
for metric, value in memory_report.items():
writer.add_scalar(metric, value, iteration)
if writer:
for metric, value in memory_report.items():
writer.add_scalar(metric, value, iteration)
if wandb_writer:
wandb_writer.log(memory_report, iteration)
if mlflow_logger:
Expand All @@ -516,18 +520,22 @@ def training_log(
train_iters=train_config.train_iters,
time_unit=logger_config.runtime_time_unit,
)
for metric, value in runtime_report.items():
writer.add_scalar(metric, value, iteration)
if writer:
for metric, value in runtime_report.items():
writer.add_scalar(metric, value, iteration)
if wandb_writer:
wandb_writer.log(runtime_report, iteration)
if mlflow_logger:
mlflow_logger.log_metrics(_sanitize_mlflow_metrics(runtime_report), step=iteration)
if comet_logger:
comet_logger.log_metrics(runtime_report, step=iteration)

# l2 grad norm
if logger_config.log_l2_norm_grad_to_tensorboard:
l2_report = report_l2_norm_grad(model)
for metric, value in l2_report.items():
writer.add_scalar(metric, value, iteration)
if writer:
for metric, value in l2_report.items():
writer.add_scalar(metric, value, iteration)
if wandb_writer:
wandb_writer.log(l2_report, iteration)
if mlflow_logger:
Expand All @@ -538,20 +546,28 @@ def training_log(
wandb_writer.log({"samples vs steps": train_state.consumed_train_samples}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"samples vs steps": train_state.consumed_train_samples}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"samples vs steps": train_state.consumed_train_samples}, step=iteration)
writer.add_scalar("learning-rate", learning_rate, iteration)
writer.add_scalar("learning-rate vs samples", learning_rate, train_state.consumed_train_samples)
if wandb_writer and learning_rate is not None:
wandb_writer.log({"learning-rate": learning_rate}, iteration)
if mlflow_logger and learning_rate is not None:
mlflow_logger.log_metrics({"learning-rate": learning_rate}, step=iteration)
if comet_logger and learning_rate is not None:
comet_logger.log_metrics({"learning-rate": learning_rate}, step=iteration)

# learning rate
if learning_rate is not None:
if writer:
writer.add_scalar("learning-rate", learning_rate, iteration)
writer.add_scalar("learning-rate vs samples", learning_rate, train_state.consumed_train_samples)
if wandb_writer:
wandb_writer.log({"learning-rate": learning_rate}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"learning-rate": learning_rate}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"learning-rate": learning_rate}, step=iteration)

# decoupled lr
if config.optimizer.decoupled_lr is not None:
writer.add_scalar("decoupled-learning-rate", decoupled_learning_rate, iteration)
if writer:
writer.add_scalar("decoupled-learning-rate", decoupled_learning_rate, iteration)

# skipped samples
if global_state.train_state.skipped_train_samples > 0:
writer.add_scalar("skipped-train-samples", global_state.train_state.skipped_train_samples, iteration)
if writer:
writer.add_scalar("skipped-train-samples", global_state.train_state.skipped_train_samples, iteration)
if wandb_writer:
wandb_writer.log({"skipped-train-samples": global_state.train_state.skipped_train_samples}, iteration)
if mlflow_logger:
Expand All @@ -564,80 +580,106 @@ def training_log(
{"skipped-train-samples": global_state.train_state.skipped_train_samples},
step=iteration,
)
writer.add_scalar("batch-size", batch_size, iteration)
writer.add_scalar("batch-size vs samples", batch_size, global_state.train_state.consumed_train_samples)

# batch size
if writer:
writer.add_scalar("batch-size", batch_size, iteration)
writer.add_scalar("batch-size vs samples", batch_size, global_state.train_state.consumed_train_samples)
if wandb_writer:
wandb_writer.log({"batch-size": batch_size}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"batch-size": batch_size}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"batch-size": batch_size}, step=iteration)

# loss dict
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
writer.add_scalar(key + " vs samples", loss_dict[key], global_state.train_state.consumed_train_samples)
if writer:
writer.add_scalar(key, loss_dict[key], iteration)
writer.add_scalar(key + " vs samples", loss_dict[key], global_state.train_state.consumed_train_samples)
if wandb_writer:
wandb_writer.log({key: loss_dict[key]}, iteration)
if mlflow_logger:
loss_metrics = {key: float(val) for key, val in loss_dict.items()}
mlflow_logger.log_metrics(loss_metrics, step=iteration)
if comet_logger:
comet_logger.log_metrics({key: float(val) for key, val in loss_dict.items()}, step=iteration)

# loss scale
if logger_config.log_loss_scale_to_tensorboard:
writer.add_scalar("loss-scale", loss_scale, iteration)
writer.add_scalar("loss-scale vs samples", loss_scale, global_state.train_state.consumed_train_samples)
if writer:
writer.add_scalar("loss-scale", loss_scale, iteration)
writer.add_scalar("loss-scale vs samples", loss_scale, global_state.train_state.consumed_train_samples)
if wandb_writer:
wandb_writer.log({"loss-scale": loss_scale}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"loss-scale": loss_scale}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"loss-scale": loss_scale}, step=iteration)

# world size
if logger_config.log_world_size_to_tensorboard:
writer.add_scalar("world-size", get_world_size_safe(), iteration)
writer.add_scalar(
"world-size vs samples", get_world_size_safe(), global_state.train_state.consumed_train_samples
)
if writer:
writer.add_scalar("world-size", get_world_size_safe(), iteration)
writer.add_scalar(
"world-size vs samples", get_world_size_safe(), global_state.train_state.consumed_train_samples
)
if wandb_writer:
wandb_writer.log({"world-size": get_world_size_safe()}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"world-size": get_world_size_safe()}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"world-size": get_world_size_safe()}, step=iteration)

# grad norm
if grad_norm is not None:
writer.add_scalar("grad-norm", grad_norm, iteration)
writer.add_scalar("grad-norm vs samples", grad_norm, global_state.train_state.consumed_train_samples)
if writer:
writer.add_scalar("grad-norm", grad_norm, iteration)
writer.add_scalar("grad-norm vs samples", grad_norm, global_state.train_state.consumed_train_samples)
if wandb_writer:
wandb_writer.log({"grad-norm": grad_norm}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"grad-norm": grad_norm}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"grad-norm": grad_norm}, step=iteration)

# num zeros in grad
if num_zeros_in_grad is not None:
writer.add_scalar("num-zeros", num_zeros_in_grad, iteration)
writer.add_scalar(
"num-zeros vs samples", num_zeros_in_grad, global_state.train_state.consumed_train_samples
)
if writer:
writer.add_scalar("num-zeros", num_zeros_in_grad, iteration)
writer.add_scalar(
"num-zeros vs samples", num_zeros_in_grad, global_state.train_state.consumed_train_samples
)
if wandb_writer:
wandb_writer.log({"num-zeros": num_zeros_in_grad}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"num-zeros": num_zeros_in_grad}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"num-zeros": num_zeros_in_grad}, step=iteration)

# params norm
if params_norm is not None:
writer.add_scalar("params-norm", params_norm, iteration)
writer.add_scalar("params-norm vs samples", params_norm, global_state.train_state.consumed_train_samples)
if writer:
writer.add_scalar("params-norm", params_norm, iteration)
writer.add_scalar(
"params-norm vs samples", params_norm, global_state.train_state.consumed_train_samples
)
if wandb_writer:
wandb_writer.log({"params-norm": params_norm}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics({"params-norm": params_norm}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"params-norm": params_norm}, step=iteration)

# max attention logit
if log_max_attention_logit is not None:
writer.add_scalar("max-attention-logit", log_max_attention_logit, iteration)
writer.add_scalar(
"max-attention-logit vs samples",
log_max_attention_logit,
global_state.train_state.consumed_train_samples,
)
if writer:
writer.add_scalar("max-attention-logit", log_max_attention_logit, iteration)
writer.add_scalar(
"max-attention-logit vs samples",
log_max_attention_logit,
global_state.train_state.consumed_train_samples,
)
if wandb_writer:
wandb_writer.log({"max-attention-logit": log_max_attention_logit}, iteration)
if mlflow_logger:
Expand Down Expand Up @@ -693,32 +735,31 @@ def training_log(
f"Step Time : {elapsed_time_per_iteration:.2f}s GPU utilization: {per_gpu_tf:.1f}MODEL_TFLOP/s/GPU"
)

# throughput
if logger_config.log_throughput_to_tensorboard:
if writer:
writer.add_scalar("throughput/tflops/device", per_gpu_tf, iteration)
writer.add_scalar("throughput/tflops", per_gpu_tf * get_world_size_safe(), iteration)
if wandb_writer:
wandb_writer.log({"throughput/tflops/device": per_gpu_tf}, iteration)
wandb_writer.log({"throughput/tflops": per_gpu_tf * get_world_size_safe()}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics(
_sanitize_mlflow_metrics(
{
"throughput/tflops/device": per_gpu_tf,
"throughput/tflops": per_gpu_tf * get_world_size_safe(),
}
),
step=iteration,
)
if comet_logger:
comet_logger.log_metrics(
{
"throughput/tflops/device": per_gpu_tf,
"throughput/tflops": per_gpu_tf * get_world_size_safe(),
},
step=iteration,
)

if wandb_writer:
wandb_writer.log({"throughput/tflops/device": per_gpu_tf}, iteration)
wandb_writer.log({"throughput/tflops": per_gpu_tf * get_world_size_safe()}, iteration)
if mlflow_logger:
mlflow_logger.log_metrics(
metrics={
"throughput/tflops_per_device": per_gpu_tf,
"throughput/tflops": per_gpu_tf * get_world_size_safe(),
},
step=iteration,
)
if comet_logger:
comet_logger.log_metrics(
{
"throughput/tflops/device": per_gpu_tf,
"throughput/tflops": per_gpu_tf * get_world_size_safe(),
},
step=iteration,
)
# timers
if logger_config.log_timers_to_tensorboard:
if writer:
writer.add_scalar("iteration-time", elapsed_time_per_iteration, iteration)
Expand All @@ -728,6 +769,7 @@ def training_log(
mlflow_logger.log_metrics({"iteration-time": elapsed_time_per_iteration}, step=iteration)
if comet_logger:
comet_logger.log_metrics({"iteration-time": elapsed_time_per_iteration}, step=iteration)

log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
log_string += " iteration {:8d}/{:8d} |".format(iteration, train_config.train_iters)
log_string += " consumed samples: {:12d} |".format(global_state.train_state.consumed_train_samples)
Expand Down
Loading
Loading