Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions docs/guides/grpo-audio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Audio GRPO on AVQA

This guide explains how to use NeMo RL to train [Qwen2.5-Omni-3B](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) with GRPO on the [AVQA](https://huggingface.co/datasets/avqa) audio question-answering dataset, following the approach described in the [R1-AQA paper](https://arxiv.org/abs/2503.11197), and then evaluate the trained model on the [MMAU benchmark](https://huggingface.co/datasets/TwinkStart/MMAU).

## 1. Train the Model

Run GRPO training with the provided config:

```
uv run examples/run_grpo.py --config examples/configs/audio_grpo_3B_megatron.yaml
```

Config: `examples/configs/audio_grpo_3B_megatron.yaml`

Key hyperparameters:

| Parameter | Value |
| --- | --- |
| Model | Qwen2.5-Omni-3B |
| Dataset | AVQA (train split) |
| GPUs | 8 x 1 node, Megatron backend |
| Learning rate | 1e-6 |
| KL penalty | 0.01 |
| Generations per prompt | 8 |
| Prompts per step | 8 |
| Max steps | 200 |
| Save period | 100 |
| Reward | format (0.2) + exact_alnum (0.8) |

## 2. Convert Checkpoint (Megatron to HF)

Throughout training, checkpoints are saved to the `results/audio_grpo_3B_megatron` directory (specified by `checkpointing.checkpoint_dir`). To evaluate a checkpoint, first convert it from Megatron format to Hugging Face format:

```
uv run --extra mcore python examples/converters/convert_megatron_to_hf.py \
--config results/audio_grpo_3B_megatron/step_200/config.yaml \
--megatron-ckpt-path results/audio_grpo_3B_megatron/step_200/policy/weights/iter_0000200 \
--hf-ckpt-path results/audio_grpo_3B_megatron/step_200/hf
```

Replace the step number with the checkpoint you want to evaluate. Note the `--extra mcore` flag is required for the Megatron converter.

## 3. Evaluate on MMAU

Evaluate the converted checkpoint on the [MMAU benchmark](https://huggingface.co/datasets/TwinkStart/MMAU):

```
uv run examples/run_eval.py \
--config=examples/configs/evals/mmau.yaml \
generation.model_name=results/audio_grpo_3B_megatron/step_200/hf \
data.dataset_name=TwinkStart/MMAU
```

Config: `examples/configs/evals/mmau.yaml`

Use `generation.model_name` to specify the path to the converted Hugging Face checkpoint.

## 4. Results

Evaluating the step-200 checkpoint on MMAU, we get the following result:

```
============================================================
model_name='hf_iter_0000000' dataset_name='MMAU'
max_new_tokens=8000 temperature=0.0 top_p=1.0 top_k=-1 seed=42

metric=pass@1 num_tests_per_prompt=1

score=0.7210 (721.0/1000)
============================================================
```

As a reference, here are results comparing the baseline, the [R1-AQA](https://arxiv.org/abs/2503.11197) HuggingFace vanilla implementation, and NeMo-RL:

| Model | MMAU Score |
| --- | --- |
| Qwen2.5-Omni-3B (baseline) | 69.8 |
| Qwen2.5-Omni-3B + GRPO (HF vanilla) | 71.6 |
| Qwen2.5-Omni-3B + GRPO (NeMo-RL) | 72.1 |

The NeMo-RL result (72.1) is comparable to and slightly higher than the Huggingface Transformers reference implementation (71.6), confirming that the training pipeline reproduces expected improvements over the baseline.
8 changes: 8 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ Step-by-step guide for supervised fine-tuning on the OpenMathInstruct2 dataset.
Create custom reward environments and integrate them with NeMo RL training pipelines.
:::

:::{grid-item-card} {octicon}`unmute` Audio GRPO on AVQA
:link: guides/grpo-audio
:link-type: doc

Train Qwen2.5-Omni-3B with GRPO on AVQA and evaluate on MMAU, following the R1-AQA approach.
:::

:::{grid-item-card} {octicon}`plus-circle` Adding New Models
:link: adding-new-models
:link-type: doc
Expand Down Expand Up @@ -213,6 +220,7 @@ guides/prorlv2.md
guides/grpo.md
guides/grpo-deepscaler.md
guides/grpo-sliding-puzzle.md
guides/grpo-audio.md
guides/rm.md
guides/environments.md
guides/eval.md
Expand Down
83 changes: 83 additions & 0 deletions examples/configs/audio_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Audio GRPO 3B Megatron Configuration
# Inherits from grpo_math_1B_megatron.yaml and overrides audio-specific settings.
defaults: "grpo_math_1B_megatron.yaml"

grpo:
num_prompts_per_step: 8
num_generations_per_prompt: 8
max_num_steps: 200
max_val_samples: 32
val_batch_size: 32

checkpointing:
enabled: true
checkpoint_dir: results/audio_grpo_3B_megatron
keep_top_k: 10
save_period: 100

policy:
model_name: Qwen/Qwen2.5-Omni-3B
train_global_batch_size: 32
train_micro_batch_size: 1
generation_batch_size: 32
logprob_batch_size: 4
max_total_sequence_length: 2048

sequence_packing:
enabled: false

generation:
max_new_tokens: 1024
vllm_cfg:
# Audio/multimodal models require tokenizer to be initialized before generation
skip_tokenizer_init: False
limit_mm_per_prompt:
audio: 1

megatron_cfg:
converter_type: Qwen2_5OmniForConditionalGeneration
apply_rope_fusion: false
optimizer:
lr: 1.0e-6
min_lr: 1.0e-7
scheduler:
lr_warmup_iters: 10
lr_warmup_init: 1.0e-7
distributed_data_parallel_config:
overlap_grad_reduce: false

data:
train:
dataset_name: avqa
split: train
validation:
dataset_name: avqa
split: validation
default:
prompt_file: null
system_prompt_file: null
processor: "vlm_hf_data_processor"
env_name: "vlm"

env:
vlm:
num_workers: 8
reward_functions:
- name: format
weight: 0.2
- name: exact_alnum
weight: 0.8

logger:
wandb_enabled: true
tensorboard_enabled: true
monitor_gpus: false
wandb:
project: grpo-dev
name: audio-grpo-3b-megatron-large-lr
swanlab:
project: grpo-dev
name: audio-grpo-3b-megatron-large-lr

cluster:
gpus_per_node: 8
58 changes: 58 additions & 0 deletions examples/configs/evals/mmau.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
eval:
metric: "pass@k"
num_tests_per_prompt: 1
seed: 42
k_value: 1
save_path: null

generation:
backend: "vllm"
max_new_tokens: 2048
temperature: 0.0
top_p: 1.0
top_k: -1
num_prompts_per_step: -1
model_name: "Qwen/Qwen2.5-Omni-3B"
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false
precision: "bfloat16"
tensor_parallel_size: 1
pipeline_parallel_size: 1
expert_parallel_size: 1
gpu_memory_utilization: 0.9
max_model_len: 8000
enforce_eager: False
skip_tokenizer_init: False
limit_mm_per_prompt:
audio: 1
colocated:
enabled: true
resources:
gpus_per_node: null
num_nodes: null

tokenizer:
name: ${generation.model_name}
chat_template: "default"
chat_template_kwargs: null

data:
max_input_seq_length: ${generation.vllm_cfg.max_model_len}
prompt_file: null
system_prompt_file: null
dataset_name: "TwinkStart/MMAU"
split: "v05.15.25"
env_name: vlm

env:
mmau:
num_workers: 8
reward_functions:
- name: exact_alnum
weight: 1.0

cluster:
gpus_per_node: 1
num_nodes: 1
27 changes: 12 additions & 15 deletions examples/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,16 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from omegaconf import OmegaConf
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.datasets import AllTaskProcessedDataset, load_eval_dataset
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.data.datasets.eval_datasets import _is_multimodal_dataset
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.environments.math_environment import MathEnvironment
from nemo_rl.environments.utils import create_env
from nemo_rl.evals.eval import MasterConfig, run_env_eval, setup
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.utils.config import load_config

TokenizerType = PreTrainedTokenizerBase


def parse_args():
"""Parse command line arguments."""
Expand All @@ -50,26 +47,25 @@ def parse_args():
return args, overrides


def setup_data(tokenizer: AutoTokenizer, data_config, env_configs):
def setup_data(tokenizer, data_config, env_configs):
print("Setting up data...")

# load dataset
base_dataset = load_eval_dataset(data_config)
rekeyed_ds = base_dataset.rekeyed_ds

env = MathEnvironment.options(
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.math_environment.MathEnvironment"
)
}
).remote(env_configs["math"])
# Determine env from config: use explicit env_name if provided,
# otherwise fall back to the single key in env_configs.
env_key = next(iter(env_configs))
env_name = data_config.get("env_name", env_key)
env = create_env(env_name=env_name, env_config=env_configs[env_key])

dataset = AllTaskProcessedDataset(
dataset=rekeyed_ds,
tokenizer=tokenizer,
default_task_data_spec=base_dataset.task_spec,
task_data_processors=base_dataset.processor,
task_data_preprocessors=getattr(base_dataset, "preprocessor", None),
max_seq_length=data_config["max_input_seq_length"],
)

Expand Down Expand Up @@ -104,8 +100,9 @@ def main():
# Init ray
init_ray()

# Setup tokenizer
tokenizer = get_tokenizer(config["tokenizer"])
# Setup tokenizer — get_tokenizer handles both text-only and multimodal
is_multimodal = _is_multimodal_dataset(config["data"]["dataset_name"])
tokenizer = get_tokenizer(config["tokenizer"], get_processor=is_multimodal)
config["generation"] = configure_generation_config(
config["generation"], tokenizer, is_eval=True
)
Expand Down
18 changes: 18 additions & 0 deletions nemo_rl/data/collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def rl_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]:
]
vllm_images = [datum_spec.get("vllm_images", []) for datum_spec in data_batch]
vllm_videos = [datum_spec.get("vllm_videos", []) for datum_spec in data_batch]
vllm_audios = [datum_spec.get("vllm_audios", []) for datum_spec in data_batch]
extra_args["vllm_content"] = vllm_content
extra_args["vllm_images"] = vllm_images
extra_args["vllm_videos"] = vllm_videos
extra_args["vllm_audios"] = vllm_audios

