From b7b22f0d87783f9ba56e29c8e122c9d5b5cd55a1 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 23 Dec 2025 23:28:28 -0800 Subject: [PATCH 01/19] support custom_dataloader Signed-off-by: Yuki Huang --- examples/configs/grpo_math_1B.yaml | 5 + examples/configs/grpo_multiple_datasets.yaml | 5 + .../custom_dataloader/custom_dataloader.py | 45 +++++++++ examples/run_grpo.py | 6 ++ nemo_rl/algorithms/grpo.py | 97 ++++++++++++++----- nemo_rl/data/dataloader.py | 47 +++++++++ nemo_rl/data/utils.py | 57 ++++++++--- 7 files changed, 222 insertions(+), 40 deletions(-) create mode 100644 examples/custom_dataloader/custom_dataloader.py create mode 100644 nemo_rl/data/dataloader.py diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 740f9ad24b..9efb906552 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -292,6 +292,11 @@ data: shuffle: true num_workers: 1 + # use multiple dataloader for train + use_multiple_dataloader: false + num_prompts_per_dataloader: ${grpo.num_prompts_per_step} + custom_dataloader: null + # dataset train: dataset_name: OpenMathInstruct-2 diff --git a/examples/configs/grpo_multiple_datasets.yaml b/examples/configs/grpo_multiple_datasets.yaml index 9ed039fa96..97ca0fa20f 100644 --- a/examples/configs/grpo_multiple_datasets.yaml +++ b/examples/configs/grpo_multiple_datasets.yaml @@ -8,6 +8,11 @@ data: shuffle: true num_workers: 1 + # use multiple dataloader for train + use_multiple_dataloader: false + num_prompts_per_dataloader: ${grpo.num_prompts_per_step} + custom_dataloader: null + # dataset # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. train: diff --git a/examples/custom_dataloader/custom_dataloader.py b/examples/custom_dataloader/custom_dataloader.py new file mode 100644 index 0000000000..d7adde7edb --- /dev/null +++ b/examples/custom_dataloader/custom_dataloader.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterator + +from torchdata.stateful_dataloader import StatefulDataLoader + + +def example_custom_dataloader( + data_iterators: dict[str, Iterator], + dataloaders: dict[str, StatefulDataLoader], + **kwargs, +): + """An example of custom dataloader function. + + This function is used to sample data from multiple dataloaders using a custom dataloader function. + In this example, we simply sample data from each dataloader. + + Args: + dataloaders: A dictionary of dataloaders. + **kwargs: Additional arguments to pass to the custom dataloader function. + + Returns: + Data from the dataloaders. + Updated data iterators (may update if the data iterator is exhausted). + """ + result = [] + for task_name, data_iterator in data_iterators.items(): + try: + result.append(next(data_iterator)) + except: + data_iterators[task_name] = iter(dataloaders[task_name]) + result.append(next(data_iterators[task_name])) + return result, data_iterators diff --git a/examples/run_grpo.py b/examples/run_grpo.py index 4ab7c1266d..6130b99018 100644 --- a/examples/run_grpo.py +++ b/examples/run_grpo.py @@ -133,6 +133,12 @@ def main() -> None: f"{feature} is not supported with async GRPO" ) + # Async GRPO does not support multiple dataloaders + if config["data"]["use_multiple_dataloader"]: + raise NotImplementedError( + "use_multiple_dataloader is not supported with async GRPO" + ) + from nemo_rl.algorithms.grpo import async_grpo_train print("🚀 Running async GRPO training") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 6772739655..ed369b4343 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -49,6 +49,7 @@ ) from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.dataloader import MultipleDataloaderWrapper from nemo_rl.data.datasets import AllTaskProcessedDataset from nemo_rl.data.interfaces import DatumSpec from nemo_rl.data.llm_message_utils import ( @@ -214,14 +215,14 @@ class MasterConfig(TypedDict): def setup( master_config: MasterConfig, tokenizer: TokenizerType, - dataset: AllTaskProcessedDataset, + dataset: AllTaskProcessedDataset | dict[str, AllTaskProcessedDataset], val_dataset: Optional[AllTaskProcessedDataset], processor: Optional[AutoProcessor] = None, ) -> tuple[ ColocatablePolicyInterface, Optional[GenerationInterface], tuple[RayVirtualCluster, RayVirtualCluster], - StatefulDataLoader, + StatefulDataLoader | dict[str, StatefulDataLoader], Optional[StatefulDataLoader], ClippedPGLossFn, Logger, @@ -274,9 +275,14 @@ def setup( # ========================== # Data # ========================== - # Validate batch_multiplier + # Validate dataloader batch size + dataloader_batch_size = data_config["num_prompts_per_dataloader"] + if not data_config["use_multiple_dataloader"]: + assert dataloader_batch_size == grpo_config["num_prompts_per_step"], ( + "data.num_prompts_per_dataloader must be equal to grpo.num_prompts_per_step if not using multiple dataloaders (data.use_multiple_dataloader=false)" + ) + batch_multiplier = grpo_config["batch_multiplier"] - dataloader_batch_size = grpo_config["num_prompts_per_step"] if not grpo_config["use_dynamic_sampling"]: assert batch_multiplier == 1, ( "batch_multiplier>1 can only be used if use_dynamic_sampling=True" @@ -284,21 +290,35 @@ def setup( else: dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) - dataloader = StatefulDataLoader( - dataset, - batch_size=dataloader_batch_size, - shuffle=data_config["shuffle"], - collate_fn=rl_collate_fn, - drop_last=True, - num_workers=data_config["num_workers"], - ) - if last_checkpoint_path is not None: - dataloader_state_dict = torch.load( - os.path.join(last_checkpoint_path, "train_dataloader.pt") + # Load train dataset + def init_dataloader(dataset, suffix: str = ""): + dataloader = StatefulDataLoader( + dataset, + batch_size=dataloader_batch_size, + shuffle=data_config["shuffle"], + collate_fn=rl_collate_fn, + drop_last=True, + num_workers=data_config["num_workers"], ) - dataloader.load_state_dict(dataloader_state_dict) + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, f"train_dataloader{suffix}.pt") + ) + dataloader.load_state_dict(dataloader_state_dict) + return dataloader - print(f" ✓ Training dataloader loaded with {len(dataset)} samples", flush=True) + if data_config["use_multiple_dataloader"]: + dataloader = { + task_name: init_dataloader(task_dataset, f"_{task_name}") + for task_name, task_dataset in dataset.items() + } + sample_count = sum( + len(task_dataloader) for task_dataloader in dataloader.values() + ) + else: + dataloader = init_dataloader(dataset) + sample_count = len(dataloader) + print(f" ✓ Training dataloader loaded with {sample_count} samples", flush=True) # Load validation dataset if provided val_dataloader: Optional[StatefulDataLoader] = None @@ -1255,7 +1275,7 @@ def compute_and_apply_seq_logprob_error_masking( def grpo_train( policy: ColocatablePolicyInterface, policy_generation: Optional[GenerationInterface], - dataloader: StatefulDataLoader, + dataloader: StatefulDataLoader | dict[str, StatefulDataLoader], val_dataloader: Optional[StatefulDataLoader], tokenizer: TokenizerType, loss_fn: LossFunction, @@ -1343,6 +1363,14 @@ def grpo_train( logger.log_metrics(val_metrics, current_step, prefix="validation") logger.log_metrics(validation_timings, current_step, prefix="timing/validation") + # Wrap dataloader if using multiple dataloaders + if master_config["data"]["use_multiple_dataloader"]: + dataloader = MultipleDataloaderWrapper( + master_config["grpo"]["num_prompts_per_step"], + master_config["data"], + dataloader, + ) + while current_epoch < max_num_epochs and total_steps < max_num_steps: memory_tracker.snapshot_start_of_stage("Preparing batch", dir()) print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") @@ -1357,10 +1385,17 @@ def grpo_train( metrics_logging_data = dict() metrics = dict() - print( - f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_num_steps)} {'=' * 25}", - flush=True, - ) + if master_config["data"]["use_multiple_dataloader"]: + print( + f"\n{'=' * 25} Step {current_step + 1}/{max_num_steps} {'=' * 25}", + flush=True, + ) + else: + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_num_steps)} {'=' * 25}", + flush=True, + ) + maybe_gpu_profile_step(policy, total_steps + 1) if policy != policy_generation: maybe_gpu_profile_step(policy_generation, total_steps + 1) @@ -1951,10 +1986,20 @@ def grpo_train( ), checkpointing_cfg=master_config["checkpointing"], ) - torch.save( - dataloader.state_dict(), - os.path.join(checkpoint_path, "train_dataloader.pt"), - ) + if master_config["data"]["use_multiple_dataloader"]: + for task_name, task_dataloader in dataloader.items(): + torch.save( + task_dataloader.state_dict(), + os.path.join( + checkpoint_path, + f"train_dataloader_{task_name}.pt", + ), + ) + else: + torch.save( + dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) checkpointer.finalize_checkpoint(checkpoint_path) # Logging diff --git a/nemo_rl/data/dataloader.py b/nemo_rl/data/dataloader.py new file mode 100644 index 0000000000..1b362689fd --- /dev/null +++ b/nemo_rl/data/dataloader.py @@ -0,0 +1,47 @@ +from hydra.utils import get_class +from torchdata.stateful_dataloader import StatefulDataLoader + + +class MultipleDataloaderWrapper: + """Wrapper for multiple dataloaders. + + This wrapper is used to sample data from multiple dataloaders using a custom dataloader function. + """ + + def __init__( + self, + num_prompts_per_step: int, + data_config: dict, + dataloaders: dict[str, StatefulDataLoader], + ): + self.num_prompts_per_step = num_prompts_per_step + self.data_config = data_config + self.dataloaders = dataloaders + + # init data iterators + self.data_iterators = { + task_name: iter(dataloader) for task_name, dataloader in dataloaders.items() + } + + self.custom_dataloader_func = get_class(data_config["custom_dataloader"]) + self.records = {} + + def __iter__(self): + result, self.data_iterators = self.custom_dataloader_func( + self.data_iterators, self.dataloaders, **self.records + ) + assert len(result) == self.num_prompts_per_step, ( + f"Expected {self.num_prompts_per_step} prompts, but got {len(result)}" + ) + + # reset records + self.records = {} + + return result + + def set_records(self, records: dict): + """Set the records for the custom dataloader. + + Records are used to pass additional information to the custom dataloader to decide how to sample the data from the dataloaders. + """ + self.records.update(records) diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 7fe335140e..3a563b80eb 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -29,6 +29,9 @@ from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.environments.utils import create_env +TrainDatasetType = Union[AllTaskProcessedDataset, dict[str, AllTaskProcessedDataset]] +ValidationDatasetType = Optional[AllTaskProcessedDataset] + # TODO: @yukih: unify to setup_data after dataset refactored def setup_response_data( @@ -69,7 +72,9 @@ def setup_response_data( "and the Migrate Guide in https://github.com/NVIDIA-NeMo/RL/pull/1649 to update the dataset config." ) - # setup environments if needed + # ========================== + # Setup Environments + # ========================== has_envs = env_configs is not None if has_envs: print("\n▶ Setting up envs...") @@ -81,8 +86,10 @@ def setup_response_data( env_name=registered_env_name, env_config=env_configs[env_name] ) + # ========================== + # Setup Train Dataset + # ========================== print("\n▶ Setting up data...") - # setup train dataset task_data_processors = {} task_data_preprocessors = {} task_to_env = {} @@ -105,18 +112,37 @@ def setup_response_data( if has_envs: task_to_env[task_name] = envs[cfg["env_name"]] - merged_data = concatenate_datasets([data.dataset for data in data_list]) - dataset = AllTaskProcessedDataset( - merged_data, - tokenizer, - None, - task_data_processors, - task_data_preprocessors=task_data_preprocessors, - max_seq_length=data_config["max_input_seq_length"], - ) - print(f" ✓ Training dataset loaded with {len(dataset)} samples.") + # merge datasets + if data_config["use_multiple_dataloader"]: + # merge datasets into a dictionary of task name to dataset + dataset = { + data.task_name: AllTaskProcessedDataset( + data.dataset, + tokenizer, + None, + task_data_processors, + task_data_preprocessors=task_data_preprocessors, + max_seq_length=data_config["max_input_seq_length"], + ) + for data in data_list + } + else: + # merge datasets into a single dataset + merged_data = concatenate_datasets([data.dataset for data in data_list]) + dataset = AllTaskProcessedDataset( + merged_data, + tokenizer, + None, + task_data_processors, + task_data_preprocessors=task_data_preprocessors, + max_seq_length=data_config["max_input_seq_length"], + ) + sample_count = sum(len(data.dataset) for data in data_list) + print(f" ✓ Training dataset loaded with {sample_count} samples.") - # setup validation dataset + # ========================== + # Setup Validation Dataset + # ========================== val_task_data_processors = {} val_task_data_preprocessors = {} val_task_to_env = {} @@ -158,6 +184,7 @@ def setup_response_data( if has_envs: val_task_to_env[task_name] = envs[cfg["env_name"]] + # merge datasets val_dataset = None if len(val_data_list) > 0: merged_val_data = concatenate_datasets(val_data_list) @@ -178,7 +205,9 @@ def setup_response_data( # TODO: @yukih: unify to setup_data after dataset refactored -def setup_preference_data(tokenizer: AutoTokenizer, data_config: DataConfig): +def setup_preference_data( + tokenizer: AutoTokenizer, data_config: DataConfig +) -> tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset]]: """Setup preference data. This function is used to setup the preference data for the training and validation datasets. From 8c5e6bf4b854c8bd548886ac7781f98e41fb7531 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Feb 2026 06:48:36 -0800 Subject: [PATCH 02/19] fix path and iter Signed-off-by: Yuki Huang --- .../custom_dataloader/custom_dataloader.py | 6 +++ nemo_rl/algorithms/grpo.py | 18 +++++--- nemo_rl/data/dataloader.py | 46 ++++++++++++++++--- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/examples/custom_dataloader/custom_dataloader.py b/examples/custom_dataloader/custom_dataloader.py index d7adde7edb..1f0f031374 100644 --- a/examples/custom_dataloader/custom_dataloader.py +++ b/examples/custom_dataloader/custom_dataloader.py @@ -16,6 +16,8 @@ from torchdata.stateful_dataloader import StatefulDataLoader +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + def example_custom_dataloader( data_iterators: dict[str, Iterator], @@ -35,6 +37,7 @@ def example_custom_dataloader( Data from the dataloaders. Updated data iterators (may update if the data iterator is exhausted). """ + # sample data from each dataloader result = [] for task_name, data_iterator in data_iterators.items(): try: @@ -42,4 +45,7 @@ def example_custom_dataloader( except: data_iterators[task_name] = iter(dataloaders[task_name]) result.append(next(data_iterators[task_name])) + + # merge results + result = BatchedDataDict.from_batches(result) return result, data_iterators diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ed369b4343..9dd9baed07 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1365,10 +1365,12 @@ def grpo_train( # Wrap dataloader if using multiple dataloaders if master_config["data"]["use_multiple_dataloader"]: + num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] + batch_multiplier = master_config["grpo"]["batch_multiplier"] dataloader = MultipleDataloaderWrapper( - master_config["grpo"]["num_prompts_per_step"], - master_config["data"], - dataloader, + expected_num_prompts=num_prompts_per_step * batch_multiplier, + data_config=master_config["data"], + dataloaders=dataloader, ) while current_epoch < max_num_epochs and total_steps < max_num_steps: @@ -1794,10 +1796,12 @@ def grpo_train( # Set generation as stale to force refit with new scales POLICY_GENERATION_STALE = True - is_last_step = (total_steps + 1 >= max_num_steps) or ( - (current_epoch + 1 == max_num_epochs) - and (current_step + 1 == len(dataloader)) - ) + is_last_step = total_steps + 1 >= max_num_steps + if not master_config["data"]["use_multiple_dataloader"]: + is_last_step = is_last_step or ( + (current_epoch + 1 == max_num_epochs) + and (current_step + 1 == len(dataloader)) + ) # Run validation if it's a validation step or last step with val_at_end if (val_period > 0 and (total_steps + 1) % val_period == 0) or ( diff --git a/nemo_rl/data/dataloader.py b/nemo_rl/data/dataloader.py index 1b362689fd..690a9fa2fa 100644 --- a/nemo_rl/data/dataloader.py +++ b/nemo_rl/data/dataloader.py @@ -1,4 +1,17 @@ -from hydra.utils import get_class +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from torchdata.stateful_dataloader import StatefulDataLoader @@ -10,11 +23,11 @@ class MultipleDataloaderWrapper: def __init__( self, - num_prompts_per_step: int, + expected_num_prompts: int, data_config: dict, dataloaders: dict[str, StatefulDataLoader], ): - self.num_prompts_per_step = num_prompts_per_step + self.expected_num_prompts = expected_num_prompts self.data_config = data_config self.dataloaders = dataloaders @@ -23,15 +36,34 @@ def __init__( task_name: iter(dataloader) for task_name, dataloader in dataloaders.items() } - self.custom_dataloader_func = get_class(data_config["custom_dataloader"]) + # custom dataloader function to decide how to sample the data from the dataloaders + self.custom_dataloader_func = self._load_custom_dataloader_func() + # records to pass additional information to the custom dataloader function self.records = {} + def _load_custom_dataloader_func(self): + import sys + from pathlib import Path + + from hydra.utils import get_method + + project_root_path = Path(__file__).absolute().parents[2] + sys.path = [str(project_root_path)] + sys.path + + return get_method(self.data_config["custom_dataloader"]) + def __iter__(self): + return self + + def __next__(self): + # sample data from the dataloaders result, self.data_iterators = self.custom_dataloader_func( self.data_iterators, self.dataloaders, **self.records ) - assert len(result) == self.num_prompts_per_step, ( - f"Expected {self.num_prompts_per_step} prompts, but got {len(result)}" + + # check if the number of prompts is expected + assert len(result["message_log"]) == self.expected_num_prompts, ( + f"Expected {self.expected_num_prompts} prompts, but got {len(result['message_log'])}" ) # reset records @@ -42,6 +74,6 @@ def __iter__(self): def set_records(self, records: dict): """Set the records for the custom dataloader. - Records are used to pass additional information to the custom dataloader to decide how to sample the data from the dataloaders. + Records are used to pass additional information to the custom dataloader function to decide how to sample the data from the dataloaders. """ self.records.update(records) From eb69ac7b7e6cc94820eb95a74ea44ec64f9b74b7 Mon Sep 17 00:00:00 2001 From: ruit Date: Thu, 1 Jan 2026 20:35:29 -0800 Subject: [PATCH 03/19] update config Signed-off-by: ruit --- examples/configs/vlm_grpo_3B.yaml | 5 +++++ examples/configs/vlm_grpo_3B_megatron.yaml | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 4cad631c85..f7b859188c 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -249,6 +249,11 @@ data: shuffle: true num_workers: 1 + # use multiple dataloader for train + use_multiple_dataloader: false + num_prompts_per_dataloader: ${grpo.num_prompts_per_step} + custom_dataloader: null + # dataset train: dataset_name: clevr-cogent diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 336d97d79b..0cded6bf87 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -199,6 +199,12 @@ data: max_input_seq_length: ${policy.max_total_sequence_length} shuffle: true num_workers: 1 + + # use multiple dataloader for train + use_multiple_dataloader: false + num_prompts_per_dataloader: ${grpo.num_prompts_per_step} + custom_dataloader: null + # dataset train: dataset_name: clevr-cogent From 61711ed80b3d23f7f0a17177e1062cfaeb244822 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 5 Jan 2026 02:19:43 -0800 Subject: [PATCH 04/19] fix mcore sample count Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 9dd9baed07..bef3a4673f 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -312,13 +312,15 @@ def init_dataloader(dataset, suffix: str = ""): task_name: init_dataloader(task_dataset, f"_{task_name}") for task_name, task_dataset in dataset.items() } - sample_count = sum( + train_sample_count = sum( len(task_dataloader) for task_dataloader in dataloader.values() ) else: dataloader = init_dataloader(dataset) - sample_count = len(dataloader) - print(f" ✓ Training dataloader loaded with {sample_count} samples", flush=True) + train_sample_count = len(dataloader) + print( + f" ✓ Training dataloader loaded with {train_sample_count} samples", flush=True + ) # Load validation dataset if provided val_dataloader: Optional[StatefulDataLoader] = None @@ -526,7 +528,7 @@ def init_dataloader(dataset, suffix: str = ""): ## NOTE: this is equal to the total number of scheduler steps total_train_iters = min( grpo_config["max_num_steps"], - grpo_config["max_num_epochs"] * len(dataloader), + grpo_config["max_num_epochs"] * train_sample_count, ) policy_config["megatron_cfg"]["train_iters"] = total_train_iters From 4af6268fcbf87f6fa57909c257b2071d466b8b54 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Feb 2026 06:56:08 -0800 Subject: [PATCH 05/19] pyrefly Signed-off-by: Yuki Huang --- pyrefly.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyrefly.toml b/pyrefly.toml index 8ad2261252..89da2cd4eb 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -34,6 +34,7 @@ project-includes = [ "docs/helpers.py", "examples/converters/convert_dcp_to_hf.py", "examples/converters/convert_megatron_to_hf.py", + "examples/custom_dataloader/custom_dataloader.py", "examples/custom_parallel/custom_parallel.py", "examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py", "nemo_rl/algorithms/__init__.py", @@ -44,6 +45,7 @@ project-includes = [ "nemo_rl/data/__init__.py", "nemo_rl/data/chat_templates.py", "nemo_rl/data/collate_fn.py", + "nemo_rl/data/dataloader.py", "nemo_rl/data/datasets/__init__.py", "nemo_rl/data/datasets/eval_datasets/__init__.py", "nemo_rl/data/datasets/eval_datasets/aime.py", From be8262f2c55860229cc1b29ccc6f5a775b350ea8 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Feb 2026 07:50:52 -0800 Subject: [PATCH 06/19] check batch size Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index bef3a4673f..0b42efcb56 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -275,13 +275,12 @@ def setup( # ========================== # Data # ========================== - # Validate dataloader batch size - dataloader_batch_size = data_config["num_prompts_per_dataloader"] - if not data_config["use_multiple_dataloader"]: - assert dataloader_batch_size == grpo_config["num_prompts_per_step"], ( - "data.num_prompts_per_dataloader must be equal to grpo.num_prompts_per_step if not using multiple dataloaders (data.use_multiple_dataloader=false)" - ) + if data_config["use_multiple_dataloader"]: + dataloader_batch_size = data_config["num_prompts_per_dataloader"] + else: + dataloader_batch_size = grpo_config["num_prompts_per_step"] + # Validate batch_multiplier batch_multiplier = grpo_config["batch_multiplier"] if not grpo_config["use_dynamic_sampling"]: assert batch_multiplier == 1, ( @@ -1367,10 +1366,25 @@ def grpo_train( # Wrap dataloader if using multiple dataloaders if master_config["data"]["use_multiple_dataloader"]: + # Validate expected number of prompts num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] batch_multiplier = master_config["grpo"]["batch_multiplier"] + expected_num_prompts = int(num_prompts_per_step * batch_multiplier) + + num_prompts_per_dataloader = master_config["data"]["num_prompts_per_dataloader"] + real_num_prompts_per_dataloader = int( + num_prompts_per_dataloader * batch_multiplier + ) + + assert expected_num_prompts % real_num_prompts_per_dataloader == 0, ( + "Expected int(num_prompts_per_step * batch_multiplier) to be a multiple of int(num_prompts_per_dataloader * batch_multiplier), " + f"but got {expected_num_prompts} and {real_num_prompts_per_dataloader}. " + "Please check the configuration of num_prompts_per_step, num_prompts_per_dataloader, and batch_multiplier." + ) + + # Wrap dataloader dataloader = MultipleDataloaderWrapper( - expected_num_prompts=num_prompts_per_step * batch_multiplier, + expected_num_prompts=expected_num_prompts, data_config=master_config["data"], dataloaders=dataloader, ) From 8ae14e96599d9d0f08e0e88c2f3673466a4a9da4 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Feb 2026 19:36:21 -0800 Subject: [PATCH 07/19] fix rebase Signed-off-by: Yuki Huang --- nemo_rl/data/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 3a563b80eb..9fd6d3c6f0 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -29,9 +29,6 @@ from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.environments.utils import create_env -TrainDatasetType = Union[AllTaskProcessedDataset, dict[str, AllTaskProcessedDataset]] -ValidationDatasetType = Optional[AllTaskProcessedDataset] - # TODO: @yukih: unify to setup_data after dataset refactored def setup_response_data( @@ -40,9 +37,12 @@ def setup_response_data( env_configs: Optional[dict[str, Any]] = None, is_vlm: bool = False, ) -> Union[ - tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset]], tuple[ - AllTaskProcessedDataset, + Union[AllTaskProcessedDataset, dict[str, AllTaskProcessedDataset]], + Optional[AllTaskProcessedDataset], + ], + tuple[ + Union[AllTaskProcessedDataset, dict[str, AllTaskProcessedDataset]], Optional[AllTaskProcessedDataset], dict[str, EnvironmentInterface], dict[str, EnvironmentInterface], From 11021a505fe56ab9aaf3d39bc0eb9513fb1c9002 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Feb 2026 19:52:58 -0800 Subject: [PATCH 08/19] move wrap dataloader to setup Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 72 ++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 0b42efcb56..5e769ea057 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -222,7 +222,7 @@ def setup( ColocatablePolicyInterface, Optional[GenerationInterface], tuple[RayVirtualCluster, RayVirtualCluster], - StatefulDataLoader | dict[str, StatefulDataLoader], + StatefulDataLoader | MultipleDataloaderWrapper, Optional[StatefulDataLoader], ClippedPGLossFn, Logger, @@ -275,19 +275,19 @@ def setup( # ========================== # Data # ========================== + batch_multiplier = grpo_config["batch_multiplier"] if data_config["use_multiple_dataloader"]: dataloader_batch_size = data_config["num_prompts_per_dataloader"] else: dataloader_batch_size = grpo_config["num_prompts_per_step"] # Validate batch_multiplier - batch_multiplier = grpo_config["batch_multiplier"] - if not grpo_config["use_dynamic_sampling"]: + if grpo_config["use_dynamic_sampling"]: + dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) + else: assert batch_multiplier == 1, ( "batch_multiplier>1 can only be used if use_dynamic_sampling=True" ) - else: - dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) # Load train dataset def init_dataloader(dataset, suffix: str = ""): @@ -307,12 +307,30 @@ def init_dataloader(dataset, suffix: str = ""): return dataloader if data_config["use_multiple_dataloader"]: - dataloader = { + # Validate number of prompts per step + num_prompts_per_step = grpo_config["num_prompts_per_step"] + expected_num_prompts = int(num_prompts_per_step * batch_multiplier) + + assert expected_num_prompts % dataloader_batch_size == 0, ( + "Expected int(num_prompts_per_step * batch_multiplier) to be a multiple of int(num_prompts_per_dataloader * batch_multiplier), " + f"but got {expected_num_prompts} and {dataloader_batch_size}. " + "Please check the configuration of num_prompts_per_step, num_prompts_per_dataloader, and batch_multiplier." + ) + + # Initialize dataloaders + dataloaders = { task_name: init_dataloader(task_dataset, f"_{task_name}") for task_name, task_dataset in dataset.items() } train_sample_count = sum( - len(task_dataloader) for task_dataloader in dataloader.values() + len(task_dataloader) for task_dataloader in dataloaders.values() + ) + + # Wrap dataloader + dataloader = MultipleDataloaderWrapper( + expected_num_prompts=expected_num_prompts, + data_config=data_config, + dataloaders=dataloaders, ) else: dataloader = init_dataloader(dataset) @@ -1276,7 +1294,7 @@ def compute_and_apply_seq_logprob_error_masking( def grpo_train( policy: ColocatablePolicyInterface, policy_generation: Optional[GenerationInterface], - dataloader: StatefulDataLoader | dict[str, StatefulDataLoader], + wrapped_dataloader: StatefulDataLoader | MultipleDataloaderWrapper, val_dataloader: Optional[StatefulDataLoader], tokenizer: TokenizerType, loss_fn: LossFunction, @@ -1364,31 +1382,6 @@ def grpo_train( logger.log_metrics(val_metrics, current_step, prefix="validation") logger.log_metrics(validation_timings, current_step, prefix="timing/validation") - # Wrap dataloader if using multiple dataloaders - if master_config["data"]["use_multiple_dataloader"]: - # Validate expected number of prompts - num_prompts_per_step = master_config["grpo"]["num_prompts_per_step"] - batch_multiplier = master_config["grpo"]["batch_multiplier"] - expected_num_prompts = int(num_prompts_per_step * batch_multiplier) - - num_prompts_per_dataloader = master_config["data"]["num_prompts_per_dataloader"] - real_num_prompts_per_dataloader = int( - num_prompts_per_dataloader * batch_multiplier - ) - - assert expected_num_prompts % real_num_prompts_per_dataloader == 0, ( - "Expected int(num_prompts_per_step * batch_multiplier) to be a multiple of int(num_prompts_per_dataloader * batch_multiplier), " - f"but got {expected_num_prompts} and {real_num_prompts_per_dataloader}. " - "Please check the configuration of num_prompts_per_step, num_prompts_per_dataloader, and batch_multiplier." - ) - - # Wrap dataloader - dataloader = MultipleDataloaderWrapper( - expected_num_prompts=expected_num_prompts, - data_config=master_config["data"], - dataloaders=dataloader, - ) - while current_epoch < max_num_epochs and total_steps < max_num_steps: memory_tracker.snapshot_start_of_stage("Preparing batch", dir()) print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") @@ -1398,7 +1391,7 @@ def grpo_train( dynamic_sampling_num_gen_batches = 0 # Run grpo/dapo training loop (single-turn) - for batch in dataloader: + for batch in wrapped_dataloader: # A central place to store logging data that won't be deleted until the loop ends metrics_logging_data = dict() metrics = dict() @@ -1410,7 +1403,7 @@ def grpo_train( ) else: print( - f"\n{'=' * 25} Step {current_step + 1}/{min(len(dataloader), max_num_steps)} {'=' * 25}", + f"\n{'=' * 25} Step {current_step + 1}/{min(len(wrapped_dataloader), max_num_steps)} {'=' * 25}", flush=True, ) @@ -1816,7 +1809,7 @@ def grpo_train( if not master_config["data"]["use_multiple_dataloader"]: is_last_step = is_last_step or ( (current_epoch + 1 == max_num_epochs) - and (current_step + 1 == len(dataloader)) + and (current_step + 1 == len(wrapped_dataloader)) ) # Run validation if it's a validation step or last step with val_at_end @@ -2007,7 +2000,10 @@ def grpo_train( checkpointing_cfg=master_config["checkpointing"], ) if master_config["data"]["use_multiple_dataloader"]: - for task_name, task_dataloader in dataloader.items(): + for ( + task_name, + task_dataloader, + ) in wrapped_dataloader.dataloaders.items(): torch.save( task_dataloader.state_dict(), os.path.join( @@ -2017,7 +2013,7 @@ def grpo_train( ) else: torch.save( - dataloader.state_dict(), + wrapped_dataloader.state_dict(), os.path.join(checkpoint_path, "train_dataloader.pt"), ) checkpointer.finalize_checkpoint(checkpoint_path) From e84d94cd40026ea4b4117bfbe24ebea9f725d717 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 4 Feb 2026 20:16:17 -0800 Subject: [PATCH 09/19] reuse multiple datasets functional test and update config Signed-off-by: Yuki Huang --- examples/configs/grpo_math_1B.yaml | 3 +-- examples/configs/grpo_multiple_datasets.yaml | 4 ++-- examples/configs/vlm_grpo_3B.yaml | 2 -- examples/configs/vlm_grpo_3B_megatron.yaml | 2 -- nemo_rl/data/__init__.py | 4 ++++ ...grpo_multiple_datasets.sh => grpo_multiple_dataloaders.sh} | 3 +++ 6 files changed, 10 insertions(+), 8 deletions(-) rename tests/functional/{grpo_multiple_datasets.sh => grpo_multiple_dataloaders.sh} (89%) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 9efb906552..9088bf62c8 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -293,9 +293,8 @@ data: num_workers: 1 # use multiple dataloader for train + # see examples/configs/grpo_multiple_datasets.yaml for an example use_multiple_dataloader: false - num_prompts_per_dataloader: ${grpo.num_prompts_per_step} - custom_dataloader: null # dataset train: diff --git a/examples/configs/grpo_multiple_datasets.yaml b/examples/configs/grpo_multiple_datasets.yaml index 97ca0fa20f..1133a70b76 100644 --- a/examples/configs/grpo_multiple_datasets.yaml +++ b/examples/configs/grpo_multiple_datasets.yaml @@ -10,8 +10,8 @@ data: # use multiple dataloader for train use_multiple_dataloader: false - num_prompts_per_dataloader: ${grpo.num_prompts_per_step} - custom_dataloader: null + num_prompts_per_dataloader: 16 + custom_dataloader: examples.custom_dataloader.custom_dataloader.example_custom_dataloader # dataset # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index f7b859188c..81d4bf8dce 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -251,8 +251,6 @@ data: # use multiple dataloader for train use_multiple_dataloader: false - num_prompts_per_dataloader: ${grpo.num_prompts_per_step} - custom_dataloader: null # dataset train: diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 0cded6bf87..47c449932e 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -202,8 +202,6 @@ data: # use multiple dataloader for train use_multiple_dataloader: false - num_prompts_per_dataloader: ${grpo.num_prompts_per_step} - custom_dataloader: null # dataset train: diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index a792d2d0f4..8d58a0d1f0 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -57,6 +57,10 @@ class DataConfig(TypedDict): # This saturates CPU threads without consuming too much memory # However, setting it too high might cause memory issues for long seqlens. num_workers: NotRequired[int] + # multiple dataloader configs + use_multiple_dataloader: bool + num_prompts_per_dataloader: NotRequired[int] + custom_dataloader: NotRequired[str] # dataset configs train: ResponseDatasetConfig | PreferenceDatasetConfig | list[ResponseDatasetConfig] validation: NotRequired[ diff --git a/tests/functional/grpo_multiple_datasets.sh b/tests/functional/grpo_multiple_dataloaders.sh similarity index 89% rename from tests/functional/grpo_multiple_datasets.sh rename to tests/functional/grpo_multiple_dataloaders.sh index 2b2fced916..c90c2a9603 100755 --- a/tests/functional/grpo_multiple_datasets.sh +++ b/tests/functional/grpo_multiple_dataloaders.sh @@ -22,6 +22,9 @@ uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJE $PROJECT_ROOT/examples/run_grpo.py \ --config $PROJECT_ROOT/examples/configs/grpo_multiple_datasets.yaml \ policy.model_name=Qwen/Qwen3-0.6B \ + data.use_multiple_dataloader=true \ + data.num_prompts_per_dataloader=1 \ + data.custom_dataloader=examples.custom_dataloader.custom_dataloader.example_custom_dataloader \ grpo.val_at_start=true \ grpo.max_val_samples=4 \ grpo.val_batch_size=4 \ From cba8f3a2605f51187bd27d2274c077e90575c23c Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 5 Feb 2026 01:55:26 -0800 Subject: [PATCH 10/19] add doc Signed-off-by: Yuki Huang --- docs/guides/grpo.md | 40 +++++++++++++++++++ examples/configs/grpo_math_1B.yaml | 2 +- examples/configs/grpo_multiple_datasets.yaml | 1 + .../custom_dataloader/custom_dataloader.py | 3 +- 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 08cf4ac101..b8105b87b5 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -164,6 +164,46 @@ def my_data_processor( We have an example of this as `math_data_processor` in [processors.py](../../nemo_rl/data/processors.py). +#### Multiple Dataloaders + +By default, NeMo RL uses a single dataloader that aggregates data from multiple datasets. For scenarios requiring fine-grained control over the number of prompts loaded from each dataset, NeMo RL provides support for multiple dataloaders. + +The following example demonstrates how to configure multiple dataloaders: + +```bash +uv run examples/run_grpo.py \ + --config examples/configs/grpo_multiple_datasets.yaml \ + grpo.num_prompts_per_step=32 \ + data.use_multiple_dataloader=true \ + data.num_prompts_per_dataloader=16 \ + data.custom_dataloader=examples.custom_dataloader.custom_dataloader.example_custom_dataloader +``` + +**Custom Dataloader** + +The file `examples/custom_dataloader/custom_dataloader.py` provides a reference implementation that samples `data.num_prompts_per_dataloader` entries from each dataloader. + +Additionally, custom dataloaders can access recorded metrics from the training loop. Use `wrapped_dataloader.set_records()` in `nemo_rl/algorithms/grpo.py` to store relevant information, which can then be retrieved in your custom dataloader implementation: + +```python +# In nemo_rl/algorithms/grpo.py +wrapped_dataloader.set_records({"reward": ...}) + +# In custom_dataloader.py +def example_custom_dataloader( + data_iterators: dict[str, Iterator], + dataloaders: dict[str, StatefulDataLoader], + **kwargs, +): + ... + reward = kwargs["reward"] + ... +``` + +**num_prompts_per_dataloader** + +This parameter specifies the number of prompts generated by each dataloader per iteration. Ensure that `grpo.num_prompts_per_step` is a multiple of `data.num_prompts_per_dataloader` to guarantee that exactly `grpo.num_prompts_per_step` prompts are available for each training step. + ### Task–Dataset Mapping - task_name (unique task identifier): diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 9088bf62c8..e3dc097a38 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -293,7 +293,7 @@ data: num_workers: 1 # use multiple dataloader for train - # see examples/configs/grpo_multiple_datasets.yaml for an example + # see https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details. use_multiple_dataloader: false # dataset diff --git a/examples/configs/grpo_multiple_datasets.yaml b/examples/configs/grpo_multiple_datasets.yaml index 1133a70b76..5fd0b40eed 100644 --- a/examples/configs/grpo_multiple_datasets.yaml +++ b/examples/configs/grpo_multiple_datasets.yaml @@ -9,6 +9,7 @@ data: num_workers: 1 # use multiple dataloader for train + # see https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details. use_multiple_dataloader: false num_prompts_per_dataloader: 16 custom_dataloader: examples.custom_dataloader.custom_dataloader.example_custom_dataloader diff --git a/examples/custom_dataloader/custom_dataloader.py b/examples/custom_dataloader/custom_dataloader.py index 1f0f031374..8e8fed4b1f 100644 --- a/examples/custom_dataloader/custom_dataloader.py +++ b/examples/custom_dataloader/custom_dataloader.py @@ -30,7 +30,8 @@ def example_custom_dataloader( In this example, we simply sample data from each dataloader. Args: - dataloaders: A dictionary of dataloaders. + data_iterators: A dictionary of data iterators. + dataloaders: A dictionary of dataloaders. It is used to reset the data iterator when it is exhausted. **kwargs: Additional arguments to pass to the custom dataloader function. Returns: From e3117096f3c95996d7d4fa83d5c8ba9dcfed5cef Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 8 Feb 2026 19:07:53 -0800 Subject: [PATCH 11/19] fix type Signed-off-by: Yuki Huang --- nemo_rl/data/__init__.py | 3 ++- nemo_rl/data/utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 8d58a0d1f0..93eea10108 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -58,7 +58,8 @@ class DataConfig(TypedDict): # However, setting it too high might cause memory issues for long seqlens. num_workers: NotRequired[int] # multiple dataloader configs - use_multiple_dataloader: bool + # currently only supported for GRPO + use_multiple_dataloader: NotRequired[bool] num_prompts_per_dataloader: NotRequired[int] custom_dataloader: NotRequired[str] # dataset configs diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 9fd6d3c6f0..91f6db9858 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -207,7 +207,7 @@ def setup_response_data( # TODO: @yukih: unify to setup_data after dataset refactored def setup_preference_data( tokenizer: AutoTokenizer, data_config: DataConfig -) -> tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset]]: +) -> tuple[AllTaskProcessedDataset, dict[str, AllTaskProcessedDataset]]: """Setup preference data. This function is used to setup the preference data for the training and validation datasets. From a837cf8feb0913c4b472814fcde6f29a5ebc3c32 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 8 Feb 2026 22:33:57 -0800 Subject: [PATCH 12/19] fix unit test Signed-off-by: Yuki Huang --- tests/unit/algorithms/test_grpo.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 2a6d91cd87..73a75fe64e 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -755,7 +755,12 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_single_node(): "use_dynamic_sampling": False, "batch_multiplier": 1, }, - "data": {"shuffle": False, "num_workers": 1, "env_name": None}, + "data": { + "shuffle": False, + "num_workers": 1, + "env_name": None, + "use_multiple_dataloader": False, + }, "logger": {}, # Config extraction requires this key "checkpointing": {}, # Config extraction requires this key "cluster": { @@ -829,7 +834,12 @@ def test_noncolocated_inference_requires_explicit_gpus_per_node_multi_node(): "use_dynamic_sampling": False, "batch_multiplier": 1, }, - "data": {"shuffle": False, "num_workers": 1, "env_name": None}, + "data": { + "shuffle": False, + "num_workers": 1, + "env_name": None, + "use_multiple_dataloader": False, + }, "logger": {}, # Config extraction requires this key "checkpointing": {}, # Config extraction requires this key "cluster": { @@ -993,7 +1003,12 @@ def init_collective(self, *_args, **_kwargs): "reward_shaping": {"enabled": False}, "overlong_filtering": False, }, - "data": {"shuffle": False, "num_workers": 0, "env_name": None}, + "data": { + "shuffle": False, + "num_workers": 0, + "env_name": None, + "use_multiple_dataloader": False, + }, "logger": {"num_val_samples_to_print": 0}, "checkpointing": {"enabled": False}, "cluster": {"num_nodes": 1, "gpus_per_node": 4}, @@ -1343,6 +1358,9 @@ def val_iter(self): "logger": { "num_val_samples_to_print": 5, }, + "data": { + "use_multiple_dataloader": False, + }, } return { From c61553e618654bb4ccb321ce23ed1a34ca672a64 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 22 Feb 2026 23:19:10 -0800 Subject: [PATCH 13/19] fix signature and add unit test Signed-off-by: Yuki Huang --- docs/guides/grpo.md | 2 +- .../custom_dataloader/custom_dataloader.py | 52 +++++++- nemo_rl/data/dataloader.py | 2 +- tests/unit/data/test_multiple_dataloader.py | 113 ++++++++++++++++++ 4 files changed, 164 insertions(+), 5 deletions(-) create mode 100644 tests/unit/data/test_multiple_dataloader.py diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index b8105b87b5..f42cbf4aca 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -194,7 +194,7 @@ def example_custom_dataloader( data_iterators: dict[str, Iterator], dataloaders: dict[str, StatefulDataLoader], **kwargs, -): +) -> tuple[BatchedDataDict, dict[str, Iterator]]: ... reward = kwargs["reward"] ... diff --git a/examples/custom_dataloader/custom_dataloader.py b/examples/custom_dataloader/custom_dataloader.py index 8e8fed4b1f..5a23d6d0fb 100644 --- a/examples/custom_dataloader/custom_dataloader.py +++ b/examples/custom_dataloader/custom_dataloader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ def example_custom_dataloader( data_iterators: dict[str, Iterator], dataloaders: dict[str, StatefulDataLoader], **kwargs, -): +) -> tuple[BatchedDataDict, dict[str, Iterator]]: """An example of custom dataloader function. This function is used to sample data from multiple dataloaders using a custom dataloader function. @@ -43,10 +43,56 @@ def example_custom_dataloader( for task_name, data_iterator in data_iterators.items(): try: result.append(next(data_iterator)) - except: + except StopIteration: data_iterators[task_name] = iter(dataloaders[task_name]) result.append(next(data_iterators[task_name])) # merge results result = BatchedDataDict.from_batches(result) return result, data_iterators + + +def example_custom_dataloader_with_chosen_task( + data_iterators: dict[str, Iterator], + dataloaders: dict[str, StatefulDataLoader], + chosen_task: list[str], + expected_num_prompts: int, + **kwargs, +) -> tuple[BatchedDataDict, dict[str, Iterator]]: + """An example of custom dataloader function with chosen task. + + This function is used to sample data from multiple dataloaders using a custom dataloader function. + In this example, we sample data from the chosen task. + + This function will need to call `wrapped_dataloader.set_records({"chosen_task": ..., "expected_num_prompts": ...})` to set the records in `nemo_rl/algorithms/grpo.py`. + A usage example is shown in the test case `test_multiple_dataloader_with_records` in `tests/unit/data/test_multiple_dataloader.py`. + + Args: + data_iterators: A dictionary of data iterators. + dataloaders: A dictionary of dataloaders. It is used to reset the data iterator when it is exhausted. + chosen_task: A list of task names to sample data from. + expected_num_prompts: The expected number of prompts to sample. + + Returns: + Data from the dataloaders. + Updated data iterators (may update if the data iterator is exhausted). + """ + # sample data from the chosen task + result = [] + current_task_idx = 0 + current_num_prompts = 0 + while current_num_prompts < expected_num_prompts: + task_name = chosen_task[current_task_idx] + try: + data = next(data_iterators[task_name]) + except StopIteration: + data_iterators[task_name] = iter(dataloaders[task_name]) + data = next(data_iterators[task_name]) + + result.append(data) + current_num_prompts += len(data["message_log"]) + current_task_idx = (current_task_idx + 1) % len(chosen_task) + + # merge results + result = BatchedDataDict.from_batches(result) + return result, data_iterators diff --git a/nemo_rl/data/dataloader.py b/nemo_rl/data/dataloader.py index 690a9fa2fa..430ea14804 100644 --- a/nemo_rl/data/dataloader.py +++ b/nemo_rl/data/dataloader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/data/test_multiple_dataloader.py b/tests/unit/data/test_multiple_dataloader.py new file mode 100644 index 0000000000..cd276cf6cd --- /dev/null +++ b/tests/unit/data/test_multiple_dataloader.py @@ -0,0 +1,113 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from torchdata.stateful_dataloader import StatefulDataLoader + +from nemo_rl.data.dataloader import MultipleDataloaderWrapper + + +@pytest.fixture(scope="function") +def dataloaders() -> dict[str, StatefulDataLoader]: + dataset1 = [ + {"message_log": [{"role": "user", "content": str(x)}]} for x in range(2) + ] + dataset2 = [ + {"message_log": [{"role": "user", "content": str(x)}]} for x in range(2, 6) + ] + + def collate_fn(data_batch: list[dict]) -> dict: + return { + "message_log": [datum["message_log"] for datum in data_batch], + } + + dataloaders = { + "dataloader1": StatefulDataLoader( + dataset=dataset1, batch_size=2, shuffle=False, collate_fn=collate_fn + ), + "dataloader2": StatefulDataLoader( + dataset=dataset2, batch_size=2, shuffle=False, collate_fn=collate_fn + ), + } + + yield dataloaders + + +def test_multiple_dataloader(dataloaders): + wrapped_dataloader = MultipleDataloaderWrapper( + expected_num_prompts=4, + data_config={ + "custom_dataloader": "examples.custom_dataloader.custom_dataloader.example_custom_dataloader" + }, + dataloaders=dataloaders, + ) + + iter_count = 0 + for data in wrapped_dataloader: + content = sorted([message[0]["content"] for message in data["message_log"]]) + + if iter_count == 0: + assert content == ["0", "1", "2", "3"] + elif iter_count == 1: + assert content == ["0", "1", "4", "5"] + + iter_count += 1 + if iter_count == 2: + break + + +def test_multiple_dataloader_with_records(dataloaders): + wrapped_dataloader = MultipleDataloaderWrapper( + expected_num_prompts=4, + data_config={ + "custom_dataloader": "examples.custom_dataloader.custom_dataloader.example_custom_dataloader_with_chosen_task" + }, + dataloaders=dataloaders, + ) + # set the records to sample data from all dataloaders + wrapped_dataloader.set_records( + { + "chosen_task": ["dataloader1", "dataloader2"], + "expected_num_prompts": wrapped_dataloader.expected_num_prompts, + } + ) + + iter_count = 0 + for data in wrapped_dataloader: + content = sorted([message[0]["content"] for message in data["message_log"]]) + + if iter_count == 0: + assert content == ["0", "1", "2", "3"] + # set the records to sample data from dataloader1 + wrapped_dataloader.set_records( + { + "chosen_task": ["dataloader1"], + "expected_num_prompts": wrapped_dataloader.expected_num_prompts, + } + ) + elif iter_count == 1: + assert content == ["0", "0", "1", "1"] + # set the records to sample data from dataloader2 + wrapped_dataloader.set_records( + { + "chosen_task": ["dataloader2"], + "expected_num_prompts": wrapped_dataloader.expected_num_prompts, + } + ) + elif iter_count == 2: + assert content == ["2", "3", "4", "5"] + + iter_count += 1 + if iter_count == 3: + break From 99b137981ec0cee6a539ea4e692600e1459c989f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 23 Feb 2026 00:27:12 -0800 Subject: [PATCH 14/19] update functional test Signed-off-by: Yuki Huang --- tests/functional/L1_Functional_Tests_GPU.sh | 2 +- tests/functional/grpo_multiple_dataloaders.sh | 61 ++++++++++++++++--- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index c3e15addff..773e57e9e0 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -37,7 +37,7 @@ time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh -time uv run --no-sync bash ./tests/functional/grpo_multiple_datasets.sh +time uv run --no-sync bash ./tests/functional/grpo_multiple_dataloaders.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh time uv run --no-sync bash ./tests/functional/grpo_rm_env.sh diff --git a/tests/functional/grpo_multiple_dataloaders.sh b/tests/functional/grpo_multiple_dataloaders.sh index c90c2a9603..e2b7fac765 100755 --- a/tests/functional/grpo_multiple_dataloaders.sh +++ b/tests/functional/grpo_multiple_dataloaders.sh @@ -10,14 +10,21 @@ set -eou pipefail EXP_NAME=$(basename $0 .sh) EXP_DIR=$SCRIPT_DIR/$EXP_NAME LOG_DIR=$EXP_DIR/logs -JSON_METRICS=$EXP_DIR/metrics.json +CKPT_DIR=$EXP_DIR/ckpts RUN_LOG=$EXP_DIR/run.log export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} rm -rf $EXP_DIR $LOG_DIR mkdir -p $EXP_DIR $LOG_DIR -cd $PROJECT_ROOT +# This test will run for 2 steps and make sure that 1+1 steps w/ resume leads to the same result. +# We use the checkpointing.checkpoint_must_save_by=0:0:0:1 feature to exit after 1 step. + +prefix_output() { + sed "s/^/$1/" +} + +train_cmd() { uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ $PROJECT_ROOT/examples/run_grpo.py \ --config $PROJECT_ROOT/examples/configs/grpo_multiple_datasets.yaml \ @@ -35,14 +42,52 @@ uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJE cluster.gpus_per_node=2 \ grpo.max_num_steps=2 \ logger.tensorboard_enabled=true \ - logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ logger.monitor_gpus=true \ checkpointing.enabled=false \ - $@ \ - 2>&1 | tee $RUN_LOG + checkpointing.save_period=1 \ + $@ +} + +cd $PROJECT_ROOT + +# 2 step baseline +train_cmd logger.log_dir=$LOG_DIR/baseline $@ 2>&1 | prefix_output "[baseline 2step] " | tee ${RUN_LOG}.2step_baseline +uv run tests/json_dump_tb_logs.py $LOG_DIR/baseline --output_path $EXP_DIR/baseline.json +# 1+1 step +train_cmd logger.log_dir=$LOG_DIR/resume checkpointing.checkpoint_must_save_by=0:0:0:1 checkpointing.enabled=true checkpointing.checkpoint_dir=$CKPT_DIR/resume $@ 2>&1 | prefix_output "[resume 1step] " | tee ${RUN_LOG}.resume_1step +uv run tests/json_dump_tb_logs.py $LOG_DIR/resume --output_path $EXP_DIR/resume_1step.json +train_cmd logger.log_dir=$LOG_DIR/resume checkpointing.enabled=true checkpointing.checkpoint_dir=$CKPT_DIR/resume $@ 2>&1 | prefix_output "[resume 2step] " | tee ${RUN_LOG}.resume_2step +uv run tests/json_dump_tb_logs.py $LOG_DIR/resume --output_path $EXP_DIR/resume_2step.json + +uv run python - < Date: Mon, 23 Feb 2026 05:39:21 -0800 Subject: [PATCH 15/19] update print info Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 25 ++++++++++++++++--------- nemo_rl/data/utils.py | 9 +++++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 5e769ea057..7fd7541ba9 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -290,7 +290,7 @@ def setup( ) # Load train dataset - def init_dataloader(dataset, suffix: str = ""): + def init_train_dataloader(dataset, suffix: str = ""): dataloader = StatefulDataLoader( dataset, batch_size=dataloader_batch_size, @@ -318,10 +318,16 @@ def init_dataloader(dataset, suffix: str = ""): ) # Initialize dataloaders - dataloaders = { - task_name: init_dataloader(task_dataset, f"_{task_name}") - for task_name, task_dataset in dataset.items() - } + dataloaders = {} + for task_name, task_dataset in dataset.items(): + dataloaders[task_name] = init_train_dataloader( + task_dataset, f"_{task_name}" + ) + print( + f" ✓ Training dataloader {task_name} loaded with {len(task_dataset)} samples", + flush=True, + ) + train_sample_count = sum( len(task_dataloader) for task_dataloader in dataloaders.values() ) @@ -333,11 +339,12 @@ def init_dataloader(dataset, suffix: str = ""): dataloaders=dataloaders, ) else: - dataloader = init_dataloader(dataset) + dataloader = init_train_dataloader(dataset) train_sample_count = len(dataloader) - print( - f" ✓ Training dataloader loaded with {train_sample_count} samples", flush=True - ) + print( + f" ✓ Training dataloader loaded with {train_sample_count} samples", + flush=True, + ) # Load validation dataset if provided val_dataloader: Optional[StatefulDataLoader] = None diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 91f6db9858..2721ff6dd1 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -104,6 +104,9 @@ def setup_response_data( update_single_dataset_config(cfg, data_config["default"]) data = load_response_dataset(cfg) data_list.append(data) + print( + f" - Loaded training dataset {data.task_name} with {len(data.dataset)} samples." + ) # bind task_name to task_data_processors and task_to_env task_name = data.task_name task_data_processors[task_name] = (data.task_spec, data.processor) @@ -152,6 +155,9 @@ def setup_response_data( for data in data_list: if hasattr(data, "val_dataset") and data.val_dataset is not None: val_data_list.append(data.val_dataset) + print( + f" - Loaded validation dataset {data.task_name} with {len(data.val_dataset)} samples." + ) # bind task_name to task_data_processors and task_to_env task_name = data.task_name val_task_data_processors[task_name] = task_data_processors[task_name] @@ -173,6 +179,9 @@ def setup_response_data( update_single_dataset_config(cfg, data_config["default"]) val_data = load_response_dataset(cfg) val_data_list.append(val_data.dataset) + print( + f" - Loaded validation dataset {val_data.task_name} with {len(val_data.dataset)} samples." + ) # bind task_name to task_data_processors and task_to_env task_name = val_data.task_name val_task_data_processors[task_name] = ( From 4b849485bc35aabc143763defbf72dc4df08d515 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 23 Feb 2026 05:58:54 -0800 Subject: [PATCH 16/19] update assert place Signed-off-by: Yuki Huang --- nemo_rl/algorithms/grpo.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 7fd7541ba9..5cb8565633 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -275,20 +275,32 @@ def setup( # ========================== # Data # ========================== - batch_multiplier = grpo_config["batch_multiplier"] + # num_prompts_per_step and dataloader_batch_size will be different when using multiple dataloaders + num_prompts_per_step = grpo_config["num_prompts_per_step"] if data_config["use_multiple_dataloader"]: dataloader_batch_size = data_config["num_prompts_per_dataloader"] else: - dataloader_batch_size = grpo_config["num_prompts_per_step"] + dataloader_batch_size = num_prompts_per_step # Validate batch_multiplier + batch_multiplier = grpo_config["batch_multiplier"] if grpo_config["use_dynamic_sampling"]: + num_prompts_per_step = int(num_prompts_per_step * batch_multiplier) dataloader_batch_size = int(dataloader_batch_size * batch_multiplier) else: assert batch_multiplier == 1, ( "batch_multiplier>1 can only be used if use_dynamic_sampling=True" ) + # Validate number of prompts per step + if data_config["use_multiple_dataloader"]: + assert num_prompts_per_step % dataloader_batch_size == 0, ( + "Expected num_prompts_per_step to be a multiple of num_prompts_per_dataloader, " + f"but got {num_prompts_per_step} and {dataloader_batch_size}. " + "Please check the configuration of num_prompts_per_step and num_prompts_per_dataloader. " + "If use_dynamic_sampling is enabled and batch_multiplier is used, please also check the configuration of batch_multiplier." + ) + # Load train dataset def init_train_dataloader(dataset, suffix: str = ""): dataloader = StatefulDataLoader( @@ -307,16 +319,6 @@ def init_train_dataloader(dataset, suffix: str = ""): return dataloader if data_config["use_multiple_dataloader"]: - # Validate number of prompts per step - num_prompts_per_step = grpo_config["num_prompts_per_step"] - expected_num_prompts = int(num_prompts_per_step * batch_multiplier) - - assert expected_num_prompts % dataloader_batch_size == 0, ( - "Expected int(num_prompts_per_step * batch_multiplier) to be a multiple of int(num_prompts_per_dataloader * batch_multiplier), " - f"but got {expected_num_prompts} and {dataloader_batch_size}. " - "Please check the configuration of num_prompts_per_step, num_prompts_per_dataloader, and batch_multiplier." - ) - # Initialize dataloaders dataloaders = {} for task_name, task_dataset in dataset.items(): @@ -334,7 +336,7 @@ def init_train_dataloader(dataset, suffix: str = ""): # Wrap dataloader dataloader = MultipleDataloaderWrapper( - expected_num_prompts=expected_num_prompts, + expected_num_prompts=num_prompts_per_step, data_config=data_config, dataloaders=dataloaders, ) From ff9832c26e1558ac0ad1562859c252a0cd1b33e4 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 23 Feb 2026 06:15:22 -0800 Subject: [PATCH 17/19] add example in doc Signed-off-by: Yuki Huang --- docs/guides/grpo.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index f42cbf4aca..b72099abb8 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -179,6 +179,26 @@ uv run examples/run_grpo.py \ data.custom_dataloader=examples.custom_dataloader.custom_dataloader.example_custom_dataloader ``` +For example, consider using `example_custom_dataloader`, which samples data from each dataloader sequentially. + +Given two datasets: +- Dataset 1: `[a, b, c, d]` +- Dataset 2: `[1, 2, 3, 4, 5, 6, 7, 8]` + +With `data.use_multiple_dataloader=false` and `grpo.num_prompts_per_step=4`: +``` +Batch 1: [a, b, c, d] +Batch 2: [1, 2, 3, 4] +Batch 3: [5, 6, 7, 8] +``` + +With `data.use_multiple_dataloader=true`, `grpo.num_prompts_per_step=4`, and `data.num_prompts_per_dataloader=2`: +``` +Batch 1: [a, b, 1, 2] +Batch 2: [c, d, 3, 4] +Batch 3: [a, b, 5, 6] +``` + **Custom Dataloader** The file `examples/custom_dataloader/custom_dataloader.py` provides a reference implementation that samples `data.num_prompts_per_dataloader` entries from each dataloader. From f464786e5d1145625f71586c438471b06e77d668 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Mon, 23 Feb 2026 18:28:52 -0800 Subject: [PATCH 18/19] fix config Signed-off-by: Yuki Huang --- nemo_rl/data/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index 2721ff6dd1..2819e27582 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -116,7 +116,10 @@ def setup_response_data( task_to_env[task_name] = envs[cfg["env_name"]] # merge datasets - if data_config["use_multiple_dataloader"]: + if ( + "use_multiple_dataloader" in data_config + and data_config["use_multiple_dataloader"] + ): # merge datasets into a dictionary of task name to dataset dataset = { data.task_name: AllTaskProcessedDataset( From 39272a239b19c204eb6b07308fee3e4f4fbb4dcc Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 25 Feb 2026 07:58:05 -0800 Subject: [PATCH 19/19] add comment and warning Signed-off-by: Yuki Huang --- docs/guides/grpo.md | 3 +++ examples/custom_dataloader/custom_dataloader.py | 6 ++++++ nemo_rl/algorithms/grpo.py | 7 +++++++ nemo_rl/data/dataloader.py | 3 +++ 4 files changed, 19 insertions(+) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index b72099abb8..2576b93303 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -203,6 +203,9 @@ Batch 3: [a, b, 5, 6] The file `examples/custom_dataloader/custom_dataloader.py` provides a reference implementation that samples `data.num_prompts_per_dataloader` entries from each dataloader. +When a single dataloader is exhausted, the data iterator must be reset in the custom dataloader function (as demonstrated in `examples/custom_dataloader/custom_dataloader.py`). +This design ensures that the [MultipleDataloaderWrapper](../../nemo_rl/data/dataloader.py) operates as an infinite iterator, where `__next__()` will not raise StopIteration and `__len__()` is not supported. + Additionally, custom dataloaders can access recorded metrics from the training loop. Use `wrapped_dataloader.set_records()` in `nemo_rl/algorithms/grpo.py` to store relevant information, which can then be retrieved in your custom dataloader implementation: ```python diff --git a/examples/custom_dataloader/custom_dataloader.py b/examples/custom_dataloader/custom_dataloader.py index 5a23d6d0fb..a8281a7f8c 100644 --- a/examples/custom_dataloader/custom_dataloader.py +++ b/examples/custom_dataloader/custom_dataloader.py @@ -29,6 +29,9 @@ def example_custom_dataloader( This function is used to sample data from multiple dataloaders using a custom dataloader function. In this example, we simply sample data from each dataloader. + When a single dataloader is exhausted, the data iterator must be reset (as demonstrated here). + This design ensures that the MultipleDataloaderWrapper operates as an infinite iterator. + Args: data_iterators: A dictionary of data iterators. dataloaders: A dictionary of dataloaders. It is used to reset the data iterator when it is exhausted. @@ -67,6 +70,9 @@ def example_custom_dataloader_with_chosen_task( This function will need to call `wrapped_dataloader.set_records({"chosen_task": ..., "expected_num_prompts": ...})` to set the records in `nemo_rl/algorithms/grpo.py`. A usage example is shown in the test case `test_multiple_dataloader_with_records` in `tests/unit/data/test_multiple_dataloader.py`. + When a single dataloader is exhausted, the data iterator must be reset (as demonstrated here). + This design ensures that the MultipleDataloaderWrapper operates as an infinite iterator. + Args: data_iterators: A dictionary of data iterators. dataloaders: A dictionary of dataloaders. It is used to reset the data iterator when it is exhausted. diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 5cb8565633..3a995dbfb3 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1391,6 +1391,13 @@ def grpo_train( logger.log_metrics(val_metrics, current_step, prefix="validation") logger.log_metrics(validation_timings, current_step, prefix="timing/validation") + if master_config["data"]["use_multiple_dataloader"]: + warnings.warn( + "When using multiple dataloaders, MultipleDataloaderWrapper operates as an infinite iterator. " + "As a result, grpo.max_num_epochs will be ignored, and only grpo.max_num_steps will be used. " + "See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details." + ) + while current_epoch < max_num_epochs and total_steps < max_num_steps: memory_tracker.snapshot_start_of_stage("Preparing batch", dir()) print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") diff --git a/nemo_rl/data/dataloader.py b/nemo_rl/data/dataloader.py index 430ea14804..b1da3406ce 100644 --- a/nemo_rl/data/dataloader.py +++ b/nemo_rl/data/dataloader.py @@ -19,6 +19,9 @@ class MultipleDataloaderWrapper: """Wrapper for multiple dataloaders. This wrapper is used to sample data from multiple dataloaders using a custom dataloader function. + + When a single dataloader is exhausted, the data iterator must be reset in the custom dataloader function (as demonstrated in `examples/custom_dataloader/custom_dataloader.py`). + This design ensures that the MultipleDataloaderWrapper operates as an infinite iterator, where __next__() will not raise StopIteration and __len__() is not supported. """ def __init__(