diff --git a/docs/guides/grpo-audio.md b/docs/guides/grpo-audio.md new file mode 100644 index 0000000000..5e3d6e7dbb --- /dev/null +++ b/docs/guides/grpo-audio.md @@ -0,0 +1,81 @@ +# Audio GRPO on AVQA + +This guide explains how to use NeMo RL to train [Qwen2.5-Omni-3B](https://huggingface.co/Qwen/Qwen2.5-Omni-3B) with GRPO on the [AVQA](https://huggingface.co/datasets/avqa) audio question-answering dataset, following the approach described in the [R1-AQA paper](https://arxiv.org/abs/2503.11197), and then evaluate the trained model on the [MMAU benchmark](https://huggingface.co/datasets/TwinkStart/MMAU). + +## 1. Train the Model + +Run GRPO training with the provided config: + +``` +uv run examples/run_grpo.py --config examples/configs/audio_grpo_3B_megatron.yaml +``` + +Config: `examples/configs/audio_grpo_3B_megatron.yaml` + +Key hyperparameters: + +| Parameter | Value | +| --- | --- | +| Model | Qwen2.5-Omni-3B | +| Dataset | AVQA (train split) | +| GPUs | 8 x 1 node, Megatron backend | +| Learning rate | 1e-6 | +| KL penalty | 0.01 | +| Generations per prompt | 8 | +| Prompts per step | 8 | +| Max steps | 200 | +| Save period | 100 | +| Reward | format (0.2) + exact_alnum (0.8) | + +## 2. Convert Checkpoint (Megatron to HF) + +Throughout training, checkpoints are saved to the `results/audio_grpo_3B_megatron` directory (specified by `checkpointing.checkpoint_dir`). To evaluate a checkpoint, first convert it from Megatron format to Hugging Face format: + +``` +uv run --extra mcore python examples/converters/convert_megatron_to_hf.py \ + --config results/audio_grpo_3B_megatron/step_200/config.yaml \ + --megatron-ckpt-path results/audio_grpo_3B_megatron/step_200/policy/weights/iter_0000200 \ + --hf-ckpt-path results/audio_grpo_3B_megatron/step_200/hf +``` + +Replace the step number with the checkpoint you want to evaluate. Note the `--extra mcore` flag is required for the Megatron converter. + +## 3. Evaluate on MMAU + +Evaluate the converted checkpoint on the [MMAU benchmark](https://huggingface.co/datasets/TwinkStart/MMAU): + +``` +uv run examples/run_eval.py \ + --config=examples/configs/evals/mmau.yaml \ + generation.model_name=results/audio_grpo_3B_megatron/step_200/hf \ + data.dataset_name=TwinkStart/MMAU +``` + +Config: `examples/configs/evals/mmau.yaml` + +Use `generation.model_name` to specify the path to the converted Hugging Face checkpoint. + +## 4. Results + +Evaluating the step-200 checkpoint on MMAU, we get the following result: + +``` +============================================================ +model_name='hf_iter_0000000' dataset_name='MMAU' +max_new_tokens=8000 temperature=0.0 top_p=1.0 top_k=-1 seed=42 + +metric=pass@1 num_tests_per_prompt=1 + +score=0.7210 (721.0/1000) +============================================================ +``` + +As a reference, here are results comparing the baseline, the [R1-AQA](https://arxiv.org/abs/2503.11197) HuggingFace vanilla implementation, and NeMo-RL: + +| Model | MMAU Score | +| --- | --- | +| Qwen2.5-Omni-3B (baseline) | 69.8 | +| Qwen2.5-Omni-3B + GRPO (HF vanilla) | 71.6 | +| Qwen2.5-Omni-3B + GRPO (NeMo-RL) | 72.1 | + +The NeMo-RL result (72.1) is comparable to and slightly higher than the Huggingface Transformers reference implementation (71.6), confirming that the training pipeline reproduces expected improvements over the baseline. diff --git a/docs/index.md b/docs/index.md index 7cb4e655ad..ed7c6f9802 100644 --- a/docs/index.md +++ b/docs/index.md @@ -107,6 +107,13 @@ Step-by-step guide for supervised fine-tuning on the OpenMathInstruct2 dataset. Create custom reward environments and integrate them with NeMo RL training pipelines. ::: +:::{grid-item-card} {octicon}`unmute` Audio GRPO on AVQA +:link: guides/grpo-audio +:link-type: doc + +Train Qwen2.5-Omni-3B with GRPO on AVQA and evaluate on MMAU, following the R1-AQA approach. +::: + :::{grid-item-card} {octicon}`plus-circle` Adding New Models :link: adding-new-models :link-type: doc @@ -213,6 +220,7 @@ guides/prorlv2.md guides/grpo.md guides/grpo-deepscaler.md guides/grpo-sliding-puzzle.md +guides/grpo-audio.md guides/rm.md guides/environments.md guides/eval.md diff --git a/examples/configs/audio_grpo_3B_megatron.yaml b/examples/configs/audio_grpo_3B_megatron.yaml new file mode 100644 index 0000000000..e90fdea1f6 --- /dev/null +++ b/examples/configs/audio_grpo_3B_megatron.yaml @@ -0,0 +1,83 @@ +# Audio GRPO 3B Megatron Configuration +# Inherits from grpo_math_1B_megatron.yaml and overrides audio-specific settings. +defaults: "grpo_math_1B_megatron.yaml" + +grpo: + num_prompts_per_step: 8 + num_generations_per_prompt: 8 + max_num_steps: 200 + max_val_samples: 32 + val_batch_size: 32 + +checkpointing: + enabled: true + checkpoint_dir: results/audio_grpo_3B_megatron + keep_top_k: 10 + save_period: 100 + +policy: + model_name: Qwen/Qwen2.5-Omni-3B + train_global_batch_size: 32 + train_micro_batch_size: 1 + generation_batch_size: 32 + logprob_batch_size: 4 + max_total_sequence_length: 2048 + + sequence_packing: + enabled: false + + generation: + max_new_tokens: 1024 + vllm_cfg: + # Audio/multimodal models require tokenizer to be initialized before generation + skip_tokenizer_init: False + limit_mm_per_prompt: + audio: 1 + + megatron_cfg: + converter_type: Qwen2_5OmniForConditionalGeneration + apply_rope_fusion: false + optimizer: + lr: 1.0e-6 + min_lr: 1.0e-7 + scheduler: + lr_warmup_iters: 10 + lr_warmup_init: 1.0e-7 + distributed_data_parallel_config: + overlap_grad_reduce: false + +data: + train: + dataset_name: avqa + split: train + validation: + dataset_name: avqa + split: validation + default: + prompt_file: null + system_prompt_file: null + processor: "vlm_hf_data_processor" + env_name: "vlm" + +env: + vlm: + num_workers: 8 + reward_functions: + - name: format + weight: 0.2 + - name: exact_alnum + weight: 0.8 + +logger: + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: false + wandb: + project: grpo-dev + name: audio-grpo-3b-megatron-large-lr + swanlab: + project: grpo-dev + name: audio-grpo-3b-megatron-large-lr + +cluster: + gpus_per_node: 8 diff --git a/examples/configs/evals/mmau.yaml b/examples/configs/evals/mmau.yaml new file mode 100644 index 0000000000..0338937f9b --- /dev/null +++ b/examples/configs/evals/mmau.yaml @@ -0,0 +1,58 @@ +eval: + metric: "pass@k" + num_tests_per_prompt: 1 + seed: 42 + k_value: 1 + save_path: null + +generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 0.0 + top_p: 1.0 + top_k: -1 + num_prompts_per_step: -1 + model_name: "Qwen/Qwen2.5-Omni-3B" + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: "bfloat16" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.9 + max_model_len: 8000 + enforce_eager: False + skip_tokenizer_init: False + limit_mm_per_prompt: + audio: 1 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + +tokenizer: + name: ${generation.model_name} + chat_template: "default" + chat_template_kwargs: null + +data: + max_input_seq_length: ${generation.vllm_cfg.max_model_len} + prompt_file: null + system_prompt_file: null + dataset_name: "TwinkStart/MMAU" + split: "v05.15.25" + env_name: vlm + +env: + mmau: + num_workers: 8 + reward_functions: + - name: exact_alnum + weight: 1.0 + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/examples/run_eval.py b/examples/run_eval.py index 8966938632..3d678466ed 100644 --- a/examples/run_eval.py +++ b/examples/run_eval.py @@ -20,19 +20,16 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from omegaconf import OmegaConf -from transformers import AutoTokenizer, PreTrainedTokenizerBase from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.datasets import AllTaskProcessedDataset, load_eval_dataset -from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env +from nemo_rl.data.datasets.eval_datasets import _is_multimodal_dataset from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.environments.math_environment import MathEnvironment +from nemo_rl.environments.utils import create_env from nemo_rl.evals.eval import MasterConfig, run_env_eval, setup from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config -TokenizerType = PreTrainedTokenizerBase - def parse_args(): """Parse command line arguments.""" @@ -50,26 +47,25 @@ def parse_args(): return args, overrides -def setup_data(tokenizer: AutoTokenizer, data_config, env_configs): +def setup_data(tokenizer, data_config, env_configs): print("Setting up data...") # load dataset base_dataset = load_eval_dataset(data_config) rekeyed_ds = base_dataset.rekeyed_ds - env = MathEnvironment.options( - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.math_environment.MathEnvironment" - ) - } - ).remote(env_configs["math"]) + # Determine env from config: use explicit env_name if provided, + # otherwise fall back to the single key in env_configs. + env_key = next(iter(env_configs)) + env_name = data_config.get("env_name", env_key) + env = create_env(env_name=env_name, env_config=env_configs[env_key]) dataset = AllTaskProcessedDataset( dataset=rekeyed_ds, tokenizer=tokenizer, default_task_data_spec=base_dataset.task_spec, task_data_processors=base_dataset.processor, + task_data_preprocessors=getattr(base_dataset, "preprocessor", None), max_seq_length=data_config["max_input_seq_length"], ) @@ -104,8 +100,9 @@ def main(): # Init ray init_ray() - # Setup tokenizer - tokenizer = get_tokenizer(config["tokenizer"]) + # Setup tokenizer — get_tokenizer handles both text-only and multimodal + is_multimodal = _is_multimodal_dataset(config["data"]["dataset_name"]) + tokenizer = get_tokenizer(config["tokenizer"], get_processor=is_multimodal) config["generation"] = configure_generation_config( config["generation"], tokenizer, is_eval=True ) diff --git a/nemo_rl/data/collate_fn.py b/nemo_rl/data/collate_fn.py index 09d2cf766a..6f4291aa43 100644 --- a/nemo_rl/data/collate_fn.py +++ b/nemo_rl/data/collate_fn.py @@ -55,9 +55,11 @@ def rl_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]: ] vllm_images = [datum_spec.get("vllm_images", []) for datum_spec in data_batch] vllm_videos = [datum_spec.get("vllm_videos", []) for datum_spec in data_batch] + vllm_audios = [datum_spec.get("vllm_audios", []) for datum_spec in data_batch] extra_args["vllm_content"] = vllm_content extra_args["vllm_images"] = vllm_images extra_args["vllm_videos"] = vllm_videos + extra_args["vllm_audios"] = vllm_audios output: BatchedDataDict[Any] = BatchedDataDict( message_log=message_log, @@ -116,10 +118,26 @@ def eval_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]: extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] idx = [datum_spec["idx"] for datum_spec in data_batch] + # Check if any of the data batch has vllm content (multimodal data) + extra_args = {} + if any( + datum_spec.get("vllm_content", None) is not None for datum_spec in data_batch + ): + extra_args["vllm_content"] = [ + datum_spec.get("vllm_content", None) for datum_spec in data_batch + ] + extra_args["vllm_images"] = [ + datum_spec.get("vllm_images", []) for datum_spec in data_batch + ] + extra_args["vllm_audios"] = [ + datum_spec.get("vllm_audios", []) for datum_spec in data_batch + ] + output: BatchedDataDict[Any] = BatchedDataDict( message_log=message_log, extra_env_info=extra_env_info, idx=idx, + **extra_args, ) return output diff --git a/nemo_rl/data/datasets/eval_datasets/__init__.py b/nemo_rl/data/datasets/eval_datasets/__init__.py index 8386286c83..d813ed040c 100644 --- a/nemo_rl/data/datasets/eval_datasets/__init__.py +++ b/nemo_rl/data/datasets/eval_datasets/__init__.py @@ -16,9 +16,18 @@ from nemo_rl.data.datasets.eval_datasets.gpqa import GPQADataset from nemo_rl.data.datasets.eval_datasets.local_math_dataset import LocalMathDataset from nemo_rl.data.datasets.eval_datasets.math import MathDataset +from nemo_rl.data.datasets.eval_datasets.mmau import MMAUDataset from nemo_rl.data.datasets.eval_datasets.mmlu import MMLUDataset from nemo_rl.data.datasets.eval_datasets.mmlu_pro import MMLUProDataset +# Dataset names that require multimodal (VLM) processing +MULTIMODAL_DATASETS = {"mmau", "TwinkStart/MMAU"} + + +def _is_multimodal_dataset(dataset_name): + """Check if the dataset requires multimodal processing.""" + return dataset_name in MULTIMODAL_DATASETS + def load_eval_dataset(data_config): """Loads evaluation dataset.""" @@ -82,6 +91,13 @@ def load_eval_dataset(data_config): prompt_file=data_config["prompt_file"], system_prompt_file=data_config["system_prompt_file"], ) + # mmau + elif dataset_name in ("mmau", "TwinkStart/MMAU"): + split = data_config.get("split", "v05.15.25") + base_dataset = MMAUDataset( + dataset_name="TwinkStart/MMAU", + split=split, + ) # fall back to local dataset else: print(f"Loading dataset from {dataset_name}...") @@ -103,6 +119,9 @@ def load_eval_dataset(data_config): "GPQADataset", "LocalMathDataset", "MathDataset", + "MMAUDataset", "MMLUDataset", "MMLUProDataset", + "MULTIMODAL_DATASETS", + "_is_multimodal_dataset", ] diff --git a/nemo_rl/data/datasets/eval_datasets/mmau.py b/nemo_rl/data/datasets/eval_datasets/mmau.py new file mode 100644 index 0000000000..8f23e94b10 --- /dev/null +++ b/nemo_rl/data/datasets/eval_datasets/mmau.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MMAU (Massive Multitask Audio Understanding) evaluation dataset.""" + +from typing import Any + +from datasets import load_dataset + +from nemo_rl.data.datasets.response_datasets.avqa import _resample_audio +from nemo_rl.data.interfaces import TaskDataSpec +from nemo_rl.data.processors import vlm_hf_data_processor + +DEFAULT_TEMPLATE = ( + "{question} Please choose the answer from the following options: {choices}. " + "Output the final answer in ." +) + + +class MMAUDataset: + """MMAU evaluation dataset. + + Loads the TwinkStart/MMAU HF dataset and formats each item into the + messages format expected by vlm_hf_data_processor. + + Args: + dataset_name: HuggingFace dataset name. + split: Dataset split to load. + """ + + def __init__( + self, + dataset_name: str = "TwinkStart/MMAU", + split: str = "v05.15.25", + ): + ds = load_dataset(dataset_name, split=split) + + self.rekeyed_ds = ds + self.task_spec = TaskDataSpec(task_name="mmau") + self.processor = vlm_hf_data_processor + self.preprocessor = self.format_data + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + """Convert a raw MMAU item into messages format for vlm_hf_data_processor.""" + audio_data = data["audio"] + audio_array = audio_data["array"] + + # Resample to 16kHz if needed + if audio_data["sampling_rate"] != 16000: + audio_array = _resample_audio( + audio_array, audio_data["sampling_rate"], 16000 + ) + + question = data["question"] + choices = data["choices"] + answer = data["answer"] + + prompt_text = DEFAULT_TEMPLATE.format(question=question, choices=choices) + + user_content = [ + {"type": "audio", "audio": audio_array}, + {"type": "text", "text": prompt_text}, + ] + return { + "messages": [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": answer}, + ], + "task_name": "mmau", + "choices": choices, + } diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index eb48bb5204..85107495c1 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -14,6 +14,7 @@ from nemo_rl.data import ResponseDatasetConfig from nemo_rl.data.datasets.response_datasets.aime24 import AIME2024Dataset +from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset from nemo_rl.data.datasets.response_datasets.dapo_math import ( @@ -41,6 +42,7 @@ DATASET_REGISTRY = { # built-in datasets + "avqa": AVQADataset, "AIME2024": AIME2024Dataset, "clevr-cogent": CLEVRCoGenTDataset, "daily-omni": DailyOmniDataset, @@ -88,6 +90,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig): __all__ = [ + "AVQADataset", "AIME2024Dataset", "CLEVRCoGenTDataset", "DailyOmniDataset", diff --git a/nemo_rl/data/datasets/response_datasets/avqa.py b/nemo_rl/data/datasets/response_datasets/avqa.py new file mode 100644 index 0000000000..e05648fe52 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/avqa.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Any + +import numpy as np +import torch +import torchaudio +from datasets import Dataset, load_dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + +DEFAULT_TEMPLATE = ( + "{question} Please choose the answer from the following options: {choices}. " + "Output the final answer in ." +) + + +def _resample_audio(audio_array, orig_sr, target_sr=16000): + """Resample audio to target sample rate.""" + if isinstance(audio_array, np.ndarray): + waveform = torch.from_numpy(audio_array).float() + else: + waveform = audio_array.float() + + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + + resampler = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr) + resampled = resampler(waveform) + return resampled[0].numpy() + + +def _parse_question(question_text): + r"""Parse the HF dataset question format. + + Input: "How many animals are there in the video?\nChoices:\nA. 3\nB. One\nC. 4\nD. 2" + Returns: (question, choices_list) + """ + parts = question_text.split("\nChoices:\n") + if len(parts) == 2: + question = parts[0] + choices = [] + for line in parts[1].strip().split("\n"): + line = line.strip() + if line: + match = re.match(r"^[A-Z]\.\s*(.+)$", line) + choices.append(match.group(1) if match else line) + return question, choices + return question_text, [] + + +class AVQADataset(RawDataset): + """Wrapper around the AVQA (Audio-Visual Question Answering) dataset. + + Formats audio samples into OpenAI-style messages for audio QA + fine-tuning with Qwen2.5-Omni. + + Args: + split: Split name for the dataset. Supported: "train", "validation". + """ + + task_name = "avqa" + + def __init__( + self, + split: str = "train", + split_validation_size: float = 0, + seed: int = 42, + max_samples: int | None = None, + **kwargs, + ): + VALID_SPLITS = ("train", "validation") + if split not in VALID_SPLITS: + raise ValueError( + f"Invalid split: {split}. Please use one of {VALID_SPLITS}." + ) + + if max_samples is not None: + ds = load_dataset("gijs/avqa-processed", split=split, streaming=True) + self.dataset = Dataset.from_list(list(ds.take(max_samples))) + else: + self.dataset = load_dataset("gijs/avqa-processed", split=split) + + self.dataset = self.dataset.add_column( + "task_name", [self.task_name] * len(self.dataset) + ) + + self.preprocessor = self.format_data + + # `self.val_dataset` is used (not None) only when current dataset is used for both training and validation + self.val_dataset = None + self.split_train_validation(split_validation_size, seed) + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + audio_data = data["audio"] + audio_array = audio_data["array"] + + # Resample to 16kHz if needed + if audio_data["sampling_rate"] != 16000: + audio_array = _resample_audio( + audio_array, audio_data["sampling_rate"], 16000 + ) + + # Parse question and build prompt + question, choices = _parse_question(data["question"]) + question = question.replace("video", "audio") + + prompt_text = DEFAULT_TEMPLATE.format(question=question, choices=choices) + + # Strip letter prefix from answer (e.g., "B. Yacht consignment" -> "Yacht consignment") + answer = data["answer"] + answer_match = re.match(r"^[A-Z]\.\s*(.+)$", answer) + if answer_match: + answer = answer_match.group(1) + + user_content = [ + {"type": "audio", "audio": audio_array}, + {"type": "text", "text": prompt_text}, + ] + return { + "messages": [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": answer}, + ], + "task_name": self.task_name, + } diff --git a/nemo_rl/data/datasets/utils.py b/nemo_rl/data/datasets/utils.py index 5cb407dbdd..4a54f9ca39 100644 --- a/nemo_rl/data/datasets/utils.py +++ b/nemo_rl/data/datasets/utils.py @@ -34,18 +34,24 @@ def assert_no_double_bos(token_ids: torch.Tensor, tokenizer: TokenizerType) -> N token_ids: List of token IDs tokenizer: Tokenizer """ - if tokenizer.bos_token_id is not None: + # AutoProcessor wraps a tokenizer; unwrap if needed + if isinstance(tokenizer, PreTrainedTokenizerBase): + _tok = tokenizer + elif isinstance(tokenizer, AutoProcessor): + _tok = tokenizer.tokenizer + else: + raise TypeError(f"Unsupported tokenizer type: {type(tokenizer)}") + + if _tok.bos_token_id is not None: token_ids_list = token_ids.tolist() if len(token_ids_list) > 1: assert not ( - token_ids_list[0] == tokenizer.bos_token_id - and token_ids_list[1] == tokenizer.bos_token_id + token_ids_list[0] == _tok.bos_token_id + and token_ids_list[1] == _tok.bos_token_id ), "Found double BOS token in the first two positions of the message." else: - # `name_or_path` is not available for AutoProcessor, temp fix in get_tokenizer - print( - f"skip assert_start_single_bos since Tokenizer {tokenizer.name_or_path} has no BOS token" - ) + name = getattr(_tok, "name_or_path", str(type(_tok).__name__)) + print(f"skip assert_start_single_bos since Tokenizer {name} has no BOS token") def pil_to_base64(image: Image.Image, format: str = "PNG") -> str: diff --git a/nemo_rl/data/multimodal_utils.py b/nemo_rl/data/multimodal_utils.py index 0513ec9760..6a16e18fec 100644 --- a/nemo_rl/data/multimodal_utils.py +++ b/nemo_rl/data/multimodal_utils.py @@ -214,7 +214,7 @@ def get_multimodal_keys_from_processor(processor) -> list[str]: all_keys.update(processor.video_processor.model_input_names) if hasattr(processor, "feature_extractor"): all_keys.update(processor.feature_extractor.model_input_names) - # all_keys.update(processor.model_input_names) + all_keys.update(processor.model_input_names) all_keys.difference_update(set(processor.tokenizer.model_input_names)) return list(all_keys) diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 52ac9bf67d..7d0d4f753b 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -467,29 +467,30 @@ def vlm_hf_data_processor( datum_dict = format_refcoco_dataset(datum_dict) elif datum_dict["task_name"] == "geometry3k": datum_dict = format_geometry3k_dataset(datum_dict) + elif datum_dict["task_name"] == "avqa": + pass # AVQA data is already formatted by AVQADataset.format_data + elif datum_dict["task_name"] == "mmau": + pass # MMAU data is already formatted by MMAUDataset.format_data else: raise ValueError(f"No data processor for task {datum_dict['task_name']}") user_message = datum_dict["messages"] problem = user_message[0]["content"] extra_env_info = {"ground_truth": user_message[1]["content"]} + if "choices" in datum_dict: + extra_env_info["choices"] = datum_dict["choices"] message_log: VLMMessageLogType = [] ### only one round of interaction is assumed, this can easily be extended to a conversational setting user_message: dict[str, Any] = {"role": "user", "content": []} # images = [] + audios = [] if isinstance(problem, list): for content in problem: - # for image, video, just append it + # for image, video, audio, just append it # for text, format the prompt to the problem - if content["type"] != "text": - user_message["content"].append(content) - if content["type"] == "image": - images.append(content["image"]) - else: - raise ValueError(f"Unsupported content type: {content['type']}") - elif content["type"] == "text": + if content["type"] == "text": user_message["content"].append( { "type": "text", @@ -498,6 +499,15 @@ def vlm_hf_data_processor( else content["text"], } ) + elif content["type"] == "image": + user_message["content"].append(content) + images.append(content["image"]) + elif content["type"] == "audio": + user_message["content"].append(content) + # Store as (audio_array, sample_rate) tuple for vLLM + audios.append((content["audio"], 16000)) + else: + raise ValueError(f"Unsupported content type: {content['type']}") else: # conversation consists of a text-only message user_message["content"] = task_data_spec.prompt.format(problem) @@ -552,6 +562,7 @@ def vlm_hf_data_processor( vllm_kwargs = { "vllm_content": None, "vllm_images": [], + "vllm_audios": [], } # make smaller and mask out @@ -564,11 +575,11 @@ def vlm_hf_data_processor( chat_message[key] = PackedTensor.empty_like(value) loss_multiplier = 0.0 else: - # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images) for the entire conversation - # add images for vllm serving + # get the prompt content! (use this for vllm-backend that needs formatted dialog and list of images/audios) for the entire conversation vllm_kwargs = { "vllm_content": string_formatted_dialog, "vllm_images": images, + "vllm_audios": audios, } output: DatumSpec = { diff --git a/nemo_rl/environments/vlm_environment.py b/nemo_rl/environments/vlm_environment.py index 7e4943c3b2..34e32b4e98 100644 --- a/nemo_rl/environments/vlm_environment.py +++ b/nemo_rl/environments/vlm_environment.py @@ -144,6 +144,7 @@ def step( # type: ignore[override] self, message_log_batch: list[list[dict[str, str]]], metadata: list[VLMEnvironmentMetadata], + return_extracted_answer: bool = False, ) -> EnvironmentReturn: """Runs a step in the vlm environment. diff --git a/nemo_rl/evals/eval.py b/nemo_rl/evals/eval.py index d67255ef1e..94723d5ebc 100644 --- a/nemo_rl/evals/eval.py +++ b/nemo_rl/evals/eval.py @@ -322,11 +322,33 @@ async def _run_env_eval_impl( batch = batch.repeat_interleave(num_tests_per_prompt) # get input prompt from message_log + is_multimodal = "vllm_content" in batch prompts = [] - for message_log in batch["message_log"]: - content = [message["content"] for message in message_log] - content = "\n".join(content) - prompts.append(content) + prompts_for_display = [] + for i, message_log in enumerate(batch["message_log"]): + if is_multimodal and batch["vllm_content"][i] is not None: + vllm_content = batch["vllm_content"][i] + prompt_dict = {"prompt": vllm_content} + multi_modal_data = {} + audios = batch.get("vllm_audios", None) + if audios is not None and len(audios[i]) > 0: + multi_modal_data["audio"] = ( + audios[i][0] if len(audios[i]) == 1 else audios[i] + ) + images = batch.get("vllm_images", None) + if images is not None and len(images[i]) > 0: + multi_modal_data["image"] = ( + images[i][0] if len(images[i]) == 1 else images[i] + ) + if multi_modal_data: + prompt_dict["multi_modal_data"] = multi_modal_data + prompts.append(prompt_dict) + prompts_for_display.append(vllm_content) + else: + content = [message["content"] for message in message_log] + content = "\n".join(content) + prompts.append(content) + prompts_for_display.append(content) # generate by vllm inputs = BatchedDataDict({"prompts": prompts}) @@ -353,7 +375,7 @@ async def _run_env_eval_impl( # Collect data for JSON file for i, (prompt, output, message_log, reward, extra_info) in enumerate( zip( - prompts, + prompts_for_display, outputs, batch["message_log"], rewards.tolist(), diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 603a972095..4093e4370f 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -430,6 +430,8 @@ def run_multi_turn_rollout( generation_input_data["vllm_images"] = active_batch["vllm_images"] if "vllm_videos" in active_batch: generation_input_data["vllm_videos"] = active_batch["vllm_videos"] + if "vllm_audios" in active_batch: + generation_input_data["vllm_audios"] = active_batch["vllm_audios"] # generate_responses updates active_batch["message_log"] in-place active_batch, generated_ids, gen_metrics = generate_responses( diff --git a/nemo_rl/models/generation/vllm/utils.py b/nemo_rl/models/generation/vllm/utils.py index 4be7d95117..f9ecb3523b 100644 --- a/nemo_rl/models/generation/vllm/utils.py +++ b/nemo_rl/models/generation/vllm/utils.py @@ -66,15 +66,22 @@ def _get_regular_prompt(index: int): continue # init prompt dict prompt_dict = {"prompt": msg} - # add additional data if present + # collect multi_modal_data from images and audios + multi_modal_data = {} images = data.get("vllm_images", None) - if images is None or len(images[i]) == 0: + if images is not None and len(images[i]) > 0: + multi_modal_data["image"] = ( + images[i][0] if len(images[i]) == 1 else images[i] + ) + audios = data.get("vllm_audios", None) + if audios is not None and len(audios[i]) > 0: + multi_modal_data["audio"] = ( + audios[i][0] if len(audios[i]) == 1 else audios[i] + ) + if not multi_modal_data: prompts.append(_get_regular_prompt(i)) continue - else: - prompt_dict["multi_modal_data"] = { - "image": images[i][0] if len(images[i]) == 1 else images[i] - } + prompt_dict["multi_modal_data"] = multi_modal_data prompts.append(prompt_dict) else: # Regular LLM generation using token_ids diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 1e3f1420d3..13232b3e45 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -693,6 +693,8 @@ def freeze_moe_router(megatron_model): if isinstance(model_module, Float16Module): model_module = model_module.module # Handle VLM models + if hasattr(model_module, "thinker"): + model_module = model_module.thinker if hasattr(model_module, "language_model"): model_module = model_module.language_model for layer in model_module.decoder.layers: diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index a4214032f7..a65e6dd77d 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -1002,7 +1002,10 @@ def log_batched_dict_as_jsonl( for key, value in sample.items(): if isinstance(value, torch.Tensor): sample[key] = value.tolist() - f.write(json.dumps({**sample, "idx": i}) + "\n") + elif isinstance(value, np.ndarray): + sample[key] = value.tolist() + # default=str is a fallback for non-JSON-serializable types (e.g., datetime, custom objects) + f.write(json.dumps({**sample, "idx": i}, default=str) + "\n") print(f"Logged data to {filepath}") diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index bee4d8d2eb..41ce679680 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -38,6 +38,7 @@ run_test() { run_test bash ./tests/functional/grpo_frozen_env.sh run_test bash ./tests/functional/test_frozen_env.sh +run_test uv run --no-sync bash ./tests/functional/audio_grpo_megatron.sh run_test fast uv run --no-sync bash ./tests/functional/distillation.sh run_test uv run --no-sync bash ./tests/functional/distillation_megatron.sh run_test fast uv run --no-sync bash ./tests/functional/dpo.sh @@ -45,6 +46,7 @@ run_test uv run --no-sync bash ./tests/functional/dpo_automodel_lora.sh run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/eval.sh run_test uv run --no-sync bash ./tests/functional/eval_async.sh +run_test uv run --no-sync bash ./tests/functional/eval_audio.sh run_test fast uv run --no-sync bash ./tests/functional/grpo.sh run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh diff --git a/tests/functional/audio_grpo_megatron.sh b/tests/functional/audio_grpo_megatron.sh new file mode 100644 index 0000000000..2e52a7fa29 --- /dev/null +++ b/tests/functional/audio_grpo_megatron.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_vlm_grpo.py \ + --config $PROJECT_ROOT/examples/configs/audio_grpo_3B_megatron.yaml \ + policy.model_name=Qwen/Qwen2.5-Omni-3B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/token_mult_prob_error"]) < 1.05' \ + 'mean(data["train/token_mult_prob_error"]) < 1.05' diff --git a/tests/functional/eval_audio.sh b/tests/functional/eval_audio.sh new file mode 100644 index 0000000000..6bf8d2d98d --- /dev/null +++ b/tests/functional/eval_audio.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_eval.py \ + --config $PROJECT_ROOT/examples/configs/evals/mmau.yaml \ + cluster.gpus_per_node=2 \ + $@ \ + 2>&1 | tee $RUN_LOG + +cat $RUN_LOG | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["score"] >= 0.0' diff --git a/tests/unit/data/datasets/test_avqa_dataset.py b/tests/unit/data/datasets/test_avqa_dataset.py new file mode 100644 index 0000000000..0f7a85160b --- /dev/null +++ b/tests/unit/data/datasets/test_avqa_dataset.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo_rl.data.datasets.eval_datasets import ( + MULTIMODAL_DATASETS, + _is_multimodal_dataset, +) + + +class TestIsMultimodalDataset: + """Tests for _is_multimodal_dataset and MULTIMODAL_DATASETS.""" + + def test_mmau_is_multimodal(self): + assert _is_multimodal_dataset("mmau") is True + + def test_twinkstart_mmau_is_multimodal(self): + assert _is_multimodal_dataset("TwinkStart/MMAU") is True + + def test_math_is_not_multimodal(self): + assert _is_multimodal_dataset("math") is False + + def test_gpqa_is_not_multimodal(self): + assert _is_multimodal_dataset("gpqa") is False + + def test_empty_string_is_not_multimodal(self): + assert _is_multimodal_dataset("") is False + + def test_multimodal_datasets_is_a_set(self): + assert isinstance(MULTIMODAL_DATASETS, set) + + def test_multimodal_datasets_contains_expected(self): + assert "mmau" in MULTIMODAL_DATASETS + assert "TwinkStart/MMAU" in MULTIMODAL_DATASETS + + +class TestAVQADataset: + """Tests for AVQADataset loading and format_data.""" + + def test_avqa_dataset_loads(self): + from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset + + dataset = AVQADataset(split="train", max_samples=2) + + assert dataset.task_name == "avqa" + assert len(dataset.dataset) > 0 + assert dataset.preprocessor is not None + assert dataset.val_dataset is None + + def test_avqa_dataset_with_split_validation(self): + from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset + + dataset = AVQADataset( + split="train", split_validation_size=0.5, seed=42, max_samples=4 + ) + + assert dataset.task_name == "avqa" + assert len(dataset.dataset) > 0 + assert dataset.val_dataset is not None + assert len(dataset.val_dataset) > 0 + + def test_avqa_dataset_format_data(self): + from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset + + dataset = AVQADataset(split="train", max_samples=2) + + # Get a raw example and format it + raw_example = dataset.dataset[0] + formatted = dataset.preprocessor(raw_example) + + assert "messages" in formatted + assert "task_name" in formatted + assert formatted["task_name"] == "avqa" + + # Check message structure + messages = formatted["messages"] + assert len(messages) == 2 + assert messages[0]["role"] == "user" + assert messages[1]["role"] == "assistant" + + # Check user content is multimodal (has audio + text) + user_content = messages[0]["content"] + assert isinstance(user_content, list) + content_types = [c["type"] for c in user_content] + assert "audio" in content_types + assert "text" in content_types + + def test_avqa_dataset_invalid_split(self): + from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset + + with pytest.raises(ValueError, match="Invalid split"): + AVQADataset(split="test") + + def test_avqa_dataset_has_task_name_column(self): + from nemo_rl.data.datasets.response_datasets.avqa import AVQADataset + + dataset = AVQADataset(split="train", max_samples=2) + + # Verify the task_name column was added + raw_example = dataset.dataset[0] + assert "task_name" in raw_example + assert raw_example["task_name"] == "avqa" + + def test_avqa_load_via_registry(self): + from nemo_rl.data.datasets import load_response_dataset + + data_config = {"dataset_name": "avqa", "split": "train", "max_samples": 2} + dataset = load_response_dataset(data_config) + + assert dataset.task_name == "avqa" + assert len(dataset.dataset) > 0