Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion torchtitan/tools/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import device_module, device_type
from torchtitan.tools.utils import Color, device_module, device_type

# named tuple for passing device memory stats for logging
DeviceMemStats = namedtuple(
Expand Down Expand Up @@ -154,6 +154,38 @@ def close(self) -> None:
self.wandb.finish()


def ensure_pp_loss_visible(
parallel_dims: ParallelDims, job_config: JobConfig, color: Color
) -> None:
"""
Ensures that the loss is visible on the console for pipeline-parallel training.

For pipeline-parallel training, the loss is only visible on the last pipeline stage.
This function checks if the appropriate rank is included in the LOG_RANK environment
variable and warns if it's not.
"""

# V Block Schedules return loss on rank 0
if job_config.experimental.pipeline_parallel_schedule == "ZBVZeroBubble":
return

# Calculate the rank where loss is visible (first rank of the last pipeline stage)
world_size = parallel_dims.world_size
pp_size = parallel_dims.pp
loss_visible_rank = (world_size // pp_size) * (pp_size - 1)

# Check if the loss-visible rank is included in LOG_RANK environment variable
env_logged_ranks = os.environ.get("LOG_RANK", "").split(",")
if env_logged_ranks == [""]:
env_logged_ranks = []

if str(loss_visible_rank) not in env_logged_ranks:
logger.warning(
f"{color.red}Pipeline parallel loss is not visible. "
f"Add {color.yellow}rank {loss_visible_rank}{color.red} to LOG_RANK environment variable in run_train.sh.{color.reset}"
)


def _get_metrics_rank(
parallel_dims: ParallelDims,
job_config: JobConfig,
Expand Down
10 changes: 9 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.metrics import build_device_memory_monitor, build_metric_logger
from torchtitan.tools.metrics import (
build_device_memory_monitor,
build_metric_logger,
ensure_pp_loss_visible,
)
from torchtitan.tools.profiling import (
maybe_enable_memory_snapshot,
maybe_enable_profiling,
Expand Down Expand Up @@ -185,6 +189,10 @@ def main(job_config: JobConfig):
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()

# confirm that user will be able to view loss metrics on the console
ensure_pp_loss_visible(parallel_dims, job_config, color)

else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
Expand Down