output: BatchedDataDict[Any] = BatchedDataDict(
message_log=message_log,
Expand Down Expand Up @@ -116,10 +118,26 @@ def eval_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]:
extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch]
idx = [datum_spec["idx"] for datum_spec in data_batch]

# Check if any of the data batch has vllm content (multimodal data)
extra_args = {}
if any(
datum_spec.get("vllm_content", None) is not None for datum_spec in data_batch
):
extra_args["vllm_content"] = [
datum_spec.get("vllm_content", None) for datum_spec in data_batch
]
extra_args["vllm_images"] = [
datum_spec.get("vllm_images", []) for datum_spec in data_batch
]
extra_args["vllm_audios"] = [
datum_spec.get("vllm_audios", []) for datum_spec in data_batch
]

output: BatchedDataDict[Any] = BatchedDataDict(
message_log=message_log,
extra_env_info=extra_env_info,
idx=idx,
**extra_args,
)
return output

Expand Down
19 changes: 19 additions & 0 deletions nemo_rl/data/datasets/eval_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,18 @@
from nemo_rl.data.datasets.eval_datasets.gpqa import GPQADataset
from nemo_rl.data.datasets.eval_datasets.local_math_dataset import LocalMathDataset
from nemo_rl.data.datasets.eval_datasets.math import MathDataset
from nemo_rl.data.datasets.eval_datasets.mmau import MMAUDataset
from nemo_rl.data.datasets.eval_datasets.mmlu import MMLUDataset
from nemo_rl.data.datasets.eval_datasets.mmlu_pro import MMLUProDataset

# Dataset names that require multimodal (VLM) processing
MULTIMODAL_DATASETS = {"mmau", "TwinkStart/MMAU"}


def _is_multimodal_dataset(dataset_name):
"""Check if the dataset requires multimodal processing."""
return dataset_name in MULTIMODAL_DATASETS


def load_eval_dataset(data_config):
"""Loads evaluation dataset."""
Expand Down Expand Up @@ -82,6 +91,13 @@ def load_eval_dataset(data_config):
prompt_file=data_config["prompt_file"],
system_prompt_file=data_config["system_prompt_file"],
)
# mmau
elif dataset_name in ("mmau", "TwinkStart/MMAU"):
split = data_config.get("split", "v05.15.25")
base_dataset = MMAUDataset(
dataset_name="TwinkStart/MMAU",
split=split,
)
# fall back to local dataset
else:
print(f"Loading dataset from {dataset_name}...")
Expand All @@ -103,6 +119,9 @@ def load_eval_dataset(data_config):
"GPQADataset",
"LocalMathDataset",
"MathDataset",
"MMAUDataset",
"MMLUDataset",
"MMLUProDataset",
"MULTIMODAL_DATASETS",
"_is_multimodal_dataset",
]
Loading
Loading