diff --git a/src/megatron/bridge/data/loaders.py b/src/megatron/bridge/data/loaders.py index 15786df708..1ea3be5890 100644 --- a/src/megatron/bridge/data/loaders.py +++ b/src/megatron/bridge/data/loaders.py @@ -226,6 +226,12 @@ def build_train_valid_test_data_loaders( exit_signal = cfg.train.exit_signal def worker_init_fn(_): + """ + Install the distributed exit signal handler in a dataloader worker process. + + Parameters: + _ (int): Worker process index (unused). + """ DistributedSignalHandler(exit_signal).__enter__() maybe_worker_init_fn = worker_init_fn if cfg.train.exit_signal_handler_for_dataloader else None @@ -322,18 +328,19 @@ def build_train_valid_test_data_iterators( build_train_valid_test_datasets_provider: Callable, dp_group: torch.distributed.ProcessGroup, ) -> tuple[Optional[RerunDataIterator], Optional[RerunDataIterator], Optional[RerunDataIterator]]: - """Build train, validation, and test data iterators. - - Builds the data loaders first, then wraps them in appropriate iterators - (e.g., RerunDataIterator, cyclic_iter) based on the configuration. - - Args: - cfg: The main configuration container. - train_state: The current training state. - build_train_valid_test_datasets_provider: A function to build the datasets. - + """ + Build iterators for training, validation, and testing data. + + Wraps dataloaders produced by build_train_valid_test_data_loaders into iterator objects according to the dataset's dataloader_type; uses a "cyclic" iterator for validation when the dataset is a GPTDatasetConfig. For "external" dataloaders, preserves list or single dataloader semantics. + + Parameters: + cfg (ConfigContainer): Configuration that determines dataloader types and dataset settings. + train_state (TrainState): Current training state used when constructing dataloaders. + build_train_valid_test_datasets_provider (Callable): Provider used to build datasets for the dataloaders. + dp_group (torch.distributed.ProcessGroup): Data-parallel process group used when building dataloaders. + Returns: - A tuple (train_data_iterator, valid_data_iterator, test_data_iterator). + tuple: (train_data_iterator, valid_data_iterator, test_data_iterator) where each element is a RerunDataIterator or a list of RerunDataIterator for "external" dataloaders, or `None` if the corresponding dataloader was not created. """ # Build loaders. diff --git a/src/megatron/bridge/data/mimo/base_provider.py b/src/megatron/bridge/data/mimo/base_provider.py index 4aff0b5008..bda105b9d0 100644 --- a/src/megatron/bridge/data/mimo/base_provider.py +++ b/src/megatron/bridge/data/mimo/base_provider.py @@ -39,25 +39,23 @@ class MimoDatasetProvider(DatasetProvider): def build_datasets( self, context: DatasetBuildContext ) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]: - """Build train, validation, and test datasets. + """ + Construct the provider's train, validation, and test datasets. + + Parameters: + context (DatasetBuildContext): Build context containing sample counts used during dataset construction. - Args: - context: Build context with sample counts. - Returns: - Tuple of (train_dataset, valid_dataset, test_dataset). - Any element can be None if not needed. + Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]: A 3-tuple (train_dataset, valid_dataset, test_dataset); any element may be `None` if that split is not produced. """ ... @abstractmethod def get_collate_fn(self) -> Callable: - """Return the collate function for batching. - - The collate function should handle the modality_inputs dict - and batch them appropriately for the model. + """ + Provide the callable used to collate a list of samples into a batched `modality_inputs` dictionary. Returns: - Callable that takes a list of samples and returns a batch dict. + Callable[[List[Any]], Dict[str, Any]]: A callable that accepts a list of samples and returns a dictionary mapping modality keys to their batched tensors/structures. """ ... diff --git a/src/megatron/bridge/data/mimo/collate.py b/src/megatron/bridge/data/mimo/collate.py index 76583e56db..de3f13389d 100644 --- a/src/megatron/bridge/data/mimo/collate.py +++ b/src/megatron/bridge/data/mimo/collate.py @@ -13,43 +13,28 @@ def mimo_collate_fn( batch: List[Dict[str, Any]], modality_names: List[str], ) -> Dict[str, Any]: - """Collate function for MIMO datasets. - - Stacks batch items and organizes modality inputs into a structure - suitable for MIMO model forward pass. - - Args: - batch: List of examples from MimoDataset, each containing: - - input_ids: Token IDs with placeholder tokens - - labels: Labels for causal LM training - - attention_mask: Attention mask - - position_ids: Position indices - - modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs - modality_names: List of modality names to collate. - + """ + Collate a list of MIMO dataset examples into a batched dict suitable for model input. + + Parameters: + batch: List of example dictionaries. Each example is expected to contain: + - input_ids: Tensor of token IDs. + - labels: Tensor of target token IDs. + - attention_mask: Tensor attention mask aligning with tokens. + - position_ids: Tensor of position indices for tokens. + - loss_mask: Tensor indicating per-token loss contribution. + - modality_inputs: Dict[str, Dict[str, Any]] mapping modality names to modality-specific preprocessed inputs. + modality_names: List of modality names to gather from each example's `modality_inputs`. + Returns: - Dict containing: - - input_ids: (batch, seq) stacked token IDs - - labels: (batch, seq) stacked labels - - attention_mask: (batch, seq) attention mask - - position_ids: (batch, seq) position indices - - modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors - Each modality's tensors are stacked along batch dimension. - - Example: - >>> batch = [ - ... { - ... "input_ids": torch.tensor([32000, 1, 2, 3]), - ... "labels": torch.tensor([32000, 1, 2, 3]), - ... "attention_mask": torch.ones(4), - ... "position_ids": torch.arange(4), - ... "modality_inputs": { - ... "vision": {"pixel_values": torch.randn(3, 224, 224)}, - ... }, - ... }, - ... # ... more examples - ... ] - >>> collated = mimo_collate_fn(batch, modality_names=["vision"]) + A dictionary with the following entries: + - input_ids: Tensor shaped (batch, seq) of stacked input IDs. + - labels: Tensor shaped (batch, seq) of stacked labels. + - loss_mask: Tensor shaped (batch, seq) of stacked loss masks. + - attention_mask: Tensor shaped (batch, seq) of stacked attention masks. + - position_ids: Tensor shaped (batch, seq) of stacked position indices. + - modality_inputs: Dict mapping each present modality name to a dict of batched modality values. + For each modality key, tensor values are stacked along the batch dimension when shapes permit; values that cannot be stacked or non-tensor values are returned as lists of per-example entries. """ if not batch: return {} diff --git a/src/megatron/bridge/data/mimo/dataset.py b/src/megatron/bridge/data/mimo/dataset.py index f1264ed5f9..74d836bde0 100644 --- a/src/megatron/bridge/data/mimo/dataset.py +++ b/src/megatron/bridge/data/mimo/dataset.py @@ -112,17 +112,22 @@ def __len__(self) -> int: return self._size def __getitem__(self, idx: int) -> Dict[str, Any]: - """Get a single example with preprocessed modality inputs. - + """ + Retrieve and preprocess a single dataset example, producing tokenized input IDs with modality placeholder tokens, next-token prediction labels, masking tensors, and per-modality processed inputs. + + The returned dictionary contains: + - input_ids: 1D tensor of token IDs of length `seq_length`, with modality placeholder token IDs prepended and padded/truncated to fit. + - labels: 1D tensor of next-token targets where labels[i] == input_ids[i+1], the final position is set to `-100`, and any position corresponding to padding or a modality placeholder is set to `-100` so it is ignored by loss. + - loss_mask: 1D float tensor with `1.0` for positions contributing to the loss and `0.0` for positions that should be ignored (padding, modality placeholders, and the final position). + - attention_mask: 1D tensor indicating non-padded token positions. + - position_ids: 1D tensor of position indices (0..seq_length-1). + - modality_inputs: dict mapping modality name to the processor outputs for that modality (tensors with batch dim removed where applicable). + + Parameters: + idx (int): Index of the example to retrieve. + Returns: - Dict containing: - - input_ids: Tokenized text with placeholder tokens - - labels: Shifted input_ids for next-token prediction (-100 for masked positions) - - loss_mask: Float mask (0.0 for padding/image placeholder targets, 1.0 otherwise) - - attention_mask: Attention mask - - position_ids: Position indices - - modality_inputs: Dict[str, Any] with preprocessed inputs per modality - e.g., {"vision": {"pixel_values": tensor, ...}} + dict: A mapping with keys `"input_ids"`, `"labels"`, `"loss_mask"`, `"attention_mask"`, `"position_ids"`, and `"modality_inputs"` as described above. """ if idx >= self._size: raise IndexError(f"Index {idx} out of range for dataset of size {self._size}") diff --git a/src/megatron/bridge/data/mimo/dp_utils.py b/src/megatron/bridge/data/mimo/dp_utils.py index 4ff632e781..c0dafa8e34 100644 --- a/src/megatron/bridge/data/mimo/dp_utils.py +++ b/src/megatron/bridge/data/mimo/dp_utils.py @@ -19,31 +19,21 @@ def get_mimo_dp_info( mimo_cfg: "MimoParallelismConfig", grids: Dict[str, "HyperCommGrid"], ) -> Tuple[int, int, bool, str]: - """Get DP rank, size, data-loading responsibility, and loader module for MIMO. - - Determines which module's DP settings to use for data loading based on - current rank's participation in heterogeneous deployment. - - In heterogeneous mode, each rank uses its own module's DP settings. - - Args: - mimo_cfg: MIMO parallelism configuration. - grids: Module name to HyperCommGrid mapping from build_hypercomm_grids(). - + """ + Return the DP rank and size, whether the current rank must load data, and which module's DP settings to use for the current distributed rank. + + Determines which module (from `grids`) contains the current global rank and computes the DP process-group rank and size. If the rank does not belong to any provided grid, returns defaults for a non-participating rank: (0, 1, False, MIMO_LANGUAGE_MODULE_KEY). `needs_data` is true for the language module when the rank is on the first or last PP stage; for other modules it is true only on the first PP stage. + + Parameters: + mimo_cfg (MimoParallelismConfig): MIMO parallelism configuration (kept for API compatibility). + grids (Dict[str, HyperCommGrid]): Mapping from module name to its HyperCommGrid produced by build_hypercomm_grids(). + Returns: - Tuple of (dp_rank, dp_size, needs_data, loader_module): - - dp_rank: This rank's position in DP group. - - dp_size: Size of DP group for data sharding. - - needs_data: Whether this rank needs to load data (first/last PP stage). - - loader_module: Which module's DP settings are being used. - - Example: - >>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids - >>> grids = build_hypercomm_grids(mimo_cfg) - >>> dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - >>> if needs_data: - ... # Build data loader with dp_rank and dp_size - ... sampler = DistributedSampler(dataset, num_replicas=dp_size, rank=dp_rank) + Tuple[int, int, bool, str]: `(dp_rank, dp_size, needs_data, loader_module)` where + - `dp_rank`: this rank's index within its DP process group, + - `dp_size`: size of the DP process group used for data sharding, + - `needs_data`: `True` if the rank should load data, `False` otherwise, + - `loader_module`: name of the module whose DP settings should be used (or `MIMO_LANGUAGE_MODULE_KEY` for non-participating ranks). """ current_rank = dist.get_rank() diff --git a/src/megatron/bridge/data/mimo/hf_provider.py b/src/megatron/bridge/data/mimo/hf_provider.py index 095a437560..cecde1590f 100644 --- a/src/megatron/bridge/data/mimo/hf_provider.py +++ b/src/megatron/bridge/data/mimo/hf_provider.py @@ -122,7 +122,15 @@ def _load_tokenizer(self) -> Any: return tokenizer def _load_hf_dataset(self, split: str) -> Any: - """Load a HuggingFace dataset split.""" + """ + Load a specified split from the configured HuggingFace dataset. + + Parameters: + split (str): Name of the split to load (e.g., "train", "validation", "test"). + + Returns: + dataset (Any): The loaded dataset split, or `None` if the requested split does not exist. + """ try: dataset = load_dataset( self.hf_dataset_path, @@ -146,7 +154,16 @@ def _build_split_dataset( processors: Dict[str, Any], tokenizer: Any, ) -> Optional[MimoDataset]: - """Build dataset for a single split.""" + """ + Create a MimoDataset for the specified HuggingFace dataset split. + + Parameters: + split: The name of the dataset split to load (e.g., "train", "validation", "test"). + target_samples: Maximum number of examples to include; if less than or equal to 0 the function returns `None`. + + Returns: + MimoDataset: A dataset configured with the loaded split, processors, tokenizer, and provider settings, or `None` if the split is missing or `target_samples` is <= 0. + """ if target_samples <= 0: return None @@ -170,14 +187,14 @@ def _build_split_dataset( def build_datasets( self, context: DatasetBuildContext ) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]: - """Build train, validation, and test datasets. - - Args: - context: Build context with sample counts. - + """ + Build datasets for training, validation, and testing from the configured HuggingFace dataset. + + Parameters: + context (DatasetBuildContext): Provides target sample counts for train, valid, and test builds. + Returns: - Tuple of (train_dataset, valid_dataset, test_dataset). - Any element can be None if split doesn't exist or sample count is 0. + tuple: (train_dataset, valid_dataset, test_dataset) where each element is a Dataset or `None` if the corresponding split is missing or the requested sample count is less than or equal to zero. """ processors = self._load_processors() tokenizer = self._load_tokenizer() diff --git a/src/megatron/bridge/data/mimo/loaders.py b/src/megatron/bridge/data/mimo/loaders.py index 9f934c665b..057a47f74c 100644 --- a/src/megatron/bridge/data/mimo/loaders.py +++ b/src/megatron/bridge/data/mimo/loaders.py @@ -26,41 +26,21 @@ def build_mimo_data_loaders( valid_samples: int, test_samples: int, ) -> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]: - """Build MIMO data loaders with per-module DP settings. - - Creates data loaders with DP-aware sampling based on the MIMO parallelism - configuration. Only ranks that need data (first/last PP stage) will get - non-None loaders. - - Args: - cfg: Configuration container with MimoModelProvider as cfg.model. - train_state: Current training state. - mimo_provider: MIMO dataset provider (e.g., MockMimoDatasetProvider) - with get_collate_fn() method. - train_samples: Number of training samples. - valid_samples: Number of validation samples. - test_samples: Number of test samples. - + """ + Build MIMO data loaders using per-module data-parallel (DP) sampling derived from the MIMO parallelism configuration. + + Parameters: + cfg: Configuration container whose `model` must be a `MimoModelProvider` with a populated `_grids` attribute and a non-`None` `mimo_parallelism_config`. + mimo_provider: MIMO dataset provider that exposes `build_datasets(context)` and `get_collate_fn()`, and provides loader settings (`num_workers`, `pin_memory`, `drop_last`). + train_samples (int): Number of training samples to request when building datasets. + valid_samples (int): Number of validation samples to request when building datasets. + test_samples (int): Number of test samples to request when building datasets. + Returns: - Tuple of (train_loader, valid_loader, test_loader). - Returns (None, None, None) if this rank doesn't need data. - + Tuple of `(train_loader, valid_loader, test_loader)`. Each element is a `torch.utils.data.DataLoader` configured with DP-aware `DistributedSampler`, or `None` when no dataset was built or when the current rank does not require data (in which case all three are `None`). + Raises: - ValueError: If cfg.model is not MimoModelProvider or mimo_parallelism_config is None. - - Example: - >>> from megatron.bridge.data.mimo import MockMimoProvider, build_mimo_data_loaders - >>> provider = MockMimoProvider( - ... seq_length=2048, - ... processor_paths={"vision": "openai/clip-vit-large-patch14"}, - ... tokenizer_path="meta-llama/Llama-2-7b-hf", - ... special_token_ids={"vision": 32000}, - ... modality_configs={"vision": {"type": "image", "width": 224, "height": 224}}, - ... ) - >>> train_loader, valid_loader, test_loader = build_mimo_data_loaders( - ... cfg, train_state, provider, - ... train_samples=10000, valid_samples=1000, test_samples=1000, - ... ) + ValueError: If `cfg.model` is not a `MimoModelProvider`, if `cfg.model.mimo_parallelism_config` is `None`, or if `cfg.model._grids` is `None` (indicating model infra has not been built). """ from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider @@ -105,6 +85,16 @@ def build_mimo_data_loaders( micro_batch_size = cfg.train.micro_batch_size def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]: + """ + Create a DataLoader for the given dataset configured with a DistributedSampler for the current MIMO data-parallel replica. + + Parameters: + dataset: The dataset to load; if `None`, no loader is created and `None` is returned. + shuffle (bool): Whether the distributed sampler should shuffle sample order for this split. + + Returns: + DataLoader or `None`: A DataLoader using the configured `DistributedSampler` and provider settings, or `None` when `dataset` is `None`. + """ if dataset is None: return None sampler = torch.utils.data.DistributedSampler( diff --git a/src/megatron/bridge/data/mimo/mock_provider.py b/src/megatron/bridge/data/mimo/mock_provider.py index 4f364de941..aa1efda689 100644 --- a/src/megatron/bridge/data/mimo/mock_provider.py +++ b/src/megatron/bridge/data/mimo/mock_provider.py @@ -23,7 +23,17 @@ def _generate_random_image(width: int, height: int, rng: np.random.Generator) -> Image.Image: - """Generate a random RGB image.""" + """ + Create a PIL RGB image of the given pixel dimensions with random pixel values. + + Parameters: + width (int): Image width in pixels. + height (int): Image height in pixels. + rng (np.random.Generator): NumPy random number generator used to sample uint8 RGB pixel values deterministically when seeded. + + Returns: + Image.Image: A PIL Image in "RGB" mode containing random pixels. + """ pixels = rng.integers(low=0, high=256, size=(height, width, 3), dtype=np.uint8) return Image.fromarray(pixels, mode="RGB") @@ -110,7 +120,17 @@ def _load_processors(self) -> Dict[str, Any]: return processors def _load_tokenizer(self) -> Any: - """Load HuggingFace tokenizer.""" + """ + Load and cache the HuggingFace tokenizer specified by `tokenizer_path`. + + If a tokenizer is already cached on this instance, the cached object is returned. Ensures the tokenizer has a `pad_token` by assigning `eos_token` to `pad_token` when missing and stores the loaded tokenizer in `self._tokenizer`. + + Raises: + ValueError: If `self.tokenizer_path` is empty or falsy. + + Returns: + tokenizer: The loaded HuggingFace tokenizer instance. + """ if self._tokenizer is not None: return self._tokenizer diff --git a/src/megatron/bridge/models/mimo/mimo_builder.py b/src/megatron/bridge/models/mimo/mimo_builder.py index 0f70e3d19b..2a10be3d4a 100644 --- a/src/megatron/bridge/models/mimo/mimo_builder.py +++ b/src/megatron/bridge/models/mimo/mimo_builder.py @@ -15,16 +15,16 @@ def build_hypercomm_grids( mimo_parallelism_config: MimoParallelismConfig, ) -> Dict[str, "HyperCommGrid"]: - """Create HyperCommGrid objects per module from MIMO parallelism config. - - Creates grids on ALL ranks (required for consistent collective calls), - but only ranks in each grid's range will participate in its operations. - - Args: - mimo_parallelism_config: MimoParallelismConfig specifying parallelism per module. - + """ + Constructs a HyperCommGrid for each module in the given MIMO parallelism configuration. + + Grids are created on all ranks (so collective operations are consistent); only ranks within a grid's range participate in that grid's operations. + + Parameters: + mimo_parallelism_config (MimoParallelismConfig): Configuration mapping module names to their parallelism specifications. + Returns: - Dict mapping module names to their HyperCommGrids. + Dict[str, "HyperCommGrid"]: Mapping from module name to its HyperCommGrid. """ from megatron.core.hyper_comm_grid import HyperCommGrid @@ -59,20 +59,16 @@ def build_hypercomm_grids( def populate_embedding_and_position_groups( pp_group: dist.ProcessGroup, ) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: - """Create embedding-related process groups from PP group ranks. - - Following MCore semantics: - - pos_embd_pg: Only rank 0 of PP (first stage) - for position embeddings - - embd_pg: Ranks 0 and -1 of PP (first and last stages) - for tied word embeddings - - IMPORTANT: This calls dist.new_group which is a collective operation. - Must be called on all ranks that could participate. - + """ + Create process groups for position embeddings and tied word embeddings based on pipeline-parallel ranks. + + Position-embedding group contains only the first PP stage rank; embedding group contains the first and, if different, the last PP stage ranks. This operation calls `dist.new_group`, which is a collective and must be invoked on all ranks that could participate. + Args: - pp_group: The pipeline parallel process group. - + pp_group (dist.ProcessGroup): The pipeline-parallel process group or `None`. + Returns: - Tuple of (pos_embd_pg, embd_pg). Returns (None, None) if pp_group is None. + Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: `(pos_embd_pg, embd_pg)` where `pos_embd_pg` is the group for position embeddings and `embd_pg` is the group for tied word embeddings; returns `(None, None)` if `pp_group` is `None`. """ if pp_group is None: return None, None @@ -93,7 +89,15 @@ def populate_embedding_and_position_groups( def is_pp_first_stage(pp_group: Optional[dist.ProcessGroup]) -> bool: - """Check if current rank is first stage in pipeline.""" + """ + Determine whether the current process is the first stage in the given pipeline-parallel process group. + + Parameters: + pp_group (Optional[dist.ProcessGroup]): The pipeline-parallel process group to inspect. If `None`, the current process is treated as first stage. + + Returns: + bool: `true` if the current rank equals the smallest rank in `pp_group` or if `pp_group` is `None`, `false` otherwise. + """ if pp_group is None: return True pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) @@ -101,7 +105,15 @@ def is_pp_first_stage(pp_group: Optional[dist.ProcessGroup]) -> bool: def is_pp_last_stage(pp_group: Optional[dist.ProcessGroup]) -> bool: - """Check if current rank is last stage in pipeline.""" + """ + Determine whether the current process is the last stage of the given pipeline-parallel process group. + + Parameters: + pp_group (Optional[dist.ProcessGroup]): The pipeline-parallel process group to inspect. If `None`, the pipeline is treated as a single-stage group. + + Returns: + `true` if the current process rank is the last rank in `pp_group`, `false` otherwise. + """ if pp_group is None: return True pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) diff --git a/src/megatron/bridge/models/mimo/mimo_config.py b/src/megatron/bridge/models/mimo/mimo_config.py index 5d895cac55..6e63357497 100644 --- a/src/megatron/bridge/models/mimo/mimo_config.py +++ b/src/megatron/bridge/models/mimo/mimo_config.py @@ -84,7 +84,15 @@ def total_world_size(self) -> int: return max(ranges) if ranges else 0 def _validate_heterogeneous(self) -> None: - """Validate heterogeneous deployment: no overlapping rank ranges.""" + """ + Ensure module rank ranges do not overlap in a heterogeneous deployment. + + This verifies that every module has `data_parallel_size` set and that the half-open rank ranges + defined by (`rank_offset`, `rank_offset + total_ranks`) for all modules are non-overlapping. + Raises: + ValueError: If a module's `data_parallel_size` is `None`. + ValueError: If any module's rank range overlaps a previous module's range. + """ ranges = [] for name, parallelism in self.module_parallelisms.items(): if parallelism.data_parallel_size is None: @@ -99,13 +107,25 @@ def _validate_heterogeneous(self) -> None: raise ValueError("rank_offset ranges overlap in heterogeneous deployment.") def _validate_parallelism_constraints(self) -> None: - """Validate parallelism constraints for cross-module communication. - - - TP sizes must be powers of 2 - - DP sizes must be pairwise divisible (one divides the other) + """ + Validate that module parallelism settings meet constraints required for cross-module communication and embedding alignment. + + Performs these checks: + - Tensor-parallel (TP) sizes for every module must be powers of two. + - Data-parallel (DP) sizes between every pair of modules (when both are set) must be divisible (one must divide the other). + - For embedding alignment, every non-language ("encoder") module's DP (when set) must be greater than or equal to the language module's DP. + + Raises: + ValueError: If any TP is not a power of two, if any pair of set DP sizes are not divisible, or if an encoder module's DP is less than the language module's DP. """ def is_power_of_two(n: int) -> bool: + """ + Check whether an integer is a power of two. + + Returns: + `True` if `n` is a power of two and greater than zero, `False` otherwise. + """ return n > 0 and (n & (n - 1)) == 0 # Validate TP is power of 2 @@ -148,11 +168,16 @@ def is_power_of_two(n: int) -> bool: ) def finalize(self, world_size: int) -> None: - """Finalize parallelism config: compute data_parallel_size and validate. - - Args: - world_size: Total number of ranks in the distributed world. - MIMO requires a distributed environment, so this must always be provided. + """ + Finalize and validate all module parallelism configurations against the provided world size. + + Parameters: + world_size (int): Total number of ranks in the distributed world; must match the computed total from module configurations. + + Raises: + ValueError: If the language module (MIMO_LANGUAGE_MODULE_KEY) is missing from module_parallelisms. + ValueError: If any module's configuration is invalid or inconsistent with heterogeneous constraints. + ValueError: If the provided world_size does not equal the computed total world size when the computed total is non-zero. """ if MIMO_LANGUAGE_MODULE_KEY not in self.module_parallelisms: raise ValueError( diff --git a/src/megatron/bridge/models/mimo/mimo_ddp.py b/src/megatron/bridge/models/mimo/mimo_ddp.py index ca091abca9..fd1718fb1f 100644 --- a/src/megatron/bridge/models/mimo/mimo_ddp.py +++ b/src/megatron/bridge/models/mimo/mimo_ddp.py @@ -29,19 +29,21 @@ def wrap_mimo_model_distributed( grids: Dict[str, "HyperCommGrid"], pg_collections: Dict[str, Optional["ProcessGroupCollection"]], ) -> "MimoModel": - """Wrap MIMO model's submodules with DDP. - - Modifies mimo_model in-place and returns it. - - Args: - mimo_model: The MimoModel to wrap. - ddp_config: DDP configuration from Bridge. - mimo_parallelism_config: MIMO parallelism configuration. - grids: Module name to HyperCommGrid mapping. - pg_collections: Module name to ProcessGroupCollection mapping. - + """ + Wraps the MimoModel's language model and modality submodules with Megatron DistributedDataParallel (DDP) in-place for ranks that participate in each module's HyperCommGrid. + + Parameters: + mimo_model: The MimoModel to modify in-place. + ddp_config: DDP configuration used for each DistributedDataParallel wrapper. + mimo_parallelism_config: MIMO parallelism configuration (accepted but not used by this function). + grids: Mapping from module name keys to HyperCommGrid instances; presence and membership determine whether a module is wrapped. + pg_collections: Mapping from module name keys to ProcessGroupCollection instances; a module is wrapped only if its process-group collection is present. + Returns: - The same mimo_model with wrapped submodules. + The same mimo_model instance with any wrapped submodules replaced by DistributedDataParallel wrappers. + + Raises: + AttributeError: If a modality submodule has encoders but the selected first encoder lacks a `config` attribute (required for DDP configuration). """ from megatron.core.distributed import DistributedDataParallel diff --git a/src/megatron/bridge/models/mimo/mimo_provider.py b/src/megatron/bridge/models/mimo/mimo_provider.py index f63ee126be..bd5b6d82ce 100644 --- a/src/megatron/bridge/models/mimo/mimo_provider.py +++ b/src/megatron/bridge/models/mimo/mimo_provider.py @@ -134,18 +134,15 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): init_model_with_meta_device: bool = False def build_infra(self) -> MimoModelInfra: - """Build MIMO parallelism infrastructure. - - This method builds HyperCommGrids, ProcessGroupCollections, and topology - for MIMO's heterogeneous parallelism. It is idempotent and does not - mutate provider state (results are not cached). - - Can be called before or after provide(). Call finalize() first to - validate the parallelism configuration. - + """ + Builds the MIMO parallelism infrastructure and returns a MimoModelInfra describing it. + + Builds per-module HyperCommGrids (when a MIMO parallelism config is provided), derives per-module ProcessGroupCollections for the current rank, constructs the module topology (using provider `topology` if set, otherwise a default routing where each modality maps to the language module), and determines which modules this rank participates in. This method caches the created grids on the provider in the private `_grids` attribute. + + Call `finalize()` before invoking this method to validate the parallelism configuration. If `mimo_parallelism_config` is None, returned `module_to_grid_map` and `pg_collections` will be empty; topology will still be derived from `modality_submodules_spec` (or from `self.topology` when provided). + Returns: - MimoModelInfra containing grids, topology, pg_collections, - and the list of modules this rank participates in. + MimoModelInfra: contains `module_to_grid_map` (module → HyperCommGrid), `topology` (module DAG), `pg_collections` (module → ProcessGroupCollection or None for non-participating ranks), `participating_modules` (list of modules with non-None pg_collections), and `module_output_ndim` (module → expected output tensor ndim). """ if self.mimo_parallelism_config is not None: grids = build_hypercomm_grids(self.mimo_parallelism_config) @@ -184,10 +181,14 @@ def _get_pg_collections_from_grids( self, grids: Dict[str, "HyperCommGrid"], ) -> Dict[str, Optional[ProcessGroupCollection]]: - """Get ProcessGroupCollections from HyperCommGrids. - - Creates all standard process groups plus embedding groups for PP > 1. - Returns None for modules this rank doesn't participate in. + """ + Construct per-module ProcessGroupCollection objects from provided HyperCommGrid instances. + + Parameters: + grids (Dict[str, HyperCommGrid]): Mapping from module name to its HyperCommGrid. + + Returns: + Dict[str, Optional[ProcessGroupCollection]]: Mapping from module name to a ProcessGroupCollection containing the standard process groups and, when applicable, embedding/position groups; `None` for modules in which the current rank does not participate. """ pg_collections: Dict[str, Optional[ProcessGroupCollection]] = {} current_rank = dist.get_rank() @@ -228,7 +229,18 @@ def _inject_pg_collection_into_language_spec( pre_process: Optional[bool] = None, post_process: Optional[bool] = None, ) -> ModuleSpec: - """Deep copy language model spec and inject stage-aware params.""" + """ + Deep-copy a language ModuleSpec and inject MIMO process-group information and optional pipeline stage flags. + + Parameters: + spec (ModuleSpec): Language model module specification to copy and modify. + pg_collection (ProcessGroupCollection): Process-group collection to attach to the spec's params. + pre_process (Optional[bool]): If provided, set `params["pre_process"]` to this value to indicate the module runs in a pre-processing PP stage. + post_process (Optional[bool]): If provided, set `params["post_process"]` to this value to indicate the module runs in a post-processing PP stage. + + Returns: + ModuleSpec: A deep-copied spec with `params["pg_collection"]` set and `params["pre_process"]`/`params["post_process"]` set when provided. + """ spec = copy.deepcopy(spec) if spec.params is None: spec.params = {} @@ -244,7 +256,20 @@ def _inject_pg_collection_into_modality_spec( spec: ModuleSpec, pg_collection: ProcessGroupCollection, ) -> ModuleSpec: - """Inject pg_collection into encoder specs within a modality submodule.""" + """ + Create and return a copy of a modality ModuleSpec with MIMO process-group information injected into its submodules. + + The returned spec is a deep copy of `spec` where: + - `pg_collection` is assigned to `params` of each encoder spec found under `submodules["encoders"]`. + - `tp_group` is assigned to `params` of each ModuleSpec found in `submodules["input_projections"]` if `tp_group` is not already present. + + Parameters: + spec (ModuleSpec): The modality module specification to copy and modify. + pg_collection (ProcessGroupCollection): The process-group collection whose `tp` group and full collection are injected. + + Returns: + ModuleSpec: A deep-copied ModuleSpec with the injected `pg_collection` and `tp_group`. + """ spec = copy.deepcopy(spec) # Inject into encoders @@ -271,27 +296,20 @@ def provide( post_process: Optional[bool] = None, vp_stage: Optional[int] = None, ) -> MimoModel: - """Build and return the MimoModel instance. - - This method follows the standard ModelProviderMixin.provide() contract, - returning only the model instance. For infrastructure metadata (grids, - topology, pg_collections), use build_infra() separately. - - Args: - pre_process: Unused for MIMO (accepted for API compatibility). - post_process: Unused for MIMO (accepted for API compatibility). - vp_stage: Unused for MIMO (accepted for API compatibility). - + """ + Constructs and returns a CPU MimoModel with per-module process-group injections applied when MIMO parallelism is configured. + + Parameters: + pre_process (Optional[bool]): Accepted for API compatibility; ignored by this provider. + post_process (Optional[bool]): Accepted for API compatibility; ignored by this provider. + vp_stage (Optional[int]): Accepted for API compatibility; ignored by this provider. + Returns: - MimoModel instance. - - Note: - Device/dtype handling is done by provide_distributed_model(), - consistent with other providers. This method returns a CPU model. - + MimoModel: The constructed model; language and modality specs will have module-specific + process-group information injected when a MIMO parallelism configuration is present. + Raises: - ValueError: If language_model_spec is not set, or if this rank - doesn't participate in any module. + ValueError: If `language_model_spec` is not set prior to calling this method. """ if self.language_model_spec is None: raise ValueError( @@ -518,7 +536,17 @@ def initialize_model_parallel( ) def _apply_freezing(self, model: MimoModel) -> None: - """Apply freezing based on configuration.""" + """ + Apply configured parameter freezing to the provided MimoModel. + + Sets requires_grad = False for: + - the entire language model when `self.freeze_language_model` is True and the model has `language_model`; + - encoder parameters of modalities listed in `self.freeze_modality_encoders` where the value is True and the model has the modality with an `encoders` attribute; + - input projection parameters of modalities listed in `self.freeze_modality_projections` where the value is True and the model has the modality with an `input_projections` attribute. + + Parameters: + model (MimoModel): The MIMO model whose submodule parameters will be frozen according to provider settings. + """ if self.freeze_language_model and hasattr(model, "language_model"): for param in model.language_model.parameters(): param.requires_grad = False @@ -539,16 +567,16 @@ def _apply_freezing(self, model: MimoModel) -> None: param.requires_grad = False def finalize(self) -> None: - """Finalize MIMO parallelism configuration. - - This validates the parallelism config and should be called before - build_infra() or provide(). It is called automatically by - provide_distributed_model(). - + """ + Finalize the MIMO parallelism configuration by validating it against the current distributed world size. + + Validates that torch.distributed is initialized and then calls the configured + MIMO parallelism config's finalize routine with the current world size. + Raises: - ValueError: If any rank doesn't participate in at least one module. - This indicates the parallelism configuration doesn't cover all - ranks in the world (validated by MimoParallelismConfig.finalize()). + RuntimeError: If torch.distributed is not initialized prior to calling finalize(). + ValueError: If the parallelism configuration is invalid for the current world size + (propagated from MimoParallelismConfig.finalize()). """ if self.mimo_parallelism_config is not None: if not dist.is_initialized(): diff --git a/src/megatron/bridge/training/mimo_parallel_utils.py b/src/megatron/bridge/training/mimo_parallel_utils.py index 94dd801ef8..8ba2ec002d 100644 --- a/src/megatron/bridge/training/mimo_parallel_utils.py +++ b/src/megatron/bridge/training/mimo_parallel_utils.py @@ -60,13 +60,14 @@ def unwrap_mimo_model(model) -> MimoModel: def is_current_rank_in_grid(grid: "HyperCommGrid") -> bool: - """Check if current rank participates in the given grid. - - Args: - grid: HyperCommGrid to check participation in. - + """ + Determine whether the current distributed process rank lies within the grid's contiguous rank range. + + Parameters: + grid (HyperCommGrid): Grid whose rank range will be checked. + Returns: - True if current rank is within the grid's rank range. + `true` if the current process rank is within the grid's rank range, `false` otherwise. """ current_rank = dist.get_rank() return grid.rank_offset <= current_rank < (grid.rank_offset + grid.size) @@ -76,14 +77,15 @@ def get_module_to_grid_tuple( mimo_model: MimoModel, infra: MimoModelInfra, ) -> List[Tuple]: - """Build list of (module, grid) tuples for all modules the current rank participates in. - - Args: - mimo_model: The MimoModel instance. - infra: MimoModelInfra containing module_to_grid_map. - + """ + Map participating submodules of the provided MimoModel to their corresponding HyperCommGrid entries. + + Parameters: + mimo_model (MimoModel): The MimoModel to inspect; wrapped models (e.g., DDP/Float16Module) are unwrapped. + infra (MimoModelInfra): Infrastructure containing `module_to_grid_map` that associates module names with grids. + Returns: - List of (module, grid) tuples for modules this rank participates in. + List[Tuple]: A list of (module, grid) tuples for each submodule whose grid includes the current rank. """ module_to_grid_tuple = [] @@ -109,19 +111,16 @@ def get_module_to_grid_tuple( def build_pg_collection_for_schedule(infra: MimoModelInfra): - """Build pg_collection compatible with schedule. - - Primary: Use MultiModuleProcessGroupCollection if PR 3212 allows - missing LLM PG on encoder-only ranks. - Fallback: Return list of ProcessGroupCollections for participating modules. - - IMPORTANT: Uses infra.pg_collections directly. Do NOT rebuild PGs. - - Args: - infra: MimoModelInfra with pg_collections for each module. - + """ + Constructs a schedule-compatible process-group collection from the infra's per-module process groups. + + Prefers an aggregated MultiModuleProcessGroupCollection when available; otherwise returns a list of the existing per-module ProcessGroupCollection objects. Uses the mappings in `infra.pg_collections` and does not create or rebuild process groups. + + Parameters: + infra (MimoModelInfra): Infrastructure object containing `pg_collections` mapping module names to their ProcessGroupCollection (or None). + Returns: - MultiModuleProcessGroupCollection or list of ProcessGroupCollections. + MultiModuleProcessGroupCollection or list: A MultiModuleProcessGroupCollection aggregating the per-module PGs when supported, or a list of the existing per-module ProcessGroupCollection instances otherwise. """ try: from megatron.core.process_groups_config import MultiModuleProcessGroupCollection @@ -141,17 +140,13 @@ def build_pg_collection_for_schedule(infra: MimoModelInfra): @contextmanager def multimodule_no_sync(*, module_to_grid_tuple: List[Tuple]): - """Context manager to disable gradient sync for all modules during microbatch accumulation. - - This function is designed to be used with functools.partial() to pre-bind - the module_to_grid_tuple parameter, since the schedule calls no_sync_func() - with no arguments. - - Args: - module_to_grid_tuple: List of (module, grid) tuples (keyword-only, bound via partial). - - Yields: - None - context manager for gradient sync control. + """ + Disable gradient synchronization for all participating modules by entering each module's `no_sync()` context. + + This context manager enters `no_sync()` for every module in `module_to_grid_tuple` whose grid includes the current rank, yielding control while those contexts are active and exiting them on completion. + + Parameters: + module_to_grid_tuple (List[Tuple]): List of `(module, grid)` pairs; modules that are `None` or whose grid does not include the current rank are ignored. """ contexts = [] for module, grid in module_to_grid_tuple: @@ -179,21 +174,18 @@ def finalize_model_grads_multimodule( infra: MimoModelInfra, module_to_grid_tuple: List[Tuple], ): - """Finalize gradients for each module using infra.pg_collections. - - IMPORTANT: Signature matches schedule's call pattern: - config.finalize_model_grads_func([model], num_tokens, pg_collection, force_all_reduce=flag) - - The `infra` and `module_to_grid_tuple` parameters are pre-bound via partial(). - We ignore the schedule-provided `pg_collection` and use per-module PGs. - - Args: - model: Model list (passed by schedule, ignored - we use module_to_grid_tuple). - num_tokens: Token count for gradient scaling. - pg_collection: Schedule-provided PG (ignored - we use per-module PGs). - force_all_reduce: Schedule-provided flag (ignored - per-module PGs control sync). - infra: MimoModelInfra with per-module pg_collections (keyword-only, bound via partial). - module_to_grid_tuple: List of (module, grid) tuples (keyword-only, bound via partial). + """ + Finalize gradients for each participating MIMO submodule using that module's process-group collection. + + For every (module, grid) in `module_to_grid_tuple` where the current rank belongs to `grid`, this calls Megatron's internal finalize routine for that single module using the per-module PG from `infra.pg_collections`. This function intentionally ignores the schedule-provided `pg_collection` and `force_all_reduce` arguments; `infra` and `module_to_grid_tuple` are expected to be pre-bound (e.g., via partial). + + Parameters: + model: Ignored. Present to match the schedule's call signature. + num_tokens: Token count forwarded to the underlying finalize call for gradient scaling. + pg_collection: Ignored. Per-module PGs from `infra.pg_collections` are used instead. + force_all_reduce: Ignored. Per-module PG behavior determines synchronization. + infra: MimoModelInfra providing `module_to_grid_map` and `pg_collections`. + module_to_grid_tuple: List of (module, grid) tuples indicating modules and their grids; only modules whose grid contains the current rank are finalized. """ for module, grid in module_to_grid_tuple: if module is not None and is_current_rank_in_grid(grid): @@ -210,10 +202,13 @@ def finalize_model_grads_multimodule( def zero_grad_buffer_for_multimodule(module_to_grid_tuple: List[Tuple]): - """Reset gradient buffers for all DDP-wrapped modules. - - Args: - module_to_grid_tuple: List of (module, grid) tuples. + """ + Reset gradient buffers for participating multimodule submodules that expose `zero_grad_buffer`. + + Parameters: + module_to_grid_tuple (List[Tuple]): Sequence of `(module, grid)` pairs. For each pair, + if the current rank is inside `grid` and `module` implements `zero_grad_buffer()`, + that method will be invoked. """ for module, grid in module_to_grid_tuple: if module is not None and is_current_rank_in_grid(grid): @@ -257,21 +252,21 @@ def validate_data_loader_contract( micro_batch_size: int, num_microbatches: int, ): - """Validate data loading constraints for multimodule training. - - Checks: - - Global batch size divisible by all module DP sizes - - Micro-batch size consistent with per-module sharding - - num_microbatches * micro_batch_size == global_batch_size / DP_size (per module) - - Args: - infra: MimoModelInfra with module_to_grid_map. - global_batch_size: Total batch size across all data parallel ranks. - micro_batch_size: Batch size per microbatch. - num_microbatches: Number of microbatches per iteration. - + """ + Validate that global and microbatch sizes satisfy each module's data-parallel constraints. + + Checks per configured module that (1) `global_batch_size` is divisible by the module's data-parallel (DP) size and (2) `num_microbatches * micro_batch_size` equals the per-DP partition of `global_batch_size`. + + Parameters: + infra (MimoModelInfra): Infrastructure containing `module_to_grid_map`. + global_batch_size (int): Total batch size across all data-parallel ranks. + micro_batch_size (int): Batch size for a single microbatch. + num_microbatches (int): Number of microbatches accumulated per iteration. + Raises: - ValueError: If any constraint is violated. + ValueError: If `global_batch_size` is not divisible by a module's DP size, or if + `num_microbatches * micro_batch_size` does not equal `global_batch_size // dp_size` + for any module. """ for module_name, grid in infra.module_to_grid_map.items(): # Get DP size from grid diff --git a/src/megatron/bridge/training/mimo_step.py b/src/megatron/bridge/training/mimo_step.py index eb6598dd89..808dd6f155 100644 --- a/src/megatron/bridge/training/mimo_step.py +++ b/src/megatron/bridge/training/mimo_step.py @@ -58,19 +58,14 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor) -> Tuple: def get_batch(data_iterator: Iterable) -> Optional[Dict[str, torch.Tensor]]: - """Get batch from data iterator. - - Returns dict with: - - input_ids, labels, loss_mask, position_ids (for LLM) - - modality_inputs: {modality_name: preprocessed_tensors} (for encoders) - - Uses existing MimoDataset format from Phase 3. - - Args: - data_iterator: Iterator over the dataset. - + """ + Retrieve the next dataset batch and move any contained torch.Tensors to CUDA. + + Parameters: + data_iterator (Iterable): Iterator over dataset batches; may be None. + Returns: - Batch dictionary or None if iterator is exhausted. + dict or None: The next batch dictionary with tensors transferred to CUDA (non-blocking), or `None` if `data_iterator` is None or exhausted. The batch follows the MimoDataset format and typically includes keys such as `input_ids`, `labels`, `loss_mask`, `position_ids`, and `modality_inputs` (a mapping of modality names to preprocessed tensors). """ if data_iterator is None: return None @@ -82,6 +77,15 @@ def get_batch(data_iterator: Iterable) -> Optional[Dict[str, torch.Tensor]]: # Move tensors to GPU if not already there def _move_to_cuda(obj): + """ + Recursively move any torch.Tensor values found in obj to CUDA, preserving container types and leaving non-tensor values unchanged. + + Parameters: + obj (Any): A tensor, dict, list, tuple, or nested combination thereof. + + Returns: + Any: The same structure as `obj` with all torch.Tensor instances moved to CUDA (using non_blocking=True) if they were not already on CUDA. + """ if isinstance(obj, torch.Tensor): return obj.cuda(non_blocking=True) if not obj.is_cuda else obj if isinstance(obj, dict): @@ -102,30 +106,24 @@ def forward_step( data_iterator: Iterable, model: MimoModel, ) -> Tuple[torch.Tensor, Optional[partial]]: - """Forward step for MIMO model training. - - Uses 3-arg signature with GlobalState for Bridge compatibility. - The training loop wraps this with prepare_forward_step_func() which: - - Injects GlobalState automatically if forward_step accepts it - - Provides access to state.timers, state.cfg, state.train_state - - The MimoModel handles dict-based tensor flow internally: - - Encoder modules produce activations sent via BridgeCommunicator - - LLM module receives encoder outputs and produces loss - - At terminal stage: returns (loss_tensor, loss_func) - At intermediate stages: returns (output_dict, None) - schedule handles communication - - GUARDRAIL: At last stage, assert output is scalar tensor (not dict) to catch - misconfigurations early with a clear error message. - - Args: - state: GlobalState containing timers, config, train_state. - data_iterator: Iterator over the dataset. - model: MimoModel instance. - + """ + Perform a pipeline-compatible forward pass for a MIMO model, returning either the terminal-stage loss tensor and a loss-function partial or intermediate activations and None. + + Determines whether the current pipeline rank requires dataset inputs, fetches or constructs the appropriate data batch, runs model(**data_batch), and normalizes model outputs to (output_tensor, optional_loss_mask). If this rank is the language-module's last pipeline stage, returns the scalar loss tensor and a callable (partial of loss_func) that, when invoked with that tensor, computes (total_loss, total_tokens, {"lm loss": reporting_loss}). For non-terminal stages, returns activations and None so the pipeline schedule can pass them to other stages. + + Parameters: + state (GlobalState): Bridge-compatible global state (timers, config, train_state). + data_iterator (Iterable): Iterator over dataset batches; may be None for non-data ranks. + model (MimoModel): The wrapped MIMO model to execute. + Returns: - Tuple of (output_tensor, loss_function or None). + Tuple[torch.Tensor, Optional[partial]]: + - First element: the model's output tensor (scalar loss at terminal language stage or activations at intermediate stages). + - Second element: a `functools.partial` of `loss_func` bound to the applicable loss mask for terminal language-stage losses, or `None` for intermediate stages. + + Raises: + RuntimeError: If this rank requires data but get_batch returns None (indicates data-loading/parallelism misconfiguration). + ValueError: If the terminal language-stage returns a dict instead of a scalar loss tensor. """ # Get the model's role to determine if we're at first pipeline stage mimo_model = unwrap_mimo_model(model) diff --git a/src/megatron/bridge/training/pretrain_mimo.py b/src/megatron/bridge/training/pretrain_mimo.py index a78bd58490..7e3ab171a1 100644 --- a/src/megatron/bridge/training/pretrain_mimo.py +++ b/src/megatron/bridge/training/pretrain_mimo.py @@ -47,13 +47,22 @@ def _set_mimo_random_seeds( cfg: ConfigContainer, mimo_infra: "MimoModelInfra", ) -> None: - """Initialize random seeds with per-module TP/PP awareness. - - Mirrors the standard path's ``_set_random_seed()`` but derives TP/PP ranks - from the per-module HyperCommGrids instead of global MPU state. - - Must be called **after** ``build_infra()`` (grids exist) and **before** - ``provide_distributed_model()`` (weight init needs the CUDA RNG tracker). + """ + Set random seeds for Python, NumPy, PyTorch, and (if available) CUDA using the + tensor-parallel (TP) and pipeline-parallel (PP) ranks derived from the MIMO + module grids. + + This function reads the base seed from `cfg.seed` or `cfg.rng.seed` (default 1234), + determines the TP and PP ranks for the current process by inspecting + `mimo_infra.module_to_grid_map`, offsets the base seed by `100 * pp_rank`, and + applies the resulting seed to Python `random`, `numpy.random`, `torch.manual_seed`, + and the Megatron tensor-parallel CUDA RNG initializer. + + Parameters: + cfg: Configuration container exposing `seed` or `rng.seed`. + mimo_infra: MIMO infrastructure containing `module_to_grid_map` used to + determine per-module TP/PP ranks. + """ import random @@ -120,28 +129,19 @@ def setup_mimo( build_data_iterators_fn: Optional[Callable] = None, global_state: Optional[GlobalState] = None, ) -> MimoSetupOutput: - """MIMO-specific setup helper. - - This function sets up all components needed for MIMO training: - - Builds distributed model via MimoModelProvider - - Builds MIMO infrastructure (grids, topology, pg_collections) - - Creates MultiModulePipelineCommunicator - - Builds data iterators (if function provided) - - Validates configuration - - Args: - cfg: ConfigContainer with training configuration. - mimo_provider: MimoModelProvider for building model and infrastructure. - build_data_iterators_fn: Optional function to build data iterators. - Should have signature: (cfg, mimo_infra) -> (train_iter, valid_iter) - global_state: Optional GlobalState. If not provided, creates a new one. - + """ + Set up all components required for MIMO pretraining and return them as a MimoSetupOutput. + + This initializes GlobalState if absent, finalizes and builds MIMO infrastructure, seeds RNGs per module grid, constructs the distributed model and the MultiModulePipelineCommunicator, prepares scheduling/gradient helper structures, and optionally builds train/validation data iterators. + + Parameters: + cfg (ConfigContainer): Global configuration container. + mimo_provider (MimoModelProvider): Provider responsible for finalizing and constructing MIMO model and infra. + build_data_iterators_fn (Optional[Callable]): Optional callable with signature (cfg, mimo_infra) -> (train_iter, valid_iter) to create data iterators. + global_state (Optional[GlobalState]): Pre-existing GlobalState to reuse; if omitted a new GlobalState is created. + Returns: - MimoSetupOutput containing all components for training. - - Reuses from setup.py: - - Logging setup (via global_state) - - Timer infrastructure (via global_state) + MimoSetupOutput: Container with the constructed model, mimo_infra, multimodule_pg_collection, multimodule_communicator, module_to_grid_tuple, train/valid iterators (may be None), and the GlobalState. """ # Create GlobalState if not provided if global_state is None: @@ -259,23 +259,19 @@ def pretrain_mimo( schedulers: Optional[Dict[str, "OptimizerParamScheduler"]] = None, global_state: Optional[GlobalState] = None, ) -> None: - """Entry point for MIMO pretraining. - - Steps: - 1. Call setup_mimo() to get model, infra, communicators - 2. Validate constructor-time MIMO config wiring - 3. Create MimoOptimizer using get_mimo_optimizer() - 4. Call train_mimo() with all components - - Args: - cfg: ConfigContainer with training configuration. - mimo_provider: MimoModelProvider for building model and infrastructure. - forward_step_func: Forward step function for training. - build_data_iterators_fn: Function to build data iterators. - Signature: (cfg, mimo_infra) -> (train_iter, valid_iter) - opt_config: OptimizerConfig for creating MimoOptimizer. - schedulers: Per-module learning rate schedulers {module_name: scheduler}. - global_state: Optional GlobalState. If not provided, creates a new one. + """ + Orchestrate MIMO pretraining: prepare infrastructure and model, create optimizer and schedulers, and run the training loop. + + Sets up MIMO infrastructure and distributed model via the provided provider, constructs a MIMO optimizer from the unwrapped model and the given optimizer configuration, optionally auto-creates per-module learning-rate schedulers when `schedulers` is empty, and executes the training loop using the supplied forward step and data iterators. + + Parameters: + cfg: Configuration container with training, model, scheduler, and logging settings. + mimo_provider: Provider responsible for finalizing and constructing the MIMO model and infra. + forward_step_func: Callable that performs a single forward/backward step for training. + build_data_iterators_fn: Callable used to build data iterators. Signature: (cfg, mimo_infra) -> (train_iter, valid_iter). + opt_config: Optimizer configuration used to construct the MIMO optimizer; if it exposes `finalize()`, that will be invoked. + schedulers: Optional mapping of module name to OptimizerParamScheduler; if empty or None, per-module schedulers are created automatically. + global_state: Optional GlobalState to use for timers and training state; if omitted, a new GlobalState is created by setup. """ if schedulers is None: schedulers = {} diff --git a/src/megatron/bridge/training/train_mimo.py b/src/megatron/bridge/training/train_mimo.py index e8d1dc2a3d..7feebde137 100644 --- a/src/megatron/bridge/training/train_mimo.py +++ b/src/megatron/bridge/training/train_mimo.py @@ -78,25 +78,30 @@ def train_step_mimo( seq_length: int, micro_batch_size: int, ) -> Tuple[Dict[str, torch.Tensor], Optional[float], Optional[int]]: - """Single MIMO training step. - - Args: - forward_step_func: Forward step function (wrapped with GlobalState). - data_iterator: Iterator over the dataset. - model: MimoModel instance. - optimizer: MimoOptimizer managing per-module optimizers. - schedulers: Per-module learning rate schedulers {module_name: scheduler}. - global_state: GlobalState containing timers, config, train_state. - multimodule_communicator: MultiModulePipelineCommunicator for P2P. - multimodule_pg_collection: PG collection for schedule. - infra: MimoModelInfra with grids, topology, pg_collections. - module_to_grid_tuple: List of (module, grid) tuples. - num_microbatches: Number of microbatches per iteration. - seq_length: Sequence length. - micro_batch_size: Micro batch size. - + """ + Execute a single MIMO training iteration: run forward/backward pipelined across modules, perform the coordinated optimizer update, advance per-module schedulers on success, reduce per-stage losses on the pipeline's final stage, and broadcast the final loss dictionary to the logging rank. + + Parameters: + forward_step_func (Callable): Wrapped forward step function used by the pipeline. + data_iterator (Iterator): Iterator yielding training microbatches. + model (MimoModel): The MIMO-wrapped model participating in the pipeline. + optimizer (MimoOptimizer): Optimizer that manages per-module updates and returns update status and gradient statistics. + schedulers (Dict[str, OptimizerParamScheduler]): Per-module learning-rate schedulers; any None entries are skipped. + global_state (GlobalState): Global training state and timers (provides cfg.data_parallel_size and timers). + multimodule_communicator (MultiModulePipelineCommunicator): P2P communicator used by the multimodule pipeline schedule. + multimodule_pg_collection: Process-group collection used by the pipeline schedule. + infra (MimoModelInfra): MIMO infra containing module grids and pg_collections (used for final-stage reductions and optional all-reduce). + module_to_grid_tuple (List): Mapping of modules to their device/grid tuples used for gradient clearing. + num_microbatches (int): Number of microbatches per training iteration. + seq_length (int): Sequence length for the forward step. + micro_batch_size (int): Micro-batch size. + Returns: - Tuple of (loss_dict, skipped_iter, grad_norm, num_zeros_in_grad). + Tuple containing: + - loss_dict (Dict[str, torch.Tensor]): Reduced loss metrics produced by the pipeline's last stage (empty if no losses produced on this rank). All tensors are on the local CUDA device for the logging rank. + - skipped_iter (int): `0` if the optimizer update succeeded and schedulers were stepped, `1` if the update was skipped. + - grad_norm (Optional[float]): Global gradient norm computed by the optimizer, or `None` if unavailable. + - num_zeros_in_grad (Optional[int]): Number of zero gradients reported by the optimizer, or `None` if unavailable. """ timers = global_state.timers @@ -197,33 +202,25 @@ def train_mimo( mimo_infra: "MimoModelInfra", multimodule_communicator: "MultiModulePipelineCommunicator", ) -> None: - """Main MIMO training loop. - - Key differences from standard train(): - - Creates MultiModuleProcessGroupCollection for the schedule - - Uses forward_backward_pipelining_without_interleaving with multimodule support - - Uses zero_grad_buffer_for_multimodule() for gradient clearing - - Uses MimoOptimizer for coordinated gradient clipping with global norm - - Reuses from existing Bridge training: - - GlobalState for timers, config, train_state - - training_log() for metrics reporting - - handle_profiling_step() and handle_profiling_stop() for profiler lifecycle - - save_checkpoint() with MimoOptimizer for checkpointing - - evaluate_and_print_results() for validation with multimodule support - - maybe_finalize_async_save() for async checkpoint finalization - - - Args: - forward_step_func: Forward step function. - model: MimoModel instance. - optimizer: MimoOptimizer managing per-module optimizers. - schedulers: Per-module learning rate schedulers {module_name: scheduler}. - train_data_iterator: Training data iterator. - valid_data_iterator: Validation data iterator (optional). - global_state: GlobalState containing timers, config, train_state. - mimo_infra: MimoModelInfra with grids, topology, pg_collections. - multimodule_communicator: MultiModulePipelineCommunicator for P2P. + """ + Run the full MIMO training loop coordinating multimodule pipelining, optimization, scheduling, logging, evaluation, and checkpointing. + + This function: + - Builds multimodule process-group collection and validates MIMO-specific requirements. + - Configures model gradient hooks for multimodule training and prepares the forward step. + - Iterates until the configured number of training iterations, calling train_step_mimo each iteration, updating training state, stepping per-module schedulers on successful updates, logging metrics, running evaluation, and saving checkpoints (including async finalize behavior). + - Manages optional profiling and ensures graceful shutdown of profilers and async checkpoint saves. + + Parameters: + forward_step_func (Callable): Per-microbatch forward function; will be wrapped with GlobalState. + model (MimoModel): The distributed MIMO model to train. + optimizer (MimoOptimizer): Optimizer coordinating per-module optimizer behavior and gradient scaling. + schedulers (Dict[str, OptimizerParamScheduler]): Mapping of module names to their learning-rate schedulers. + train_data_iterator (Iterator): Iterator yielding training microbatches. + valid_data_iterator (Optional[Iterator]): Optional iterator for validation/evaluation. + global_state (GlobalState): Global training state and configuration (timers, train_state, loggers). + mimo_infra (MimoModelInfra): MIMO infrastructure describing module grids, topology, and process-group collections. + multimodule_communicator (MultiModulePipelineCommunicator): Communicator for cross-module point-to-point pipeline transfers. """ timers = global_state.timers train_state = global_state.train_state diff --git a/src/megatron/bridge/training/utils/train_utils.py b/src/megatron/bridge/training/utils/train_utils.py index 64ba31151c..7890dd03b9 100644 --- a/src/megatron/bridge/training/utils/train_utils.py +++ b/src/megatron/bridge/training/utils/train_utils.py @@ -350,32 +350,31 @@ def training_log( pg_collection: Optional[Any] = None, log_max_attention_logit: Optional[float] = None, ) -> bool: - """Log training stats (losses, learning rate, timings, etc.). - - Aggregates losses, logs metrics to TensorBoard and WandB (if enabled), - and prints a formatted log string to the console on the last rank. - - Args: - loss_dict (dict[str, torch.Tensor]): Dictionary of losses for the current step. - total_loss_dict (dict[str, Any]): Dictionary to accumulate losses and stats - across logging intervals. - learning_rate (Optional[float]): Current learning rate. - decoupled_learning_rate (Optional[float]): Current decoupled learning rate (if used). - loss_scale (float): Current loss scale value. - report_memory_flag (bool): Flag to indicate if memory usage should be reported. - skipped_iter (int): 1 if the iteration was skipped, 0 otherwise. - grad_norm (Optional[float]): Gradient norm if computed, else None. - params_norm (Optional[float]): Parameter L2 norm if computed, else None. - num_zeros_in_grad (Optional[int]): Number of zeros in gradient if computed, else None. - config: The main configuration container. - global_state: The global training state. - history_wct (list): list of elapsed time per each iteration. - model (list[MegatronModule]): megatron model state. - pg_collection (Optional[Any]): ProcessGroupCollection to use for logging reductions. - If None, falls back to extracting from model wrappers. - log_max_attention_logit (Optional[float]): Maximum attention logit if available, None otherwise. + """ + Log and report training metrics (losses, timing, throughput, memory, and auxiliary model metrics). + + Aggregates per-step losses into `total_loss_dict`, writes configured scalars to TensorBoard / Weights & Biases / MLflow / Comet, prints a periodic console summary on the last rank, and tracks MoE/MTP specific metrics when enabled. + + Parameters: + loss_dict (dict[str, torch.Tensor]): Loss tensors for the current iteration. + total_loss_dict (dict[str, Any]): Accumulator for losses and counters across the logging interval. + learning_rate (Optional[float]): Current learning rate (may be None on ranks without trainable params). + decoupled_learning_rate (Optional[float]): Current decoupled learning rate for optimizers that use it. + loss_scale (float): Current loss scaling factor. + report_memory_flag (bool): If True, report memory usage on this logging iteration and then set False. + skipped_iter (int): 1 if the current iteration was skipped, 0 otherwise. + grad_norm (Optional[float]): Computed gradient L2 norm, or None if not available. + params_norm (Optional[float]): Computed parameter L2 norm, or None if not available. + num_zeros_in_grad (Optional[int]): Number of zero entries in gradients, or None if not available. + config (ConfigContainer): Global configuration container. + global_state (GlobalState): Global training state and utilities (timers, loggers, counters). + history_wct (list): Wall-clock time history for recent iterations (used for throughput rollups). + model (list[MegatronModule]): Model wrapper(s) used to derive process-group collection when needed. + pg_collection (Optional[Any]): Optional ProcessGroupCollection to use for reductions; if None, it is derived from `model`. + log_max_attention_logit (Optional[float]): Maximum attention logit observed this iteration, if provided. + Returns: - bool: The updated report_memory_flag. + bool: `True` if memory reporting should remain enabled for subsequent logging iterations, `False` otherwise. """ timers = global_state.timers train_state = global_state.train_state diff --git a/tests/unit_tests/data/mimo/test_collate.py b/tests/unit_tests/data/mimo/test_collate.py index 6fd71997ff..78ae90407b 100644 --- a/tests/unit_tests/data/mimo/test_collate.py +++ b/tests/unit_tests/data/mimo/test_collate.py @@ -10,7 +10,23 @@ def make_sample( seq_length: int = 64, modalities: dict = None, ) -> dict: - """Create a sample item for testing.""" + """ + Create a synthetic sample dictionary for tests containing token and modality fields. + + Parameters: + seq_length (int): Sequence length for 1D token tensors (default 64). + modalities (dict | None): Mapping of modality name to its tensors. If None, a default + vision modality with `pixel_values` of shape (3, 224, 224) is used. + + Returns: + dict: A sample with the following keys: + - `input_ids` (Tensor): random integers of shape (seq_length,). + - `labels` (Tensor): random integers of shape (seq_length,). + - `loss_mask` (Tensor): ones of dtype `float32` and shape (seq_length,). + - `attention_mask` (Tensor): ones of shape (seq_length,). + - `position_ids` (Tensor): range tensor of shape (seq_length,). + - `modality_inputs` (dict): the provided or default modality tensors. + """ if modalities is None: modalities = {"vision": {"pixel_values": torch.randn(3, 224, 224)}} diff --git a/tests/unit_tests/data/mimo/test_dp_utils.py b/tests/unit_tests/data/mimo/test_dp_utils.py index 64650a8172..3ac6624a8a 100644 --- a/tests/unit_tests/data/mimo/test_dp_utils.py +++ b/tests/unit_tests/data/mimo/test_dp_utils.py @@ -33,11 +33,32 @@ def __init__(self, rank_offset: int, size: int, dp_rank: int, dp_size: int, pp_r } def get_pg(self, dims): + """ + Return the process group associated with the given parallelism dimension keys. + + Parameters: + dims (Iterable): Sequence of dimension identifiers used as the lookup key (e.g., ('dp',), ('pp',)). + + Returns: + FakePG: The process group mapped to the provided dimensions. + + Raises: + KeyError: If no process group exists for the given dimensions. + """ return self._pgs[tuple(dims)] def _make_mimo_cfg() -> MimoParallelismConfig: - """Create test MIMO config for heterogeneous deployment.""" + """ + Constructs a MIMO parallelism configuration for tests with heterogeneous module settings. + + Creates a MimoParallelismConfig whose module_parallelisms map contains: + - "vision": tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=0 + - "language": tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=4 + + Returns: + MimoParallelismConfig: Configuration with the above per-module parallelism settings. + """ module_parallelisms = { "vision": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=0), "language": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=4), diff --git a/tests/unit_tests/data/mimo/test_hf_provider.py b/tests/unit_tests/data/mimo/test_hf_provider.py index c5c34b2ba0..75350b6eca 100644 --- a/tests/unit_tests/data/mimo/test_hf_provider.py +++ b/tests/unit_tests/data/mimo/test_hf_provider.py @@ -50,11 +50,40 @@ def test_build_datasets_happy_path(monkeypatch): calls = Calls() def fake_is_safe_repo(trust_remote_code, hf_path): + """ + Test helper that records an invocation and always reports the repository as not safe. + + Parameters: + trust_remote_code: Ignored. + hf_path: Ignored. + + Returns: + False: Indicates the repository is not considered safe. Also increments `calls.is_safe_repo` as a side effect. + """ del trust_remote_code, hf_path calls.is_safe_repo += 1 return False def fake_load_dataset(path, name=None, split=None, trust_remote_code=None, data_files=None): + """ + Test stub that simulates loading a dataset for unit tests. + + Parameters: + path (str): Ignored. + name (str | None): Ignored. + split (str | None): If equal to "validation", a `ValueError` is raised to simulate a missing split. + trust_remote_code: Ignored. + data_files: Ignored. + + Side effects: + Increments `calls.load_dataset` to record an invocation. + + Returns: + list[dict]: A single-record dataset: [{"text": "hello", "image": "image_0.jpg"}]. + + Raises: + ValueError: If `split` is "validation". + """ del path, name, trust_remote_code, data_files calls.load_dataset += 1 if split == "validation": diff --git a/tests/unit_tests/data/mimo/test_loaders.py b/tests/unit_tests/data/mimo/test_loaders.py index 218d5cc8c2..b85050aab6 100644 --- a/tests/unit_tests/data/mimo/test_loaders.py +++ b/tests/unit_tests/data/mimo/test_loaders.py @@ -10,6 +10,13 @@ class FakeMimoModelProvider: def __init__(self, mimo_parallelism_config, grids=None): + """ + Initialize the fake MimoModelProvider test double. + + Parameters: + mimo_parallelism_config: Configuration object describing MIMO parallelism; stored as `mimo_parallelism_config`. + grids (optional): Mapping of grid names to grid objects; stored as `_grids`. Provide `None` when no grids are available. + """ self.mimo_parallelism_config = mimo_parallelism_config self._grids = grids @@ -50,6 +57,11 @@ def test_build_mimo_data_loaders_raises_when_model_not_mimo(monkeypatch): def test_build_mimo_data_loaders_raises_when_parallelism_missing(monkeypatch): + """ + Verifies that build_mimo_data_loaders raises a ValueError when the model's mimo_parallelism_config is None. + + Sets up a fake MIMO provider class and a config whose model has mimo_parallelism_config=None and asserts that calling build_mimo_data_loaders raises ValueError with message matching "mimo_parallelism_config must be set". + """ _patch_mimo_provider_class(monkeypatch) cfg = SimpleNamespace( model=FakeMimoModelProvider(mimo_parallelism_config=None, grids={"llm": object()}), diff --git a/tests/unit_tests/models/mimo/test_mimo_provider.py b/tests/unit_tests/models/mimo/test_mimo_provider.py index 6b7bc1c733..56942c598b 100644 --- a/tests/unit_tests/models/mimo/test_mimo_provider.py +++ b/tests/unit_tests/models/mimo/test_mimo_provider.py @@ -175,7 +175,11 @@ def test_build_infra_with_parallelism(self, mock_build_grids, mock_get_rank, moc @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") def test_build_infra_is_idempotent(self, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group): - """Test build_infra() can be called multiple times.""" + """ + Verifies that calling build_infra() multiple times produces consistent participating_modules across calls. + + Ensures repeated invocations return infra objects with equivalent participating_modules when a mimo_parallelism_config is present. + """ mock_get_rank.return_value = 0 mock_get_pg_ranks.return_value = [0, 1] mock_new_group.return_value = MagicMock() @@ -656,7 +660,11 @@ def test_pg_collection_middle_stage_no_embedding_groups( @patch("megatron.bridge.models.mimo.mimo_provider.populate_embedding_and_position_groups") @patch("torch.distributed.get_rank") def test_pg_collection_includes_composite_groups(self, mock_get_rank, mock_populate, mock_is_first, mock_is_last): - """Test that pg_collection includes mp, tp_ep_pp, and expt_dp composite groups.""" + """ + Verify that the process-group collection for the "language" module contains the expected primitive and composite process groups. + + Asserts that `tp`, `dp`, `pp`, `cp`, `ep`, and the composite groups `dp_cp`, `mp` (from `("tp","pp")`), and `tp_ep_pp` (from `("tp","ep","pp")`) are present and mapped to the corresponding process-group objects. + """ mock_get_rank.return_value = 0 mock_populate.return_value = (MagicMock(), MagicMock()) mock_is_first.return_value = True diff --git a/tests/unit_tests/training/mimo/test_pretrain_mimo.py b/tests/unit_tests/training/mimo/test_pretrain_mimo.py index 27320ca926..fb1a5c6526 100644 --- a/tests/unit_tests/training/mimo/test_pretrain_mimo.py +++ b/tests/unit_tests/training/mimo/test_pretrain_mimo.py @@ -8,6 +8,20 @@ def _make_cfg(): + """ + Create and return a test mock configuration with sensible default training fields. + + The returned object is a MagicMock with a `train` SimpleNamespace containing: + - `rampup_batch_size = None` + - `global_batch_size = 1` + - `micro_batch_size = 1` + - `decrease_batch_size_if_needed = False` + + Also sets `data_parallel_size = 1` on the mock. + + Returns: + MagicMock: A mock configuration object with the `train` namespace and `data_parallel_size` set. + """ cfg = MagicMock() cfg.train = SimpleNamespace( rampup_batch_size=None, @@ -20,6 +34,21 @@ def _make_cfg(): def _make_setup_output(module_to_grid_map): + """ + Create a SimpleNamespace that mimics the structure returned by setup_mimo for unit tests. + + Parameters: + module_to_grid_map (dict): Mapping from module name to grid object used by tests (e.g., {"vision": grid}). + + Returns: + SimpleNamespace: Namespace with the following test-oriented attributes: + - model: a MagicMock representing the model. + - mimo_infra: SimpleNamespace with `module_to_grid_map` set to the provided mapping. + - multimodule_communicator: a MagicMock representing inter-module communication. + - train_data_iterator: an empty iterator for training data. + - valid_data_iterator: None (no validation iterator). + - global_state: a MagicMock representing global training state. + """ return SimpleNamespace( model=MagicMock(), mimo_infra=SimpleNamespace(module_to_grid_map=module_to_grid_map),