diff --git a/docs/guides/grpo-audio.md b/docs/guides/grpo-audio.md
new file mode 100644
index 0000000000..5e3d6e7dbb
--- /dev/null
+++ b/docs/guides/grpo-audio.md
@@ -0,0 +1,81 @@
+# Audio GRPO on AVQA
+
+This guide explains how to use NeMo RL to train [Qwen2.5-Omni-3B](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) with GRPO on the [AVQA](https://huggingface.co/datasets/avqa) audio question-answering dataset, following the approach described in the [R1-AQA paper](https://arxiv.org/abs/2503.11197), and then evaluate the trained model on the [MMAU benchmark](https://huggingface.co/datasets/TwinkStart/MMAU).
+
+## 1. Train the Model
+
+Run GRPO training with the provided config:
+
+```
+uv run examples/run_grpo.py --config examples/configs/audio_grpo_3B_megatron.yaml
+```
+
+Config: `examples/configs/audio_grpo_3B_megatron.yaml`
+
+Key hyperparameters:
+
+| Parameter | Value |
+| --- | --- |
+| Model | Qwen2.5-Omni-3B |
+| Dataset | AVQA (train split) |
+| GPUs | 8 x 1 node, Megatron backend |
+| Learning rate | 1e-6 |
+| KL penalty | 0.01 |
+| Generations per prompt | 8 |
+| Prompts per step | 8 |
+| Max steps | 200 |
+| Save period | 100 |
+| Reward | format (0.2) + exact_alnum (0.8) |
+
+## 2. Convert Checkpoint (Megatron to HF)
+
+Throughout training, checkpoints are saved to the `results/audio_grpo_3B_megatron` directory (specified by `checkpointing.checkpoint_dir`). To evaluate a checkpoint, first convert it from Megatron format to Hugging Face format:
+
+```
+uv run --extra mcore python examples/converters/convert_megatron_to_hf.py \
+ --config results/audio_grpo_3B_megatron/step_200/config.yaml \
+ --megatron-ckpt-path results/audio_grpo_3B_megatron/step_200/policy/weights/iter_0000200 \
+ --hf-ckpt-path results/audio_grpo_3B_megatron/step_200/hf
+```
+
+Replace the step number with the checkpoint you want to evaluate. Note the `--extra mcore` flag is required for the Megatron converter.
+
+## 3. Evaluate on MMAU
+
+Evaluate the converted checkpoint on the [MMAU benchmark](https://huggingface.co/datasets/TwinkStart/MMAU):
+
+```
+uv run examples/run_eval.py \
+ --config=examples/configs/evals/mmau.yaml \
+ generation.model_name=results/audio_grpo_3B_megatron/step_200/hf \
+ data.dataset_name=TwinkStart/MMAU
+```
+
+Config: `examples/configs/evals/mmau.yaml`
+
+Use `generation.model_name` to specify the path to the converted Hugging Face checkpoint.
+
+## 4. Results
+
+Evaluating the step-200 checkpoint on MMAU, we get the following result:
+
+```
+============================================================
+model_name='hf_iter_0000000' dataset_name='MMAU'
+max_new_tokens=8000 temperature=0.0 top_p=1.0 top_k=-1 seed=42
+
+metric=pass@1 num_tests_per_prompt=1
+
+score=0.7210 (721.0/1000)
+============================================================
+```
+
+As a reference, here are results comparing the baseline, the [R1-AQA](https://arxiv.org/abs/2503.11197) HuggingFace vanilla implementation, and NeMo-RL:
+
+| Model | MMAU Score |
+| --- | --- |
+| Qwen2.5-Omni-3B (baseline) | 69.8 |
+| Qwen2.5-Omni-3B + GRPO (HF vanilla) | 71.6 |
+| Qwen2.5-Omni-3B + GRPO (NeMo-RL) | 72.1 |
+
+The NeMo-RL result (72.1) is comparable to and slightly higher than the Huggingface Transformers reference implementation (71.6), confirming that the training pipeline reproduces expected improvements over the baseline.
diff --git a/docs/index.md b/docs/index.md
index 7cb4e655ad..ed7c6f9802 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -107,6 +107,13 @@ Step-by-step guide for supervised fine-tuning on the OpenMathInstruct2 dataset.
Create custom reward environments and integrate them with NeMo RL training pipelines.
:::
+:::{grid-item-card} {octicon}`unmute` Audio GRPO on AVQA
+:link: guides/grpo-audio
+:link-type: doc
+
+Train Qwen2.5-Omni-3B with GRPO on AVQA and evaluate on MMAU, following the R1-AQA approach.
+:::
+
:::{grid-item-card} {octicon}`plus-circle` Adding New Models
:link: adding-new-models
:link-type: doc
@@ -213,6 +220,7 @@ guides/prorlv2.md
guides/grpo.md
guides/grpo-deepscaler.md
guides/grpo-sliding-puzzle.md
+guides/grpo-audio.md
guides/rm.md
guides/environments.md
guides/eval.md
diff --git a/examples/configs/audio_grpo_3B_megatron.yaml b/examples/configs/audio_grpo_3B_megatron.yaml
new file mode 100644
index 0000000000..e90fdea1f6
--- /dev/null
+++ b/examples/configs/audio_grpo_3B_megatron.yaml
@@ -0,0 +1,83 @@
+# Audio GRPO 3B Megatron Configuration
+# Inherits from grpo_math_1B_megatron.yaml and overrides audio-specific settings.
+defaults: "grpo_math_1B_megatron.yaml"
+
+grpo:
+ num_prompts_per_step: 8
+ num_generations_per_prompt: 8
+ max_num_steps: 200
+ max_val_samples: 32
+ val_batch_size: 32
+
+checkpointing:
+ enabled: true
+ checkpoint_dir: results/audio_grpo_3B_megatron
+ keep_top_k: 10
+ save_period: 100
+
+policy:
+ model_name: Qwen/Qwen2.5-Omni-3B
+ train_global_batch_size: 32
+ train_micro_batch_size: 1
+ generation_batch_size: 32
+ logprob_batch_size: 4
+ max_total_sequence_length: 2048
+
+ sequence_packing:
+ enabled: false
+
+ generation:
+ max_new_tokens: 1024
+ vllm_cfg:
+ # Audio/multimodal models require tokenizer to be initialized before generation
+ skip_tokenizer_init: False
+ limit_mm_per_prompt:
+ audio: 1
+
+ megatron_cfg:
+ converter_type: Qwen2_5OmniForConditionalGeneration
+ apply_rope_fusion: false
+ optimizer:
+ lr: 1.0e-6
+ min_lr: 1.0e-7
+ scheduler:
+ lr_warmup_iters: 10
+ lr_warmup_init: 1.0e-7
+ distributed_data_parallel_config:
+ overlap_grad_reduce: false
+
+data:
+ train:
+ dataset_name: avqa
+ split: train
+ validation:
+ dataset_name: avqa
+ split: validation
+ default:
+ prompt_file: null
+ system_prompt_file: null
+ processor: "vlm_hf_data_processor"
+ env_name: "vlm"
+
+env:
+ vlm:
+ num_workers: 8
+ reward_functions:
+ - name: format
+ weight: 0.2
+ - name: exact_alnum
+ weight: 0.8
+
+logger:
+ wandb_enabled: true
+ tensorboard_enabled: true
+ monitor_gpus: false
+ wandb:
+ project: grpo-dev
+ name: audio-grpo-3b-megatron-large-lr
+ swanlab:
+ project: grpo-dev
+ name: audio-grpo-3b-megatron-large-lr
+
+cluster:
+ gpus_per_node: 8
diff --git a/examples/configs/evals/mmau.yaml b/examples/configs/evals/mmau.yaml
new file mode 100644
index 0000000000..0338937f9b
--- /dev/null
+++ b/examples/configs/evals/mmau.yaml
@@ -0,0 +1,58 @@
+eval:
+ metric: "pass@k"
+ num_tests_per_prompt: 1
+ seed: 42
+ k_value: 1
+ save_path: null
+
+generation:
+ backend: "vllm"
+ max_new_tokens: 2048
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ num_prompts_per_step: -1
+ model_name: "Qwen/Qwen2.5-Omni-3B"
+ stop_token_ids: null
+ stop_strings: null
+ vllm_cfg:
+ async_engine: false
+ precision: "bfloat16"
+ tensor_parallel_size: 1
+ pipeline_parallel_size: 1
+ expert_parallel_size: 1
+ gpu_memory_utilization: 0.9
+ max_model_len: 8000
+ enforce_eager: False
+ skip_tokenizer_init: False
+ limit_mm_per_prompt:
+ audio: 1
+ colocated:
+ enabled: true
+ resources:
+ gpus_per_node: null
+ num_nodes: null
+
+tokenizer:
+ name: ${generation.model_name}
+ chat_template: "default"
+ chat_template_kwargs: null
+
+data:
+ max_input_seq_length: ${generation.vllm_cfg.max_model_len}
+ prompt_file: null
+ system_prompt_file: null
+ dataset_name: "TwinkStart/MMAU"
+ split: "v05.15.25"
+ env_name: vlm
+
+env:
+ mmau:
+ num_workers: 8
+ reward_functions:
+ - name: exact_alnum
+ weight: 1.0
+
+cluster:
+ gpus_per_node: 1
+ num_nodes: 1
diff --git a/examples/run_eval.py b/examples/run_eval.py
index 8966938632..3d678466ed 100644
--- a/examples/run_eval.py
+++ b/examples/run_eval.py
@@ -20,19 +20,16 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from omegaconf import OmegaConf
-from transformers import AutoTokenizer, PreTrainedTokenizerBase
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.datasets import AllTaskProcessedDataset, load_eval_dataset
-from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
+from nemo_rl.data.datasets.eval_datasets import _is_multimodal_dataset
from nemo_rl.distributed.virtual_cluster import init_ray
-from nemo_rl.environments.math_environment import MathEnvironment
+from nemo_rl.environments.utils import create_env
from nemo_rl.evals.eval import MasterConfig, run_env_eval, setup
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.utils.config import load_config
-TokenizerType = PreTrainedTokenizerBase
-
def parse_args():
"""Parse command line arguments."""
@@ -50,26 +47,25 @@ def parse_args():
return args, overrides
-def setup_data(tokenizer: AutoTokenizer, data_config, env_configs):
+def setup_data(tokenizer, data_config, env_configs):
print("Setting up data...")
# load dataset
base_dataset = load_eval_dataset(data_config)
rekeyed_ds = base_dataset.rekeyed_ds
- env = MathEnvironment.options(
- runtime_env={
- "py_executable": get_actor_python_env(
- "nemo_rl.environments.math_environment.MathEnvironment"
- )
- }
- ).remote(env_configs["math"])
+ # Determine env from config: use explicit env_name if provided,
+ # otherwise fall back to the single key in env_configs.
+ env_key = next(iter(env_configs))
+ env_name = data_config.get("env_name", env_key)
+ env = create_env(env_name=env_name, env_config=env_configs[env_key])
dataset = AllTaskProcessedDataset(
dataset=rekeyed_ds,
tokenizer=tokenizer,
default_task_data_spec=base_dataset.task_spec,
task_data_processors=base_dataset.processor,
+ task_data_preprocessors=getattr(base_dataset, "preprocessor", None),
max_seq_length=data_config["max_input_seq_length"],
)
@@ -104,8 +100,9 @@ def main():
# Init ray
init_ray()
- # Setup tokenizer
- tokenizer = get_tokenizer(config["tokenizer"])
+ # Setup tokenizer — get_tokenizer handles both text-only and multimodal
+ is_multimodal = _is_multimodal_dataset(config["data"]["dataset_name"])
+ tokenizer = get_tokenizer(config["tokenizer"], get_processor=is_multimodal)
config["generation"] = configure_generation_config(
config["generation"], tokenizer, is_eval=True
)
diff --git a/nemo_rl/data/collate_fn.py b/nemo_rl/data/collate_fn.py
index 09d2cf766a..6f4291aa43 100644
--- a/nemo_rl/data/collate_fn.py
+++ b/nemo_rl/data/collate_fn.py
@@ -55,9 +55,11 @@ def rl_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]:
]
vllm_images = [datum_spec.get("vllm_images", []) for datum_spec in data_batch]
vllm_videos = [datum_spec.get("vllm_videos", []) for datum_spec in data_batch]
+ vllm_audios = [datum_spec.get("vllm_audios", []) for datum_spec in data_batch]
extra_args["vllm_content"] = vllm_content
extra_args["vllm_images"] = vllm_images
extra_args["vllm_videos"] = vllm_videos
+ extra_args["vllm_audios"] = vllm_audios
output: BatchedDataDict[Any] = BatchedDataDict(
message_log=message_log,
@@ -116,10 +118,26 @@ def eval_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]:
extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch]
idx = [datum_spec["idx"] for datum_spec in data_batch]
+ # Check if any of the data batch has vllm content (multimodal data)
+ extra_args = {}
+ if any(
+ datum_spec.get("vllm_content", None) is not None for datum_spec in data_batch
+ ):
+ extra_args["vllm_content"] = [
+ datum_spec.get("vllm_content", None) for datum_spec in data_batch
+ ]
+ extra_args["vllm_images"] = [
+ datum_spec.get("vllm_images", []) for datum_spec in data_batch
+ ]
+ extra_args["vllm_audios"] = [
+ datum_spec.get("vllm_audios", []) for datum_spec in data_batch
+ ]
+
output: BatchedDataDict[Any] = BatchedDataDict(
message_log=message_log,
extra_env_info=extra_env_info,
idx=idx,
+ **extra_args,
)
return output
diff --git a/nemo_rl/data/datasets/eval_datasets/__init__.py b/nemo_rl/data/datasets/eval_datasets/__init__.py
index 8386286c83..d813ed040c 100644
--- a/nemo_rl/data/datasets/eval_datasets/__init__.py
+++ b/nemo_rl/data/datasets/eval_datasets/__init__.py
@@ -16,9 +16,18 @@
from nemo_rl.data.datasets.eval_datasets.gpqa import GPQADataset
from nemo_rl.data.datasets.eval_datasets.local_math_dataset import LocalMathDataset
from nemo_rl.data.datasets.eval_datasets.math import MathDataset
+from nemo_rl.data.datasets.eval_datasets.mmau import MMAUDataset
from nemo_rl.data.datasets.eval_datasets.mmlu import MMLUDataset
from nemo_rl.data.datasets.eval_datasets.mmlu_pro import MMLUProDataset
+# Dataset names that require multimodal (VLM) processing
+MULTIMODAL_DATASETS = {"mmau", "TwinkStart/MMAU"}
+
+
+def _is_multimodal_dataset(dataset_name):
+ """Check if the dataset requires multimodal processing."""
+ return dataset_name in MULTIMODAL_DATASETS
+
def load_eval_dataset(data_config):
"""Loads evaluation dataset."""
@@ -82,6 +91,13 @@ def load_eval_dataset(data_config):
prompt_file=data_config["prompt_file"],
system_prompt_file=data_config["system_prompt_file"],
)
+ # mmau
+ elif dataset_name in ("mmau", "TwinkStart/MMAU"):
+ split = data_config.get("split", "v05.15.25")
+ base_dataset = MMAUDataset(
+ dataset_name="TwinkStart/MMAU",
+ split=split,
+ )
# fall back to local dataset
else:
print(f"Loading dataset from {dataset_name}...")
@@ -103,6 +119,9 @@ def load_eval_dataset(data_config):
"GPQADataset",
"LocalMathDataset",
"MathDataset",
+ "MMAUDataset",
"MMLUDataset",
"MMLUProDataset",
+ "MULTIMODAL_DATASETS",
+ "_is_multimodal_dataset",
]
diff --git a/nemo_rl/data/datasets/eval_datasets/mmau.py b/nemo_rl/data/datasets/eval_datasets/mmau.py
new file mode 100644
index 0000000000..8f23e94b10
--- /dev/null
+++ b/nemo_rl/data/datasets/eval_datasets/mmau.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""MMAU (Massive Multitask Audio Understanding) evaluation dataset."""
+
+from typing import Any
+
+from datasets import load_dataset
+
+from nemo_rl.data.datasets.response_datasets.avqa import _resample_audio
+from nemo_rl.data.interfaces import TaskDataSpec
+from nemo_rl.data.processors import vlm_hf_data_processor
+
+DEFAULT_TEMPLATE = (
+ "{question} Please choose the answer from the following options: {choices}. "
+ "Output the final answer in ."
+)
+
+
+class MMAUDataset:
+ """MMAU evaluation dataset.
+
+ Loads the TwinkStart/MMAU HF dataset and formats each item into the
+ messages format expected by vlm_hf_data_processor.
+
+ Args:
+ dataset_name: HuggingFace dataset name.
+ split: Dataset split to load.
+ """
+
+ def __init__(
+ self,
+ dataset_name: str = "TwinkStart/MMAU",
+ split: str = "v05.15.25",
+ ):
+ ds = load_dataset(dataset_name, split=split)
+
+ self.rekeyed_ds = ds
+ self.task_spec = TaskDataSpec(task_name="mmau")
+ self.processor = vlm_hf_data_processor
+ self.preprocessor = self.format_data
+
+ def format_data(self, data: dict[str, Any]) -> dict[str, Any]:
+ """Convert a raw MMAU item into messages format for vlm_hf_data_processor."""
+ audio_data = data["audio"]
+ audio_array = audio_data["array"]
+
+ # Resample to 16kHz if needed
+ if audio_data["sampling_rate"] != 16000:
+ audio_array = _resample_audio(
+ audio_array, audio_data["sampling_rate"], 16000
+ )
+
+ question = data["question"]
+ choices = data["choices"]
+ answer = data["answer"]
+
+ prompt_text = DEFAULT_TEMPLATE.format(question=question, choices=choices)
+
+ user_content = [
+ {"type": "audio", "audio": audio_array},
+ {"type": "text", "text": prompt_text},
+ ]
+ return {
+ "messages": [
+ {"role": "user", "content": user_content},
+ {"role": "assistant", "content": answer},
+ ],
+ "task_name": "mmau",
+ "choices": choices,
+ }
diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py
index eb48bb5204..85107495c1 100644
--- a/nemo_rl/data/datasets/response_datasets/__init__.py
+++ b/nemo_rl/data/datasets/response_datasets/__init__.py
@@ -14,6 +14,7 @@
from nemo_rl.data import ResponseDatasetConfig
from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset
+from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset
from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset
from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset
from nemo_rl.data.datasets.response_datasets.dapo_math import (
@@ -41,6 +42,7 @@
DATASET_REGISTRY = {
# built-in datasets
+ "avqa": AVQADataset,
"AIME2024": AIME2024Dataset,
"clevr-cogent": CLEVRCoGenTDataset,
"daily-omni": DailyOmniDataset,
@@ -88,6 +90,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig):
__all__ = [
+ "AVQADataset",
"AIME2024Dataset",
"CLEVRCoGenTDataset",
"DailyOmniDataset",
diff --git a/nemo_rl/data/datasets/response_datasets/avqa.py b/nemo_rl/data/datasets/response_datasets/avqa.py
new file mode 100644
index 0000000000..e05648fe52
--- /dev/null
+++ b/nemo_rl/data/datasets/response_datasets/avqa.py
@@ -0,0 +1,139 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Any
+
+import numpy as np
+import torch
+import torchaudio
+from datasets import Dataset, load_dataset
+
+from nemo_rl.data.datasets.raw_dataset import RawDataset
+
+DEFAULT_TEMPLATE = (
+ "{question} Please choose the answer from the following options: {choices}. "
+ "Output the final answer in ."
+)
+
+
+def _resample_audio(audio_array, orig_sr, target_sr=16000):
+ """Resample audio to target sample rate."""
+ if isinstance(audio_array, np.ndarray):
+ waveform = torch.from_numpy(audio_array).float()
+ else:
+ waveform = audio_array.float()
+
+ if waveform.dim() == 1:
+ waveform = waveform.unsqueeze(0)
+
+ resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
+ resampled = resampler(waveform)
+ return resampled[0].numpy()
+
+
+def _parse_question(question_text):
+ r"""Parse the HF dataset question format.
+
+ Input: "How many animals are there in the video?\nChoices:\nA. 3\nB. One\nC. 4\nD. 2"
+ Returns: (question, choices_list)
+ """
+ parts = question_text.split("\nChoices:\n")
+ if len(parts) == 2:
+ question = parts[0]
+ choices = []
+ for line in parts[1].strip().split("\n"):
+ line = line.strip()
+ if line:
+ match = re.match(r"^[A-Z]\.\s*(.+)$", line)
+ choices.append(match.group(1) if match else line)
+ return question, choices
+ return question_text, []
+
+
+class AVQADataset(RawDataset):
+ """Wrapper around the AVQA (Audio-Visual Question Answering) dataset.
+
+ Formats audio samples into OpenAI-style messages for audio QA
+ fine-tuning with Qwen2.5-Omni.
+
+ Args:
+ split: Split name for the dataset. Supported: "train", "validation".
+ """
+
+ task_name = "avqa"
+
+ def __init__(
+ self,
+ split: str = "train",
+ split_validation_size: float = 0,
+ seed: int = 42,
+ max_samples: int | None = None,
+ **kwargs,
+ ):
+ VALID_SPLITS = ("train", "validation")
+ if split not in VALID_SPLITS:
+ raise ValueError(
+ f"Invalid split: {split}. Please use one of {VALID_SPLITS}."
+ )
+
+ if max_samples is not None:
+ ds = load_dataset("gijs/avqa-processed", split=split, streaming=True)
+ self.dataset = Dataset.from_list(list(ds.take(max_samples)))
+ else:
+ self.dataset = load_dataset("gijs/avqa-processed", split=split)
+
+ self.dataset = self.dataset.add_column(
+ "task_name", [self.task_name] * len(self.dataset)
+ )
+
+ self.preprocessor = self.format_data
+
+ # `self.val_dataset` is used (not None) only when current dataset is used for both training and validation
+ self.val_dataset = None
+ self.split_train_validation(split_validation_size, seed)
+
+ def format_data(self, data: dict[str, Any]) -> dict[str, Any]:
+ audio_data = data["audio"]
+ audio_array = audio_data["array"]
+
+ # Resample to 16kHz if needed
+ if audio_data["sampling_rate"] != 16000:
+ audio_array = _resample_audio(
+ audio_array, audio_data["sampling_rate"], 16000
+ )
+
+ # Parse question and build prompt
+ question, choices = _parse_question(data["question"])
+ question = question.replace("video", "audio")
+
+ prompt_text = DEFAULT_TEMPLATE.format(question=question, choices=choices)
+
+ # Strip letter prefix from answer (e.g., "B. Yacht consignment" -> "Yacht consignment")
+ answer = data["answer"]
+ answer_match = re.match(r"^[A-Z]\.\s*(.+)$", answer)
+ if answer_match:
+ answer = answer_match.group(1)
+
+ user_content = [
+ {"type": "audio", "audio": audio_array},
+ {"type": "text", "text": prompt_text},
+ ]
+ return {
+ "messages": [
+ {"role": "user", "content": user_content},
+ {"role": "assistant", "content": answer},
+ ],
+ "task_name": self.task_name,
+ }
diff --git a/nemo_rl/data/datasets/utils.py b/nemo_rl/data/datasets/utils.py
index 5cb407dbdd..4a54f9ca39 100644
--- a/nemo_rl/data/datasets/utils.py
+++ b/nemo_rl/data/datasets/utils.py
@@ -34,18 +34,24 @@ def assert_no_double_bos(token_ids: torch.Tensor, tokenizer: TokenizerType) -> N
token_ids: List of token IDs
tokenizer: Tokenizer
"""
- if tokenizer.bos_token_id is not None:
+ # AutoProcessor wraps a tokenizer; unwrap if needed
+ if isinstance(tokenizer, PreTrainedTokenizerBase):
+ _tok = tokenizer
+ elif isinstance(tokenizer, AutoProcessor):
+ _tok = tokenizer.tokenizer
+ else:
+ raise TypeError(f"Unsupported tokenizer type: {type(tokenizer)}")
+
+ if _tok.bos_token_id is not None:
token_ids_list = token_ids.tolist()
if len(token_ids_list) > 1:
assert not (
- token_ids_list[0] == tokenizer.bos_token_id
- and token_ids_list[1] == tokenizer.bos_token_id
+ token_ids_list[0] == _tok.bos_token_id
+ and token_ids_list[1] == _tok.bos_token_id
), "Found double BOS token in the first two positions of the message."
else:
- # `name_or_path` is not available for AutoProcessor, temp fix in get_tokenizer
- print(
- f"skip assert_start_single_bos since Tokenizer {tokenizer.name_or_path} has no BOS token"
- )
+ name = getattr(_tok, "name_or_path", str(type(_tok).__name__))
+ print(f"skip assert_start_single_bos since Tokenizer {name} has no BOS token")
def pil_to_base64(image: Image.Image, format: str = "PNG") -> str:
diff --git a/nemo_rl/data/multimodal_utils.py b/nemo_rl/data/multimodal_utils.py
index 0513ec9760..6a16e18fec 100644
--- a/nemo_rl/data/multimodal_utils.py
+++ b/nemo_rl/data/multimodal_utils.py
@@ -214,7 +214,7 @@ def get_multimodal_keys_from_processor(processor) -> list[str]:
all_keys.update(processor.video_processor.model_input_names)
if hasattr(processor, "feature_extractor"):
all_keys.update(processor.feature_extractor.model_input_names)
- # all_keys.update(processor.model_input_names)
+ all_keys.update(processor.model_input_names)
all_keys.difference_update(set(processor.tokenizer.model_input_names))
return list(all_keys)
diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py
index 52ac9bf67d..7d0d4f753b 100644
--- a/nemo_rl/data/processors.py
+++ b/nemo_rl/data/processors.py
@@ -467,29 +467,30 @@ def vlm_hf_data_processor(
datum_dict = format_refcoco_dataset(datum_dict)
elif datum_dict["task_name"] == "geometry3k":
datum_dict = format_geometry3k_dataset(datum_dict)
+ elif datum_dict["task_name"] == "avqa":
+ pass # AVQA data is already formatted by AVQADataset.format_data
+ elif datum_dict["task_name"] == "mmau":
+ pass # MMAU data is already formatted by MMAUDataset.format_data
else:
raise ValueError(f"No data processor for task {datum_dict['task_name']}")
user_message = datum_dict["messages"]
problem = user_message[0]["content"]
extra_env_info = {"ground_truth": user_message[1]["content"]}
+ if "choices" in datum_dict:
+ extra_env_info["choices"] = datum_dict["choices"]
message_log: VLMMessageLogType = []
### only one round of interaction is assumed, this can easily be extended to a conversational setting
user_message: dict[str, Any] = {"role": "user", "content": []}
#
images = []
+ audios = []
if isinstance(problem, list):
for content in problem:
- # for image, video, just append it
+ # for image, video, audio, just append it
# for text, format the prompt to the problem
- if content["type"] != "text":
- user_message["content"].append(content)
- if content["type"] == "image":
- images.append(content["image"])
- else:
- raise ValueError(f"Unsupported content type: {content['type']}")
- elif content["type"] == "text":
+ if content["type"] == "text":
user_message["content"].append(
{
"type": "text",
@@ -498,6 +499,15 @@ def vlm_hf_data_processor(
else content["text"],
}
)
+ elif content["type"] == "image":
+ user_message["content"].append(content)
+ images.append(content["image"])
+ elif content["type"] == "audio":
+ user_message["content"].append(content)
+ # Store as (audio_array, sample_rate) tuple for vLLM
+ audios.append((content["audio"], 16000))
+ else:
+ raise ValueError(f"Unsupported content type: {content['type']}")
else:
# conversation consists of a text-only message
user_message["content"] = task_data_spec.prompt.format(problem)
@@ -552,6 +562,7 @@ def vlm_hf_data_processor(
vllm_kwargs = {
"vllm_content": None,
"vllm_images": [],
+ "vllm_audios": [],
}
# make smaller and mask out
@@ -564,11 +575,11 @@ def vlm_hf_data_processor(
chat_message[key] = PackedTensor.empty_like(value)
loss_multiplier = 0.0
else:
- # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation
- # add images for vllm serving
+ # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images/audios) for the entire conversation
vllm_kwargs = {
"vllm_content": string_formatted_dialog,
"vllm_images": images,
+ "vllm_audios": audios,
}
output: DatumSpec = {
diff --git a/nemo_rl/environments/vlm_environment.py b/nemo_rl/environments/vlm_environment.py
index 7e4943c3b2..34e32b4e98 100644
--- a/nemo_rl/environments/vlm_environment.py
+++ b/nemo_rl/environments/vlm_environment.py
@@ -144,6 +144,7 @@ def step( # type: ignore[override]
self,
message_log_batch: list[list[dict[str, str]]],
metadata: list[VLMEnvironmentMetadata],
+ return_extracted_answer: bool = False,
) -> EnvironmentReturn:
"""Runs a step in the vlm environment.
diff --git a/nemo_rl/evals/eval.py b/nemo_rl/evals/eval.py
index d67255ef1e..94723d5ebc 100644
--- a/nemo_rl/evals/eval.py
+++ b/nemo_rl/evals/eval.py
@@ -322,11 +322,33 @@ async def _run_env_eval_impl(
batch = batch.repeat_interleave(num_tests_per_prompt)
# get input prompt from message_log
+ is_multimodal = "vllm_content" in batch
prompts = []
- for message_log in batch["message_log"]:
- content = [message["content"] for message in message_log]
- content = "\n".join(content)
- prompts.append(content)
+ prompts_for_display = []
+ for i, message_log in enumerate(batch["message_log"]):
+ if is_multimodal and batch["vllm_content"][i] is not None:
+ vllm_content = batch["vllm_content"][i]
+ prompt_dict = {"prompt": vllm_content}
+ multi_modal_data = {}
+ audios = batch.get("vllm_audios", None)
+ if audios is not None and len(audios[i]) > 0:
+ multi_modal_data["audio"] = (
+ audios[i][0] if len(audios[i]) == 1 else audios[i]
+ )
+ images = batch.get("vllm_images", None)
+ if images is not None and len(images[i]) > 0:
+ multi_modal_data["image"] = (
+ images[i][0] if len(images[i]) == 1 else images[i]
+ )
+ if multi_modal_data:
+ prompt_dict["multi_modal_data"] = multi_modal_data
+ prompts.append(prompt_dict)
+ prompts_for_display.append(vllm_content)
+ else:
+ content = [message["content"] for message in message_log]
+ content = "\n".join(content)
+ prompts.append(content)
+ prompts_for_display.append(content)
# generate by vllm
inputs = BatchedDataDict({"prompts": prompts})
@@ -353,7 +375,7 @@ async def _run_env_eval_impl(
# Collect data for JSON file
for i, (prompt, output, message_log, reward, extra_info) in enumerate(
zip(
- prompts,
+ prompts_for_display,
outputs,
batch["message_log"],
rewards.tolist(),
diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py
index 603a972095..4093e4370f 100644
--- a/nemo_rl/experience/rollouts.py
+++ b/nemo_rl/experience/rollouts.py
@@ -430,6 +430,8 @@ def run_multi_turn_rollout(
generation_input_data["vllm_images"] = active_batch["vllm_images"]
if "vllm_videos" in active_batch:
generation_input_data["vllm_videos"] = active_batch["vllm_videos"]
+ if "vllm_audios" in active_batch:
+ generation_input_data["vllm_audios"] = active_batch["vllm_audios"]
# generate_responses updates active_batch["message_log"] in-place
active_batch, generated_ids, gen_metrics = generate_responses(
diff --git a/nemo_rl/models/generation/vllm/utils.py b/nemo_rl/models/generation/vllm/utils.py
index 4be7d95117..f9ecb3523b 100644
--- a/nemo_rl/models/generation/vllm/utils.py
+++ b/nemo_rl/models/generation/vllm/utils.py
@@ -66,15 +66,22 @@ def _get_regular_prompt(index: int):
continue
# init prompt dict
prompt_dict = {"prompt": msg}
- # add additional data if present
+ # collect multi_modal_data from images and audios
+ multi_modal_data = {}
images = data.get("vllm_images", None)
- if images is None or len(images[i]) == 0:
+ if images is not None and len(images[i]) > 0:
+ multi_modal_data["image"] = (
+ images[i][0] if len(images[i]) == 1 else images[i]
+ )
+ audios = data.get("vllm_audios", None)
+ if audios is not None and len(audios[i]) > 0:
+ multi_modal_data["audio"] = (
+ audios[i][0] if len(audios[i]) == 1 else audios[i]
+ )
+ if not multi_modal_data:
prompts.append(_get_regular_prompt(i))
continue
- else:
- prompt_dict["multi_modal_data"] = {
- "image": images[i][0] if len(images[i]) == 1 else images[i]
- }
+ prompt_dict["multi_modal_data"] = multi_modal_data
prompts.append(prompt_dict)
else:
# Regular LLM generation using token_ids
diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py
index 1e3f1420d3..13232b3e45 100644
--- a/nemo_rl/models/megatron/setup.py
+++ b/nemo_rl/models/megatron/setup.py
@@ -693,6 +693,8 @@ def freeze_moe_router(megatron_model):
if isinstance(model_module, Float16Module):
model_module = model_module.module
# Handle VLM models
+ if hasattr(model_module, "thinker"):
+ model_module = model_module.thinker
if hasattr(model_module, "language_model"):
model_module = model_module.language_model
for layer in model_module.decoder.layers:
diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py
index a4214032f7..a65e6dd77d 100644
--- a/nemo_rl/utils/logger.py
+++ b/nemo_rl/utils/logger.py
@@ -1002,7 +1002,10 @@ def log_batched_dict_as_jsonl(
for key, value in sample.items():
if isinstance(value, torch.Tensor):
sample[key] = value.tolist()
- f.write(json.dumps({**sample, "idx": i}) + "\n")
+ elif isinstance(value, np.ndarray):
+ sample[key] = value.tolist()
+ # default=str is a fallback for non-JSON-serializable types (e.g., datetime, custom objects)
+ f.write(json.dumps({**sample, "idx": i}, default=str) + "\n")
print(f"Logged data to {filepath}")
diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh
index bee4d8d2eb..41ce679680 100644
--- a/tests/functional/L1_Functional_Tests_GPU.sh
+++ b/tests/functional/L1_Functional_Tests_GPU.sh
@@ -38,6 +38,7 @@ run_test() {
run_test bash ./tests/functional/grpo_frozen_env.sh
run_test bash ./tests/functional/test_frozen_env.sh
+run_test uv run --no-sync bash ./tests/functional/audio_grpo_megatron.sh
run_test fast uv run --no-sync bash ./tests/functional/distillation.sh
run_test uv run --no-sync bash ./tests/functional/distillation_megatron.sh
run_test fast uv run --no-sync bash ./tests/functional/dpo.sh
@@ -45,6 +46,7 @@ run_test uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh
run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh
run_test uv run --no-sync bash ./tests/functional/eval.sh
run_test uv run --no-sync bash ./tests/functional/eval_async.sh
+run_test uv run --no-sync bash ./tests/functional/eval_audio.sh
run_test fast uv run --no-sync bash ./tests/functional/grpo.sh
run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh
run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh
diff --git a/tests/functional/audio_grpo_megatron.sh b/tests/functional/audio_grpo_megatron.sh
new file mode 100644
index 0000000000..2e52a7fa29
--- /dev/null
+++ b/tests/functional/audio_grpo_megatron.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
+PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
+# Mark the current repo as safe, since wandb fetches metadata about the repo
+git config --global --add safe.directory $PROJECT_ROOT
+
+set -eou pipefail
+
+EXP_NAME=$(basename $0 .sh)
+EXP_DIR=$SCRIPT_DIR/$EXP_NAME
+LOG_DIR=$EXP_DIR/logs
+JSON_METRICS=$EXP_DIR/metrics.json
+RUN_LOG=$EXP_DIR/run.log
+export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
+
+rm -rf $EXP_DIR $LOG_DIR
+mkdir -p $EXP_DIR $LOG_DIR
+
+cd $PROJECT_ROOT
+uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
+ $PROJECT_ROOT/examples/run_vlm_grpo.py \
+ --config $PROJECT_ROOT/examples/configs/audio_grpo_3B_megatron.yaml \
+ policy.model_name=Qwen/Qwen2.5-Omni-3B \
+ grpo.num_prompts_per_step=2 \
+ grpo.num_generations_per_prompt=4 \
+ policy.train_global_batch_size=4 \
+ policy.train_micro_batch_size=1 \
+ cluster.gpus_per_node=2 \
+ grpo.max_num_steps=2 \
+ logger.tensorboard_enabled=true \
+ logger.log_dir=$LOG_DIR \
+ logger.wandb_enabled=false \
+ logger.monitor_gpus=false \
+ checkpointing.enabled=false \
+ $@ \
+ 2>&1 | tee $RUN_LOG
+
+uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
+
+uv run tests/check_metrics.py $JSON_METRICS \
+ 'max(data["train/token_mult_prob_error"]) < 1.05' \
+ 'mean(data["train/token_mult_prob_error"]) < 1.05'
diff --git a/tests/functional/eval_audio.sh b/tests/functional/eval_audio.sh
new file mode 100644
index 0000000000..6bf8d2d98d
--- /dev/null
+++ b/tests/functional/eval_audio.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
+PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
+# Mark the current repo as safe, since wandb fetches metadata about the repo
+git config --global --add safe.directory $PROJECT_ROOT
+
+set -eou pipefail
+
+EXP_NAME=$(basename $0 .sh)
+EXP_DIR=$SCRIPT_DIR/$EXP_NAME
+LOG_DIR=$EXP_DIR/logs
+JSON_METRICS=$EXP_DIR/metrics.json
+RUN_LOG=$EXP_DIR/run.log
+export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
+
+rm -rf $EXP_DIR $LOG_DIR
+mkdir -p $EXP_DIR $LOG_DIR
+
+cd $PROJECT_ROOT
+uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
+ $PROJECT_ROOT/examples/run_eval.py \
+ --config $PROJECT_ROOT/examples/configs/evals/mmau.yaml \
+ cluster.gpus_per_node=2 \
+ $@ \
+ 2>&1 | tee $RUN_LOG
+
+cat $RUN_LOG | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > $JSON_METRICS
+
+uv run tests/check_metrics.py $JSON_METRICS \
+ 'data["score"] >= 0.0'
diff --git a/tests/unit/data/datasets/test_avqa_dataset.py b/tests/unit/data/datasets/test_avqa_dataset.py
new file mode 100644
index 0000000000..0f7a85160b
--- /dev/null
+++ b/tests/unit/data/datasets/test_avqa_dataset.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from nemo_rl.data.datasets.eval_datasets import (
+ MULTIMODAL_DATASETS,
+ _is_multimodal_dataset,
+)
+
+
+class TestIsMultimodalDataset:
+ """Tests for _is_multimodal_dataset and MULTIMODAL_DATASETS."""
+
+ def test_mmau_is_multimodal(self):
+ assert _is_multimodal_dataset("mmau") is True
+
+ def test_twinkstart_mmau_is_multimodal(self):
+ assert _is_multimodal_dataset("TwinkStart/MMAU") is True
+
+ def test_math_is_not_multimodal(self):
+ assert _is_multimodal_dataset("math") is False
+
+ def test_gpqa_is_not_multimodal(self):
+ assert _is_multimodal_dataset("gpqa") is False
+
+ def test_empty_string_is_not_multimodal(self):
+ assert _is_multimodal_dataset("") is False
+
+ def test_multimodal_datasets_is_a_set(self):
+ assert isinstance(MULTIMODAL_DATASETS, set)
+
+ def test_multimodal_datasets_contains_expected(self):
+ assert "mmau" in MULTIMODAL_DATASETS
+ assert "TwinkStart/MMAU" in MULTIMODAL_DATASETS
+
+
+class TestAVQADataset:
+ """Tests for AVQADataset loading and format_data."""
+
+ def test_avqa_dataset_loads(self):
+ from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset
+
+ dataset = AVQADataset(split="train", max_samples=2)
+
+ assert dataset.task_name == "avqa"
+ assert len(dataset.dataset) > 0
+ assert dataset.preprocessor is not None
+ assert dataset.val_dataset is None
+
+ def test_avqa_dataset_with_split_validation(self):
+ from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset
+
+ dataset = AVQADataset(
+ split="train", split_validation_size=0.5, seed=42, max_samples=4
+ )
+
+ assert dataset.task_name == "avqa"
+ assert len(dataset.dataset) > 0
+ assert dataset.val_dataset is not None
+ assert len(dataset.val_dataset) > 0
+
+ def test_avqa_dataset_format_data(self):
+ from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset
+
+ dataset = AVQADataset(split="train", max_samples=2)
+
+ # Get a raw example and format it
+ raw_example = dataset.dataset[0]
+ formatted = dataset.preprocessor(raw_example)
+
+ assert "messages" in formatted
+ assert "task_name" in formatted
+ assert formatted["task_name"] == "avqa"
+
+ # Check message structure
+ messages = formatted["messages"]
+ assert len(messages) == 2
+ assert messages[0]["role"] == "user"
+ assert messages[1]["role"] == "assistant"
+
+ # Check user content is multimodal (has audio + text)
+ user_content = messages[0]["content"]
+ assert isinstance(user_content, list)
+ content_types = [c["type"] for c in user_content]
+ assert "audio" in content_types
+ assert "text" in content_types
+
+ def test_avqa_dataset_invalid_split(self):
+ from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset
+
+ with pytest.raises(ValueError, match="Invalid split"):
+ AVQADataset(split="test")
+
+ def test_avqa_dataset_has_task_name_column(self):
+ from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset
+
+ dataset = AVQADataset(split="train", max_samples=2)
+
+ # Verify the task_name column was added
+ raw_example = dataset.dataset[0]
+ assert "task_name" in raw_example
+ assert raw_example["task_name"] == "avqa"
+
+ def test_avqa_load_via_registry(self):
+ from nemo_rl.data.datasets import load_response_dataset
+
+ data_config = {"dataset_name": "avqa", "split": "train", "max_samples": 2}
+ dataset = load_response_dataset(data_config)
+
+ assert dataset.task_name == "avqa"
+ assert len(dataset.dataset) > 0