diff --git a/tests/unit_tests/test_job_config.py b/tests/unit_tests/test_job_config.py index 039981dbed..9325ea1861 100644 --- a/tests/unit_tests/test_job_config.py +++ b/tests/unit_tests/test_job_config.py @@ -52,40 +52,34 @@ def test_job_config_file_cmd_overrides(self): ) assert config.job.dump_folder == "/tmp/test_tt/" - def test_parse_pp_split_points(self): - toml_splits = ["layers.2", "layers.4", "layers.6"] - cmdline_splits = ["layers.1", "layers.3", "layers.5"] - # no split points specified - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - ] - ) - assert config.parallelism.pipeline_parallel_split_points == [] + def test_parse_module_fqns_per_model_part(self): + toml_chunks = [ + ["tok_embeddings", "layers.0"], + ["layers.1", "layers.2"], + ["layers.3", "norm", "output"], + ] + cmdline_chunks = [ + ["tok_embeddings", "layers.0", "layers.1"], + ["layers.2", "layers.3", "norm", "output"], + ] - # toml has no split points, but cmdline splits are specified + # no module names specified config_manager = ConfigManager() config = config_manager.parse_args( [ "--job.config_file", "./torchtitan/models/llama3/train_configs/debug_model.toml", - "--parallelism.pipeline_parallel_split_points", - ",".join(cmdline_splits), ] ) - assert ( - config.parallelism.pipeline_parallel_split_points == cmdline_splits - ), config.parallelism.pipeline_parallel_split_points + assert config.parallelism.module_fqns_per_model_part is None - # toml has split points, cmdline does not + # toml has module names, cmdline does not with tempfile.NamedTemporaryFile() as fp: with open(fp.name, "wb") as f: tomli_w.dump( { "parallelism": { - "pipeline_parallel_split_points": toml_splits, + "module_fqns_per_model_part": toml_chunks, } }, f, @@ -93,32 +87,43 @@ def test_parse_pp_split_points(self): config_manager = ConfigManager() config = config_manager.parse_args(["--job.config_file", fp.name]) assert ( - config.parallelism.pipeline_parallel_split_points == toml_splits - ), config.parallelism.pipeline_parallel_split_points + config.parallelism.module_fqns_per_model_part == toml_chunks + ), config.parallelism.module_fqns_per_model_part - # toml has split points, cmdline overrides them + # test that the field accepts list of lists structure with tempfile.NamedTemporaryFile() as fp: with open(fp.name, "wb") as f: tomli_w.dump( { "parallelism": { - "pipeline_parallel_split_points": toml_splits, + "module_fqns_per_model_part": cmdline_chunks, } }, f, ) config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - fp.name, - "--parallelism.pipeline_parallel_split_points", - ",".join(cmdline_splits), - ] - ) + config = config_manager.parse_args(["--job.config_file", fp.name]) + assert ( + config.parallelism.module_fqns_per_model_part == cmdline_chunks + ), config.parallelism.module_fqns_per_model_part + + # test empty chunks are handled correctly + empty_chunks = [[], ["tok_embeddings"], []] + with tempfile.NamedTemporaryFile() as fp: + with open(fp.name, "wb") as f: + tomli_w.dump( + { + "parallelism": { + "module_fqns_per_model_part": empty_chunks, + } + }, + f, + ) + config_manager = ConfigManager() + config = config_manager.parse_args(["--job.config_file", fp.name]) assert ( - config.parallelism.pipeline_parallel_split_points == cmdline_splits - ), config.parallelism.pipeline_parallel_split_points + config.parallelism.module_fqns_per_model_part == empty_chunks + ), config.parallelism.module_fqns_per_model_part def test_parse_exclude_from_loading(self): toml_splits = ["optimizer", "dataloader"] diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index eaf73bdff8..4f033e2c0a 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -290,6 +290,7 @@ class Parallelism: pipeline_parallel_split_points: list[str] = field(default_factory=list) """ + DEPRECATED: Use module_fqns_per_model_part instead. Specify comma-separated names of modules to use as the beginning of a split point. e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, the first containing all the layers up to layers.0, @@ -299,9 +300,31 @@ class Parallelism: but currently the split points must be specified manually. """ + module_fqns_per_model_part: list[list[str]] | None = None + """ + Specify a list of lists containing the FQNs (Fully Qualified Names) of modules for each model chunk. + Each inner list represents one model chunk and contains the module names that belong to that chunk. + e.g. [['tok_embeddings', 'layers.0'], ['layers.1', 'layers.2'], ['layers.3', 'layers.4']] + will create 3 chunks: the first containing tok_embeddings and layers.0, + the second containing layers.1 and layers.2, and the third containing layers.3 and layers.4. + This provides more explicit control over which modules belong to each chunk compared to split points. + """ + + pipeline_parallel_first_stage_less_layers: int = 1 + """ + The number of layers to reduce in the first stage of pipeline parallelism. This is because + the first stage has the extra overhead of the embedding layer, which is not present in the other stages. + """ + + pipeline_parallel_last_stage_less_layers: int = 1 + """ + The number of layers to reduce in the last stage of pipeline parallelism. This is because + the last stage has the extra overhead of the output layer, which is not present in the other stages. + """ + pipeline_parallel_layers_per_stage: int | None = None """ - The number of layers per (virtual) pipeline stage. If specified, the split points will be + The number of layers per (virtual) pipeline stage. If specified, the module_fqns_per_model_part will be calculated from the number of layers and pipeline_parallel_degree. If not specified, the layers per stage will be inferred from the model, schedule, and pipeline_parallel_degree. """ diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline.py index 9526a7e3b7..96cf2ed790 100644 --- a/torchtitan/distributed/pipeline.py +++ b/torchtitan/distributed/pipeline.py @@ -3,122 +3,34 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import os from typing import Callable +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage + from torch.distributed.pipelining.schedules import ( _PipelineSchedule, _PipelineScheduleRuntime, get_schedule_class, PipelineScheduleMulti, PipelineScheduleSingle, + ScheduleZBVZeroBubble, ) -from torch.distributed.pipelining.stage import PipelineStage from torchtitan.config import JobConfig from torchtitan.tools.logging import logger -__all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"] - - -# TODO: It's unclear if this API is general enough to be used by other models. -# If not, we should move it to a Transformer-specific directory. -def generate_split_points( - schedule_str: str, - pp_degree: int, - num_layers: int, - num_layers_per_stage: int | None, - input_weight: int = 1, - output_weight: int = 1, -) -> list[str]: - """ - Generate a list of split points based on the input configs. In this function, - the number of effective layers considered is the summation of num_layers, - input_weight, and output_weight. - - If num_layers_per_virtual_stage is given, we require rigid fit of the - effective layers (regular layers + weighted input + weighted output) - onto pipeline stages and ranks, with several assertions. It is the users' - responsibility to figure out the input weight, output weight, and the - number of regular layers, so that they can be arranged neatly. - - If num_layers_per_virtual_stage is None, we by default set each pipeline rank - to have 1 stage if schedule_str is a single-stage schedule, or 2 virtual stages - if it is a multi-stage schedule, and try to distribute all effective layers - evenly onto the PP stages. If there are extra layers, we disperse them in - the starting stages. - - Args: - schedule_str (str): The string of the schedule name. - pp_degree (int): The pipeline parallel dimension. - num_layers (int): The number of layers in the model. - input_weight (int): The number of layers to consider the input modules in layer calculation. - output_weight (int): The number of layers to consider the output modules in layer calculation. - num_layers_per_stage (int): The number of layers per (virtual) pipeline stage. - - Returns: - list[str]: A list of split point FQNs. - """ - - schedule_class = get_schedule_class(schedule_str) - is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) - - num_effective_layers = num_layers + input_weight + output_weight - - if num_layers_per_stage is not None: - # If num_layers_per_stage is provided, we require a rigid fit of the effective layers - assert num_effective_layers % pp_degree == 0 - num_layers_per_pipeline_rank = num_effective_layers // pp_degree - - assert num_layers_per_pipeline_rank % num_layers_per_stage == 0 - num_stages_per_rank = num_layers_per_pipeline_rank // num_layers_per_stage - - num_total_virtual_stages = num_stages_per_rank * pp_degree - num_extra_layers = 0 - - if is_single_stage_schedule: - assert ( - num_stages_per_rank == 1 - ), f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single-stage schedules." - else: - assert ( - num_stages_per_rank >= 2 - ), f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi-stage schedules." - else: - # In a multi-stage schedule, if num_layers_per_stage is not - # provided, by default each pipeline rank has 2 virtual stages. - num_stages_per_rank = 1 if is_single_stage_schedule else 2 - num_total_virtual_stages = pp_degree * num_stages_per_rank - - if num_total_virtual_stages > num_effective_layers: - raise ValueError( - "The number of total stages cannot be greater than the number of effective layers." - ) - - num_layers_per_stage = num_effective_layers // num_total_virtual_stages - num_extra_layers = num_effective_layers % num_total_virtual_stages - - assert num_layers_per_stage >= max(input_weight, output_weight) - - splits = [] - current_layer = 0 - for i in range(num_total_virtual_stages - 1): - if i == 0: - current_layer += num_layers_per_stage - input_weight - else: - current_layer += num_layers_per_stage - # extra layers will be dispersed to the first stages - if num_extra_layers > 0: - current_layer += 1 - num_extra_layers -= 1 - splits.append("layers." + str(current_layer)) - - logger.info( - "No 'pipeline_parallel_split_points' provided. Here is the auto-generated split, " - f"which may be sub-optimal: {splits}." - ) - return splits +__all__ = [ + "build_pipeline_schedule", + "stage_ids_this_rank", + "generate_llm_fqn_per_model_part", + "pipeline_module_split", +] def build_pipeline_schedule( @@ -154,7 +66,7 @@ def build_pipeline_schedule( # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training if batch_size % microbatch_size != 0: raise ValueError( - f"Batch size {job_config.training.local_batch_size} must be divisible by number of microbatches {n_microbatches}. " + f"Batch size {job_config.training.local_batch_size} must be divisible by microbatch_size {microbatch_size}. " "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." ) n_microbatches = batch_size // microbatch_size @@ -209,3 +121,234 @@ def stage_ids_this_rank( zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1)) ) return stage_v_pairs[pp_rank] + + +def generate_llm_fqn_per_model_part( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names model part, focused on LLMs models. + + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (tok_embeddings) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + + Returns: + List of lists containing module names for each model part + + Example: + generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Feasibility check: Ensure at least 1 layer in each PP stage + if layers_per_stage == 0: + raise ValueError( + f"Configuration would result in empty stages. " + f"With {num_stages} stages and {num_effective_layers} effective layers " + f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), " + f"each stage would get {layers_per_stage} layers on average. " + f"Reduce num_stages or increase num_layers/weights." + ) + + # Balance check: Ensure weights don't exceed minimum layers per stage + if input_weight > layers_per_stage: + raise ValueError( + f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + if output_weight > layers_per_stage: + raise ValueError( + f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Some model restrictions include: + - forward() method should tolerate deleted layers + - weight initialization methods should tolerate deleted layers + - Does not support nested moduledict and modulelist structures + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index af95492b82..7585b4e03e 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -11,11 +11,11 @@ from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers +from torchtitan.models.llama3.infra.pipeline import pipeline_llama from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_deepseekv3 -from .infra.pipeline import pipeline_deepseekv3 from .model.args import DeepSeekV3ModelArgs from .model.model import DeepSeekV3Model @@ -138,7 +138,7 @@ model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, - pipelining_fn=pipeline_deepseekv3, + pipelining_fn=pipeline_llama, build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, diff --git a/torchtitan/models/deepseek_v3/infra/pipeline.py b/torchtitan/models/deepseek_v3/infra/pipeline.py deleted file mode 100644 index b28ed39ee4..0000000000 --- a/torchtitan/models/deepseek_v3/infra/pipeline.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# This file applies the PT-D pipeline parallelism to the Llama model. - -import copy - -import torch -import torch.nn as nn -from torch.distributed import DeviceMesh -from torch.distributed.pipelining import PipelineStage -from torch.distributed.pipelining.schedules import ( - _PipelineSchedule, - get_schedule_class, - PipelineScheduleSingle, - ScheduleZBVZeroBubble, -) - -from torchtitan.components.loss import LossFunction -from torchtitan.config import JobConfig -from torchtitan.distributed import ParallelDims -from torchtitan.distributed.pipeline import build_pipeline_schedule, stage_ids_this_rank -from torchtitan.protocols.train_spec import ParallelizeFunction -from torchtitan.tools.logging import logger - -from ..model.args import DeepSeekV3ModelArgs - - -def generate_module_names_per_stage( - num_stages: int, - num_layers: int, - input_weight: int = 1, - output_weight: int = 1, -) -> list[list[str]]: - """ - Programmatically generates module names per stage for pipeline parallelism with weighting. - - Args: - num_stages: Number of pipeline stages - num_layers: Total number of transformer layers in the model - input_weight: Weight for input modules (tok_embeddings) in layer calculation - output_weight: Weight for output modules (norm + output) in layer calculation - - Returns: - List of lists containing module names for each stage - - Example: - generate_module_names_per_stage(2, 3, input_weight=2, output_weight=2) - treats embeddings as 2 layers and norm+output as 2 layers for distribution - """ - if num_stages < 1: - raise ValueError("Number of stages must be at least 1") - - if num_stages == 1: - # Single stage gets everything - layer_names = [f"layers.{i}" for i in range(num_layers)] - return [["tok_embeddings"] + layer_names + ["norm", "output"]] - - # Calculate effective layers including weights - num_effective_layers = num_layers + input_weight + output_weight - - if num_stages > num_effective_layers: - raise ValueError( - f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" - ) - - # Calculate layers per stage (distribute evenly) - layers_per_stage = num_effective_layers // num_stages - extra_layers = num_effective_layers % num_stages - - # Ensure each stage gets at least the weight of input/output modules - if layers_per_stage < max(input_weight, output_weight): - raise ValueError( - f"Layers per stage ({layers_per_stage}) must be >= max(input_weight={input_weight}, output_weight={output_weight})" - ) - - module_names_per_stage = [] - current_layer = 0 - - for stage_idx in range(num_stages): - stage_modules = [] - - # Calculate effective layers for this stage - effective_layers_for_stage = layers_per_stage - if stage_idx < extra_layers: - effective_layers_for_stage += 1 - - # First stage: handle input modules with weighting - if stage_idx == 0: - stage_modules.append("tok_embeddings") - # Account for input weight in layer distribution - remaining_layers_for_stage = effective_layers_for_stage - input_weight - - # Add transformer layers - for _ in range(remaining_layers_for_stage): - if current_layer < num_layers: - stage_modules.append(f"layers.{current_layer}") - current_layer += 1 - - # Last stage: handle output modules with weighting - elif stage_idx == num_stages - 1: - # Account for output weight in layer distribution - remaining_layers_for_stage = effective_layers_for_stage - output_weight - - # Add transformer layers - for _ in range(remaining_layers_for_stage): - if current_layer < num_layers: - stage_modules.append(f"layers.{current_layer}") - current_layer += 1 - - # Add output modules - stage_modules.extend(["norm", "output"]) - - # Middle stages: only transformer layers - else: - for _ in range(effective_layers_for_stage): - if current_layer < num_layers: - stage_modules.append(f"layers.{current_layer}") - current_layer += 1 - - module_names_per_stage.append(stage_modules) - - return module_names_per_stage - - -def pipeline_deepseekv3( - model: nn.Module, - parallel_dims: ParallelDims, - job_config: JobConfig, - device: torch.device, - model_config: DeepSeekV3ModelArgs, - parallelize_fn: ParallelizeFunction, - loss_fn: LossFunction, -) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] - - # Determine the number of virtual stages based on schedule type - schedule_class = get_schedule_class( - job_config.parallelism.pipeline_parallel_schedule - ) - is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) - - # For multi-stage schedules, default is 2 virtual stages per rank - # For single-stage schedules, default is 1 virtual stage per rank - stages_per_rank = 1 if is_single_stage_schedule else 2 - num_virtual_stages = parallel_dims.pp * stages_per_rank - - # Generate module names per stage programmatically with weighting - num_layers = model_config.n_layers - - # You can adjust these weights based on the computational cost of embeddings and output layers - # Higher weights mean these modules are treated as "heavier" in the distribution - input_weight = 1 # Weight for tok_embeddings - output_weight = 1 # Weight for norm + output layers - - module_names_per_stage = generate_module_names_per_stage( - num_virtual_stages, num_layers, input_weight, output_weight - ) - for i, stage_ms in enumerate(module_names_per_stage): - logger.info(f"Stage {i}: {stage_ms}") - - stages, model_parts = pipeline_module_split( - model, - pp_mesh, - job_config.parallelism.pipeline_parallel_schedule, - device, - module_names_per_stage, - ) - - # For PP with looped schedules, each item in model_parts is one stage-model-chunk. - # We need to iterate through model_parts to apply SPMD parallelisms, compilation, - # optimizer, and checkpointing - for i, m in enumerate(model_parts): - # apply SPMD-style PT-D techniques - m = parallelize_fn(m, parallel_dims, job_config) - model_parts[i] = m - # NOTE: this is to update the model in the stage - # in case the model is modified e.g. by torch.compile - stages[i].submod = m - - pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) - - # This is used in the train loop to determine whether to pass in the input_ids and labels - has_first_stage = False - has_last_stage = False - for stage in stages: - if stage.is_first: - has_first_stage = True - if stage.is_last: - has_last_stage = True - - return pp_schedule, model_parts, has_first_stage, has_last_stage - - -def pipeline_module_split( - whole_model: nn.Module, - pp_mesh: DeviceMesh, - pp_schedule: str, - device: torch.device, - module_names_per_stage: list[list[str]], -) -> tuple[list[PipelineStage], list[nn.Module]]: - """ - This API creates pipeline stages based on specified module names for each stage. - - Args: - whole_model: The complete model to be split - pp_mesh: Pipeline parallel device mesh - pp_schedule: Name of pipeline parallelism schedule - device: Device type - module_names_per_stage: List of lists, where each inner list contains the module names - that should be included in that stage. Module names should be - dot-separated paths. Examples: - - "tok_embeddings" for token embeddings - - "layers.0", "layers.1" for specific transformer layers - - "norm" for the final normalization layer - - "output" for the output projection layer - - Returns: - Tuple of (stages, models) where stages are PipelineStage objects and models are the - corresponding model chunks - - Example usage: - module_names_per_stage = [ - ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer - ["layers.1", "layers.2"], # Stage 1: middle layers - ["norm", "output"] # Stage 2: final norm + output - ] - """ - pp_rank = pp_mesh.get_local_rank() - pp_size = pp_mesh.size() - - def _build_stage_from_modules( - stage_idx: int, module_names: list[str], num_stages: int - ) -> tuple[PipelineStage, nn.Module]: - model = copy.deepcopy(whole_model) - - # Create a set of modules to keep for faster lookup - modules_to_keep = set(module_names) - print(f"Stage {stage_idx}: Modules to keep: {modules_to_keep}") - for module_name, module_value in model.named_children(): - # Handle layer-like structures (e.g., "layers.0", "layers.1") - if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): - layers_to_keep = { - name.split(".", 1)[1] - for name in modules_to_keep - if name.startswith(f"{module_name}.") - } - if layers_to_keep: - # Keep only specified layers - if isinstance(module_value, nn.ModuleDict): - for layer_name in list(module_value.keys()): - if layer_name not in layers_to_keep: - del module_value[layer_name] - elif isinstance(module_value, nn.ModuleList): - indices_to_keep = { - int(idx) for idx in layers_to_keep if idx.isdigit() - } - new_layers = nn.ModuleList( - [ - layer - for i, layer in enumerate(module_value) - if i in indices_to_keep - ] - ) - setattr(model, module_name, new_layers) - else: - # No layers from this structure needed, set to empty structure - if isinstance(module_value, nn.ModuleDict): - setattr(model, module_name, nn.ModuleDict()) - elif isinstance(module_value, nn.ModuleList): - setattr(model, module_name, nn.ModuleList()) - # Handle simple module attributes (e.g., "linear", "norm") - elif module_name not in modules_to_keep: - # Replace with None - setattr(model, module_name, None) - - stage = PipelineStage( - model, - stage_idx, - num_stages, - device, - group=pp_mesh.get_group("pp"), - ) - return stage, model - - num_stages = len(module_names_per_stage) - stages = [] - models = [] - - schedule_class = get_schedule_class(pp_schedule) - style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" - - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): - module_names = module_names_per_stage[stage_idx] - stage, model_chunk = _build_stage_from_modules( - stage_idx, - module_names, - num_stages, - ) - logger.info( - f"PP rank {pp_rank} is building stage_idx {stage_idx} " - f"with modules {module_names}" - ) - stages.append(stage) - models.append(model_chunk) - - return stages, models diff --git a/torchtitan/models/llama3/infra/pipeline.py b/torchtitan/models/llama3/infra/pipeline.py index bf88f74322..db3d6465e6 100644 --- a/torchtitan/models/llama3/infra/pipeline.py +++ b/torchtitan/models/llama3/infra/pipeline.py @@ -6,16 +6,14 @@ # This file applies the PT-D pipeline parallelism to the Llama model. -import copy +import math import torch import torch.nn as nn -from torch.distributed import DeviceMesh -from torch.distributed.pipelining import PipelineStage from torch.distributed.pipelining.schedules import ( _PipelineSchedule, get_schedule_class, - ScheduleZBVZeroBubble, + PipelineScheduleSingle, ) from torchtitan.components.loss import LossFunction @@ -23,13 +21,12 @@ from torchtitan.distributed import ParallelDims from torchtitan.distributed.pipeline import ( build_pipeline_schedule, - generate_split_points, - stage_ids_this_rank, + generate_llm_fqn_per_model_part, + pipeline_module_split, ) -from torchtitan.protocols.train_spec import ParallelizeFunction -from torchtitan.tools.logging import logger -from ..model.args import TransformerModelArgs +from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction +from torchtitan.tools.logging import logger def pipeline_llama( @@ -37,14 +34,95 @@ def pipeline_llama( parallel_dims: ParallelDims, job_config: JobConfig, device: torch.device, - model_args: TransformerModelArgs, + model_args: BaseModelArgs, parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + if job_config.parallelism.pipeline_parallel_split_points != []: + raise ValueError( + "pipeline_parallel_split_points is deprecated. Please use module_fqns_per_model_part instead." + "You can generate module_fqns_per_model_part programmatically with generate_llm_fqn_per_model_part" + ) + pp_mesh = parallel_dims.world_mesh["pp"] - stages, model_parts = pipeline_llama_manual_split( - model, pp_mesh, parallel_dims, job_config, device, model_args + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage + if hasattr(model_args, "n_layers"): + num_layers = model_args.n_layers + else: + raise ValueError("Model does not have n_layers attribute.") + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers + output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers + + # Calculate number of virtual stages + if layers_per_stage is not None: + + # Calculate number of virtual stages needed (using ceiling division) + # This allows for unequal distribution where stages can differ by at most 1 layer + num_virtual_stages = math.ceil( + (num_layers + input_weight + output_weight) / layers_per_stage + ) + + # Validation: check stages per rank based on schedule type + model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}" + stage_distribution_info = ( + f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks" + ) + + if num_virtual_stages % parallel_dims.pp != 0: + raise ValueError( + f"Number of virtual stages ({num_virtual_stages}) must be divisible by " + f"pipeline parallel size ({parallel_dims.pp}). " + f"{model_config_info}. " + f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages " + f"divisible by {parallel_dims.pp}." + ) + + stages_per_rank = num_virtual_stages // parallel_dims.pp + + if is_single_stage_schedule and stages_per_rank != 1: + raise ValueError( + f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher " + f"to achieve 1 stage per rank." + ) + + if not is_single_stage_schedule and stages_per_rank < 2: + raise ValueError( + f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank." + ) + else: + # Fallback to default behavior when layers_per_stage is not provided + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + for i, stage_ms in enumerate(module_names_per_stage): + logger.debug(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, ) # For PP with looped schedules, each item in model_parts is one stage-model-chunk. @@ -70,92 +148,3 @@ def pipeline_llama( has_last_stage = True return pp_schedule, model_parts, has_first_stage, has_last_stage - - -def pipeline_llama_manual_split( - whole_model: nn.Module, - pp_mesh: DeviceMesh, - parallel_dims: ParallelDims, - job_config: JobConfig, - device: torch.device, - model_args: TransformerModelArgs, -) -> tuple[list[PipelineStage], list[nn.Module]]: - """ - This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. - - It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. - - The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD - parallelism. - """ - pp_rank = pp_mesh.get_local_rank() - pp_size = pp_mesh.size() - parallelism_config = job_config.parallelism - - splits = parallelism_config.pipeline_parallel_split_points or generate_split_points( - parallelism_config.pipeline_parallel_schedule, - parallel_dims.pp, - model_args.n_layers, - parallelism_config.pipeline_parallel_layers_per_stage, - ) - - def _build_stage( - stage_idx: int, - start_layer: str | None, - stop_layer: str | None, - is_first: bool = False, - is_last: bool = False, - ) -> tuple[PipelineStage, nn.Module]: - model = copy.deepcopy(whole_model) - if not is_first: - model.tok_embeddings = None - - drop_layers = start_layer is not None - for name in list(model.layers.keys()): - # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) - if f"layers.{name}" == start_layer: - drop_layers = False - if f"layers.{name}" == stop_layer: - drop_layers = True - if drop_layers: - del model.layers[name] - - if not is_last: - model.norm = None - model.output = None - - stage = PipelineStage( - model, - stage_idx, - num_stages, - device, - group=pp_mesh.get_group("pp"), - ) - return stage, model - - num_stages = len(splits) + 1 - stage_idx = pp_rank - - stages = [] - models = [] - - schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule) - style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" - - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): - start_layer = splits[stage_idx - 1] if stage_idx > 0 else None - stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None - stage, model_chunk = _build_stage( - stage_idx, - start_layer, - stop_layer, - is_first=stage_idx == 0, - is_last=stage_idx == num_stages - 1, - ) - logger.info( - f"PP rank {pp_rank} is building stage_idx {stage_idx}" - f" with start_layer {start_layer}, stop_layer {stop_layer}" - ) - stages.append(stage) - models.append(model_chunk) - return stages, models