Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,69 @@ def my_data_processor(

We have an example of this as `math_data_processor` in [processors.py](../../nemo_rl/data/processors.py).

#### Multiple Dataloaders

By default, NeMo RL uses a single dataloader that aggregates data from multiple datasets. For scenarios requiring fine-grained control over the number of prompts loaded from each dataset, NeMo RL provides support for multiple dataloaders.

The following example demonstrates how to configure multiple dataloaders:

```bash
uv run examples/run_grpo.py \
--config examples/configs/grpo_multiple_datasets.yaml \
grpo.num_prompts_per_step=32 \
data.use_multiple_dataloader=true \
data.num_prompts_per_dataloader=16 \
data.custom_dataloader=examples.custom_dataloader.custom_dataloader.example_custom_dataloader
```

For example, consider using `example_custom_dataloader`, which samples data from each dataloader sequentially.

Given two datasets:
- Dataset 1: `[a, b, c, d]`
- Dataset 2: `[1, 2, 3, 4, 5, 6, 7, 8]`

With `data.use_multiple_dataloader=false` and `grpo.num_prompts_per_step=4`:
```
Batch 1: [a, b, c, d]
Batch 2: [1, 2, 3, 4]
Batch 3: [5, 6, 7, 8]
```

With `data.use_multiple_dataloader=true`, `grpo.num_prompts_per_step=4`, and `data.num_prompts_per_dataloader=2`:
```
Batch 1: [a, b, 1, 2]
Batch 2: [c, d, 3, 4]
Batch 3: [a, b, 5, 6]
```

**Custom Dataloader**

The file `examples/custom_dataloader/custom_dataloader.py` provides a reference implementation that samples `data.num_prompts_per_dataloader` entries from each dataloader.

When a single dataloader is exhausted, the data iterator must be reset in the custom dataloader function (as demonstrated in `examples/custom_dataloader/custom_dataloader.py`).
This design ensures that the [MultipleDataloaderWrapper](../../nemo_rl/data/dataloader.py) operates as an infinite iterator, where `__next__()` will not raise StopIteration and `__len__()` is not supported.

Additionally, custom dataloaders can access recorded metrics from the training loop. Use `wrapped_dataloader.set_records()` in `nemo_rl/algorithms/grpo.py` to store relevant information, which can then be retrieved in your custom dataloader implementation:

```python
# In nemo_rl/algorithms/grpo.py
wrapped_dataloader.set_records({"reward": ...})

# In custom_dataloader.py
def example_custom_dataloader(
data_iterators: dict[str, Iterator],
dataloaders: dict[str, StatefulDataLoader],
**kwargs,
) -> tuple[BatchedDataDict, dict[str, Iterator]]:
...
reward = kwargs["reward"]
...
```

**num_prompts_per_dataloader**

This parameter specifies the number of prompts generated by each dataloader per iteration. Ensure that `grpo.num_prompts_per_step` is a multiple of `data.num_prompts_per_dataloader` to guarantee that exactly `grpo.num_prompts_per_step` prompts are available for each training step.

### Task–Dataset Mapping

- task_name (unique task identifier):
Expand Down
4 changes: 4 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ data:
shuffle: true
num_workers: 1

# use multiple dataloader for train
# see https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details.
use_multiple_dataloader: false

# dataset
train:
dataset_name: OpenMathInstruct-2
Expand Down
6 changes: 6 additions & 0 deletions examples/configs/grpo_multiple_datasets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ data:
shuffle: true
num_workers: 1

# use multiple dataloader for train
# see https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details.
use_multiple_dataloader: false
num_prompts_per_dataloader: 16
custom_dataloader: examples.custom_dataloader.custom_dataloader.example_custom_dataloader

# dataset
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.
train:
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ data:
shuffle: true
num_workers: 1

# use multiple dataloader for train
use_multiple_dataloader: false

# dataset
train:
dataset_name: clevr-cogent
Expand Down
4 changes: 4 additions & 0 deletions examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ data:
max_input_seq_length: ${policy.max_total_sequence_length}
shuffle: true
num_workers: 1

# use multiple dataloader for train
use_multiple_dataloader: false

# dataset
train:
dataset_name: clevr-cogent
Expand Down
104 changes: 104 additions & 0 deletions examples/custom_dataloader/custom_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterator

from torchdata.stateful_dataloader import StatefulDataLoader

from nemo_rl.distributed.batched_data_dict import BatchedDataDict


def example_custom_dataloader(
data_iterators: dict[str, Iterator],
dataloaders: dict[str, StatefulDataLoader],
**kwargs,
) -> tuple[BatchedDataDict, dict[str, Iterator]]:
"""An example of custom dataloader function.

This function is used to sample data from multiple dataloaders using a custom dataloader function.
In this example, we simply sample data from each dataloader.

When a single dataloader is exhausted, the data iterator must be reset (as demonstrated here).
This design ensures that the MultipleDataloaderWrapper operates as an infinite iterator.

Args:
data_iterators: A dictionary of data iterators.
dataloaders: A dictionary of dataloaders. It is used to reset the data iterator when it is exhausted.
**kwargs: Additional arguments to pass to the custom dataloader function.

Returns:
Data from the dataloaders.
Updated data iterators (may update if the data iterator is exhausted).
"""
# sample data from each dataloader
result = []
for task_name, data_iterator in data_iterators.items():
try:
result.append(next(data_iterator))
except StopIteration:
data_iterators[task_name] = iter(dataloaders[task_name])
result.append(next(data_iterators[task_name]))

# merge results
result = BatchedDataDict.from_batches(result)
return result, data_iterators


def example_custom_dataloader_with_chosen_task(
data_iterators: dict[str, Iterator],
dataloaders: dict[str, StatefulDataLoader],
chosen_task: list[str],
expected_num_prompts: int,
**kwargs,
) -> tuple[BatchedDataDict, dict[str, Iterator]]:
"""An example of custom dataloader function with chosen task.

This function is used to sample data from multiple dataloaders using a custom dataloader function.
In this example, we sample data from the chosen task.

This function will need to call `wrapped_dataloader.set_records({"chosen_task": ..., "expected_num_prompts": ...})` to set the records in `nemo_rl/algorithms/grpo.py`.
A usage example is shown in the test case `test_multiple_dataloader_with_records` in `tests/unit/data/test_multiple_dataloader.py`.

When a single dataloader is exhausted, the data iterator must be reset (as demonstrated here).
This design ensures that the MultipleDataloaderWrapper operates as an infinite iterator.

Args:
data_iterators: A dictionary of data iterators.
dataloaders: A dictionary of dataloaders. It is used to reset the data iterator when it is exhausted.
chosen_task: A list of task names to sample data from.
expected_num_prompts: The expected number of prompts to sample.

Returns:
Data from the dataloaders.
Updated data iterators (may update if the data iterator is exhausted).
"""
# sample data from the chosen task
result = []
current_task_idx = 0
current_num_prompts = 0
while current_num_prompts < expected_num_prompts:
task_name = chosen_task[current_task_idx]
try:
data = next(data_iterators[task_name])
except StopIteration:
data_iterators[task_name] = iter(dataloaders[task_name])
data = next(data_iterators[task_name])

result.append(data)
current_num_prompts += len(data["message_log"])
current_task_idx = (current_task_idx + 1) % len(chosen_task)

# merge results
result = BatchedDataDict.from_batches(result)
return result, data_iterators
6 changes: 6 additions & 0 deletions examples/run_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def main() -> None:
f"{feature} is not supported with async GRPO"
)

# Async GRPO does not support multiple dataloaders
if config["data"]["use_multiple_dataloader"]:
raise NotImplementedError(
"use_multiple_dataloader is not supported with async GRPO"
)

from nemo_rl.algorithms.grpo import async_grpo_train

print("🚀 Running async GRPO training")
Expand Down
Loading
Loading