diff --git a/torchtitan/tools/metrics.py b/torchtitan/tools/metrics.py index f302d25c92..9a99b586c1 100644 --- a/torchtitan/tools/metrics.py +++ b/torchtitan/tools/metrics.py @@ -15,7 +15,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.tools.logging import logger -from torchtitan.tools.utils import device_module, device_type +from torchtitan.tools.utils import Color, device_module, device_type # named tuple for passing device memory stats for logging DeviceMemStats = namedtuple( @@ -154,6 +154,38 @@ def close(self) -> None: self.wandb.finish() +def ensure_pp_loss_visible( + parallel_dims: ParallelDims, job_config: JobConfig, color: Color +) -> None: + """ + Ensures that the loss is visible on the console for pipeline-parallel training. + + For pipeline-parallel training, the loss is only visible on the last pipeline stage. + This function checks if the appropriate rank is included in the LOG_RANK environment + variable and warns if it's not. + """ + + # V Block Schedules return loss on rank 0 + if job_config.experimental.pipeline_parallel_schedule == "ZBVZeroBubble": + return + + # Calculate the rank where loss is visible (first rank of the last pipeline stage) + world_size = parallel_dims.world_size + pp_size = parallel_dims.pp + loss_visible_rank = (world_size // pp_size) * (pp_size - 1) + + # Check if the loss-visible rank is included in LOG_RANK environment variable + env_logged_ranks = os.environ.get("LOG_RANK", "").split(",") + if env_logged_ranks == [""]: + env_logged_ranks = [] + + if str(loss_visible_rank) not in env_logged_ranks: + logger.warning( + f"{color.red}Pipeline parallel loss is not visible. " + f"Add {color.yellow}rank {loss_visible_rank}{color.red} to LOG_RANK environment variable in run_train.sh.{color.reset}" + ) + + def _get_metrics_rank( parallel_dims: ParallelDims, job_config: JobConfig, diff --git a/torchtitan/train.py b/torchtitan/train.py index 630626dbec..0993cd2614 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -21,7 +21,11 @@ from torchtitan.tools import utils from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger +from torchtitan.tools.metrics import ( + build_device_memory_monitor, + build_metric_logger, + ensure_pp_loss_visible, +) from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -185,6 +189,10 @@ def main(job_config: JobConfig): with torch.no_grad(): m.init_weights(buffer_device=buffer_device) m.train() + + # confirm that user will be able to view loss metrics on the console + ensure_pp_loss_visible(parallel_dims, job_config, color) + else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)