-
Notifications
You must be signed in to change notification settings - Fork 248
feat(mimo): Phase 4 - MiMo training, model/provider, data loading, heterogeneous parallelism #2869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9642406
f845bf5
a3b9a88
92dee32
c192de2
5e0b771
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| """Base class for MIMO dataset providers.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import abstractmethod | ||
| from dataclasses import dataclass | ||
| from typing import Callable, Optional, Tuple | ||
|
|
||
| from torch.utils.data import Dataset | ||
|
|
||
| from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class MimoDatasetProvider(DatasetProvider): | ||
| """Abstract base class for MIMO dataset providers. | ||
|
|
||
| All MIMO dataset providers must inherit from this class and implement | ||
| the required methods. This ensures a consistent interface for MIMO | ||
| data loading. | ||
|
|
||
| Required methods: | ||
| - build_datasets: Build train/valid/test datasets | ||
| - get_collate_fn: Return the collate function for batching | ||
|
|
||
| Example: | ||
| >>> class MyMimoProvider(MimoDatasetProvider): | ||
| ... def build_datasets(self, context): | ||
| ... # Build and return datasets | ||
| ... return train_ds, valid_ds, test_ds | ||
| ... | ||
| ... def get_collate_fn(self): | ||
| ... # Return collate function | ||
| ... return my_collate_fn | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def build_datasets( | ||
| self, context: DatasetBuildContext | ||
| ) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]: | ||
| """Build train, validation, and test datasets. | ||
|
|
||
| Args: | ||
| context: Build context with sample counts. | ||
|
|
||
| Returns: | ||
| Tuple of (train_dataset, valid_dataset, test_dataset). | ||
| Any element can be None if not needed. | ||
| """ | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def get_collate_fn(self) -> Callable: | ||
| """Return the collate function for batching. | ||
|
|
||
| The collate function should handle the modality_inputs dict | ||
| and batch them appropriately for the model. | ||
|
|
||
| Returns: | ||
| Callable that takes a list of samples and returns a batch dict. | ||
| """ | ||
| ... |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,7 @@ def mimo_collate_fn( | |
| labels = torch.stack([item["labels"] for item in batch]) | ||
| attention_mask = torch.stack([item["attention_mask"] for item in batch]) | ||
| position_ids = torch.stack([item["position_ids"] for item in batch]) | ||
| loss_mask = torch.stack([item["loss_mask"] for item in batch]) | ||
|
|
||
|
Comment on lines
+62
to
63
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update Line 62 and Line 114 make 📝 Suggested doc update Args:
batch: List of examples from MimoDataset, each containing:
- input_ids: Token IDs with placeholder tokens
- labels: Labels for causal LM training
+ - loss_mask: Per-token loss mask
- attention_mask: Attention mask
- position_ids: Position indices
- modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs
@@
Returns:
Dict containing:
- input_ids: (batch, seq) stacked token IDs
- labels: (batch, seq) stacked labels
+ - loss_mask: (batch, seq) stacked per-token loss mask
- attention_mask: (batch, seq) attention mask
- position_ids: (batch, seq) position indices
- modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensorsAs per coding guidelines Also applies to: 114-114 🤖 Prompt for AI Agents |
||
| # Collate modality inputs | ||
| modality_inputs: Dict[str, Dict[str, Any]] = {} | ||
|
|
@@ -110,6 +111,7 @@ def mimo_collate_fn( | |
| return { | ||
| "input_ids": input_ids, | ||
| "labels": labels, | ||
| "loss_mask": loss_mask, | ||
| "attention_mask": attention_mask, | ||
| "position_ids": position_ids, | ||
| "modality_inputs": modality_inputs, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the validation config here.
The rest of this module reads evaluation settings from
cfg.validation; usingcfg.train.eval_iterson the MIMO path can raise or permanently disabledo_valid/do_test.Suggested fix
🤖 Prompt for AI Agents