diff --git a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml index f3e3dcccc8..07f5949288 100644 --- a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml +++ b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml @@ -45,6 +45,7 @@ loss_fn: checkpointing: enabled: true + checkpoint_dir: "results/grpo" metric_name: "val:accuracy" higher_is_better: true keep_top_k: 3 @@ -229,13 +230,27 @@ policy: num_nodes: null # Decides number of nodes to be dedicated to generation data: - # Using the prepared train and validation datasets (downloaded from HuggingFace and split 90/10) - # Train: 1129 samples, Validation: 126 samples - train_jsonl_fpath: 3rdparty/Gym-workspace/Gym/data/workplace_assistant/train.jsonl - validation_jsonl_fpath: 3rdparty/Gym-workspace/Gym/data/workplace_assistant/validation.jsonl + # The complete trajectory is managed by gym, which will pass the entire thing to vllm. + # We just catch the vllm max seq len error and return an empty response with + # finish_reason=stop or incomplete_details=max_length that can be used to stop the trajectory, + # so max_input_seq_length is not used in this case. + max_input_seq_length: null # nemogym dataset doesn't use this parameter shuffle: true num_workers: 0 + # Using the prepared train and validation datasets (downloaded from HuggingFace and split 90/10) + # Train: 1129 samples, Validation: 126 samples + train: + data_path: 3rdparty/Gym-workspace/Gym/data/workplace_assistant/train.jsonl + validation: + data_path: 3rdparty/Gym-workspace/Gym/data/workplace_assistant/validation.jsonl + default: + dataset_name: NemoGymDataset + env_name: "nemo_gym" + prompt_file: null # nemogym dataset doesn't use this parameter + system_prompt_file: null # nemogym dataset doesn't use this parameter + processor: "nemo_gym_data_processor" + env: should_use_nemo_gym: true should_log_nemo_gym_responses: true # If you have low logging storage, set this to false diff --git a/examples/nemo_gym/run_grpo_nemo_gym.py b/examples/nemo_gym/run_grpo_nemo_gym.py index c8d2c911e2..dbef2c5791 100644 --- a/examples/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/nemo_gym/run_grpo_nemo_gym.py @@ -13,11 +13,8 @@ # limitations under the License. import argparse -import json import os import pprint -from itertools import chain, repeat -from typing import Optional # Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB) import wandb.util @@ -42,18 +39,13 @@ setup, ) from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import DatumSpec -from nemo_rl.distributed.ray_actor_environment_registry import ( - get_actor_python_env, -) +from nemo_rl.data.utils import setup_response_data from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.environments.nemo_gym import ( - NemoGym, NemoGymConfig, - nemo_gym_example_to_nemo_rl_datum_spec, setup_nemo_gym_config, ) +from nemo_rl.environments.utils import create_env from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides @@ -75,40 +67,6 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]: return args, overrides -def setup_single_nemo_gym_dataset( - jsonl_fpath: str, tokenizer, num_repeats: Optional[int] = None -): - with open(jsonl_fpath) as f: - nemo_gym_examples = list(map(json.loads, f)) - - print(f"Loaded data at {jsonl_fpath}. Found {len(nemo_gym_examples)} examples") - - if num_repeats: - previous_length = len(nemo_gym_examples) - nemo_gym_examples = list( - chain.from_iterable( - repeat(nemo_gym_example, num_repeats) - for nemo_gym_example in nemo_gym_examples - ) - ) - print( - f"Repeating examples (in a pattern of abc to aabbcc) for {jsonl_fpath} from {previous_length} to {len(nemo_gym_examples)}!" - ) - - nemo_rl_compatible_examples: list[DatumSpec] = [ - nemo_gym_example_to_nemo_rl_datum_spec(nemo_gym_example, idx) - for idx, nemo_gym_example in enumerate(nemo_gym_examples) - ] - - passthrough_task_processor = lambda datum_dict, *args, **kwargs: datum_dict - return AllTaskProcessedDataset( - nemo_rl_compatible_examples, - tokenizer, - None, - passthrough_task_processor, - ) - - # These types are directly imported from grpo_train since if something about the architecture changes we want to immediately fail. def collect_trajectories( policy: ColocatablePolicyInterface, @@ -165,7 +123,7 @@ def main() -> None: if not args.config: args.config = os.path.join( os.path.dirname(__file__), - "grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml", + "grpo_workplace_assistant_nemotron_nano_v2_9b.yaml", ) config = load_config(args.config) @@ -201,14 +159,10 @@ def main() -> None: # We assert here since this is right after the final config has been materialized. assert _should_use_nemo_gym(config) + # NeMo-Gym environment needs to get dp_openai_server_base_urls from policy_generation, so we don't setup env here. print("\n▶ Setting up data...") - train_dataset = setup_single_nemo_gym_dataset( - jsonl_fpath=config["data"]["train_jsonl_fpath"], - tokenizer=tokenizer, - ) - val_dataset = setup_single_nemo_gym_dataset( - jsonl_fpath=config["data"]["validation_jsonl_fpath"], - tokenizer=tokenizer, + train_dataset, val_dataset = setup_response_data( + tokenizer, config["data"], env_configs=None ) # Validation dataset config setup. @@ -221,11 +175,12 @@ def main() -> None: The validation set you pass in will directly be used for validation with no additional preprocessing. If you want to have some number of repetitions, please include that in your dataset, via ``num_repeats``, in your dataset config and `ng_prepare_data` will prepare it accordingly.""" ) - print( - f"Setting `grpo.max_val_samples` and `grpo.val_batch_size` to the length of the validation dataset, which is {len(val_dataset)}" - ) - config["grpo"]["max_val_samples"] = len(val_dataset) - config["grpo"]["val_batch_size"] = config["grpo"]["max_val_samples"] + if val_dataset is not None: + print( + f"Setting `grpo.max_val_samples` and `grpo.val_batch_size` to the length of the validation dataset, which is {len(val_dataset)}" + ) + config["grpo"]["max_val_samples"] = len(val_dataset) + config["grpo"]["val_batch_size"] = config["grpo"]["max_val_samples"] # Print config print("Final config:") @@ -254,15 +209,12 @@ def main() -> None: base_urls=policy_generation.dp_openai_server_base_urls, initial_global_config_dict=config["env"]["nemo_gym"], ) - nemo_gym = NemoGym.options( - runtime_env={ - "py_executable": get_actor_python_env( - "nemo_rl.environments.nemo_gym.NemoGym" - ), - } - ).remote(nemo_gym_config) + nemo_gym = create_env(env_name="nemo_gym", env_config=nemo_gym_config) # Blocking wait for NeMo-Gym to spin up ray.get(nemo_gym.health_check.remote()) + + # Bind task_to_env and val_task_to_env for nemo_gym env + # Hardcode here to match `run_async_nemo_gym_rollout` task_to_env = {"nemo_gym": nemo_gym} val_task_to_env = task_to_env diff --git a/examples/run_distillation.py b/examples/run_distillation.py index 60cf58909d..a57a0d2fed 100644 --- a/examples/run_distillation.py +++ b/examples/run_distillation.py @@ -19,7 +19,7 @@ from nemo_rl.algorithms.distillation import MasterConfig, distillation_train, setup from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.data.utils import setup_data_with_envs +from nemo_rl.data.utils import setup_response_data from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides @@ -79,7 +79,7 @@ def main() -> None: val_dataset, task_to_env, val_task_to_env, - ) = setup_data_with_envs(tokenizer, config["data"], config["env"]) + ) = setup_response_data(tokenizer, config["data"], config["env"]) ( student_policy, diff --git a/examples/run_grpo.py b/examples/run_grpo.py index 83fd9f1d97..b05721d6a1 100644 --- a/examples/run_grpo.py +++ b/examples/run_grpo.py @@ -20,7 +20,7 @@ from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.data.utils import setup_data_with_envs +from nemo_rl.data.utils import setup_response_data from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides @@ -91,7 +91,7 @@ def main() -> None: val_dataset, task_to_env, val_task_to_env, - ) = setup_data_with_envs(tokenizer, config["data"], config["env"]) + ) = setup_response_data(tokenizer, config["data"], config["env"]) ( policy, diff --git a/examples/run_vlm_grpo.py b/examples/run_vlm_grpo.py index 65613ddee6..d75d23a343 100644 --- a/examples/run_vlm_grpo.py +++ b/examples/run_vlm_grpo.py @@ -20,7 +20,7 @@ from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.data.utils import setup_data_with_envs +from nemo_rl.data.utils import setup_response_data from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.models.generation import configure_generation_config from nemo_rl.utils.config import load_config, parse_hydra_overrides @@ -97,7 +97,7 @@ def main() -> None: val_dataset, task_to_env, val_task_to_env, - ) = setup_data_with_envs(processor, config["data"], config["env"], is_vlm=True) + ) = setup_response_data(processor, config["data"], config["env"], is_vlm=True) ( policy, diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 2fb26ebd90..ad70f95c75 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -44,7 +44,7 @@ class PreferenceDatasetConfig(TypedDict): class DataConfig(TypedDict): - max_input_seq_length: int + max_input_seq_length: int | None add_bos: NotRequired[bool] add_eos: NotRequired[bool] add_generation_prompt: NotRequired[bool] diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index 524b854c3a..961b7b9ba8 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -22,6 +22,7 @@ from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset +from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset from nemo_rl.data.datasets.response_datasets.oai_format_dataset import ( OpenAIFormatDataset, ) @@ -50,6 +51,7 @@ "tulu3_sft_mixture": Tulu3SftMixtureDataset, # load from local JSONL file or HuggingFace "openai_format": OpenAIFormatDataset, + "NemoGymDataset": NemoGymDataset, "ResponseDataset": ResponseDataset, } @@ -87,6 +89,7 @@ def load_response_dataset(data_config: ResponseDatasetConfig): "DeepScalerDataset", "Geometry3KDataset", "HelpSteer3Dataset", + "NemoGymDataset", "OasstDataset", "OpenAIFormatDataset", "OpenMathInstruct2Dataset", diff --git a/nemo_rl/data/datasets/response_datasets/nemogym_dataset.py b/nemo_rl/data/datasets/response_datasets/nemogym_dataset.py new file mode 100644 index 0000000000..a1f97e1a5e --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/nemogym_dataset.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import Dataset + +from nemo_rl.data.datasets.raw_dataset import RawDataset + + +class NemoGymDataset(RawDataset): + """Simple wrapper around the Nemo Gym dataset. + + Args: + data_path: Path to the dataset JSONL file + repeat: Number of times to repeat the dataset, default is 1 + """ + + def __init__(self, data_path: str, repeat: int = 1, **kwargs) -> None: + self.task_name = "-".join(data_path.split("/")[-2:]).split(".")[0] + if self.task_name[0] == "-": + self.task_name = self.task_name[1:] + + # load raw line from jsonl + # will use `json.loads` to load to dict format at `nemo_gym_data_processor` later since `Dataset` cannot handle nested structure well + with open(data_path) as f: + self.dataset = [raw_line for raw_line in f] + + # format the dataset + self.dataset = Dataset.from_dict( + { + "extra_env_info": self.dataset, + "task_name": [self.task_name] * len(self.dataset), + } + ) + + # repeat the dataset + if repeat > 1: + self.dataset = self.dataset.repeat(repeat) diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 8958bef259..52ac9bf67d 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -14,6 +14,7 @@ """Contains data processors for evaluation.""" +import json import logging from typing import Any, Dict, cast @@ -663,6 +664,27 @@ def multichoice_qa_processor( return output +def nemo_gym_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int | None, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym.""" + output: DatumSpec = { + # load to dict format here since `Dataset` cannot handle nested structure well in `NemoGymDataset` + "extra_env_info": json.loads(datum_dict["extra_env_info"]), + "loss_multiplier": 1.0, + "idx": idx, + "task_name": datum_dict["task_name"], + # fake keys for compatibility with the current GRPO implementation + "message_log": [{"role": "user", "content": "", "token_ids": torch.tensor([])}], + "length": 0, + } + return output + + # Processor registry. Key is the processor name, value is the processor function. # Note: We cast the literal dict to Dict[str, TaskDataProcessFnCallable] because # type checkers see each concrete function's signature as a distinct callable type. @@ -679,6 +701,7 @@ def multichoice_qa_processor( "multichoice_qa_processor": multichoice_qa_processor, "sft_processor": sft_processor, "vlm_hf_data_processor": vlm_hf_data_processor, + "nemo_gym_data_processor": nemo_gym_data_processor, }, ) diff --git a/nemo_rl/data/utils.py b/nemo_rl/data/utils.py index c7f97b0592..4b11e80e7f 100644 --- a/nemo_rl/data/utils.py +++ b/nemo_rl/data/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Union from datasets import concatenate_datasets from transformers import AutoProcessor, AutoTokenizer @@ -31,16 +31,19 @@ # TODO: @yukih: unify to setup_data after dataset refactored -def setup_data_with_envs( +def setup_response_data( tokenizer: AutoProcessor | AutoTokenizer, data_config: DataConfig, - env_configs: dict[str, Any], + env_configs: Optional[dict[str, Any]] = None, is_vlm: bool = False, -) -> tuple[ - AllTaskProcessedDataset, - Optional[AllTaskProcessedDataset], - dict[str, EnvironmentInterface], - dict[str, EnvironmentInterface], +) -> Union[ + tuple[AllTaskProcessedDataset, Optional[AllTaskProcessedDataset]], + tuple[ + AllTaskProcessedDataset, + Optional[AllTaskProcessedDataset], + dict[str, EnvironmentInterface], + dict[str, EnvironmentInterface], + ], ]: """Setup data with environments. @@ -50,24 +53,33 @@ def setup_data_with_envs( tokenizer: Tokenizer or processor. data_config: Data config. env_configs: Environment configs. + If None, no environments will be created. This is used for: + - Algorithms like SFT which do not need environments. + - Environments like NeMo-Gym which need to handle the environment creation outside of this function. is_vlm: Whether to use VLM training or not. Returns: - A tuple of (train dataset, validation dataset, task to environment, task to validation environment). + If env_configs is not None: + A tuple of (train dataset, validation dataset, task to environment, task to validation environment). + If env_configs is None: + A tuple of (train dataset, validation dataset). """ assert "train" in data_config, ( "The dataset config structure is updated. Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#dataset " "and the Migrate Guide in https://github.com/NVIDIA-NeMo/RL/pull/1649 to update the dataset config." ) - print("\n▶ Setting up envs...") - env_name_list = extract_necessary_env_names(data_config) - envs = {} - for env_name in env_name_list: - registered_env_name = "vlm" if is_vlm else env_name - envs[env_name] = create_env( - env_name=registered_env_name, env_config=env_configs[env_name] - ) + # setup environments if needed + has_envs = env_configs is not None + if has_envs: + print("\n▶ Setting up envs...") + env_name_list = extract_necessary_env_names(data_config) + envs = {} + for env_name in env_name_list: + registered_env_name = "vlm" if is_vlm else env_name + envs[env_name] = create_env( + env_name=registered_env_name, env_config=env_configs[env_name] + ) print("\n▶ Setting up data...") # setup train dataset @@ -87,7 +99,8 @@ def setup_data_with_envs( # bind task_name to task_data_processors and task_to_env task_name = data.task_name task_data_processors[task_name] = (data.task_spec, data.processor) - task_to_env[task_name] = envs[cfg["env_name"]] + if has_envs: + task_to_env[task_name] = envs[cfg["env_name"]] merged_data = concatenate_datasets([data.dataset for data in data_list]) dataset = AllTaskProcessedDataset( @@ -111,7 +124,8 @@ def setup_data_with_envs( # bind task_name to task_data_processors and task_to_env task_name = data.task_name val_task_data_processors[task_name] = task_data_processors[task_name] - val_task_to_env[task_name] = task_to_env[task_name] + if has_envs: + val_task_to_env[task_name] = task_to_env[task_name] # validation dataset from config if "validation" in data_config and data_config["validation"] is not None: @@ -130,7 +144,8 @@ def setup_data_with_envs( val_data.task_spec, val_data.processor, ) - val_task_to_env[task_name] = envs[cfg["env_name"]] + if has_envs: + val_task_to_env[task_name] = envs[cfg["env_name"]] val_dataset = None if len(val_data_list) > 0: @@ -144,7 +159,10 @@ def setup_data_with_envs( ) print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") - return dataset, val_dataset, task_to_env, val_task_to_env + if has_envs: + return dataset, val_dataset, task_to_env, val_task_to_env + else: + return dataset, val_dataset # TODO: @yukih: unify to setup_data after dataset refactored diff --git a/nemo_rl/environments/nemo_gym.py b/nemo_rl/environments/nemo_gym.py index 5ec15c3cef..6694d4f4d1 100644 --- a/nemo_rl/environments/nemo_gym.py +++ b/nemo_rl/environments/nemo_gym.py @@ -18,7 +18,6 @@ import torch from transformers import PreTrainedTokenizerBase -from nemo_rl.data.interfaces import DatumSpec from nemo_rl.distributed.virtual_cluster import _get_free_port_local, _get_node_ip_local from nemo_rl.environments.interfaces import EnvironmentInterface from nemo_rl.utils.timer import Timer @@ -236,27 +235,3 @@ def setup_nemo_gym_config(config, tokenizer) -> None: # Stop strings or token ids are not supported generation_config["stop_strings"] = None generation_config["stop_token_ids"] = None - - -######################################## -# Data utils -######################################## - - -# We do some light preprocessing here to make our data format compatible with nemo rl format -def nemo_gym_example_to_nemo_rl_datum_spec( - nemo_gym_example: dict, idx: int -) -> DatumSpec: - return DatumSpec( - message_log=[ - {"role": "user", "content": "", "token_ids": torch.tensor([])} - ], # Fake message - length=0, - extra_env_info=nemo_gym_example, - loss_multiplier=1.0, # Fix to 1.0 to backprop on all examples - idx=idx, - task_name="nemo_gym", - stop_strings=None, - # Extra vars - token_ids=[], # Just need this empty key to be compatible with the current NeMo RL GRPO impl - ) diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index 99fe9eda1a..9b4f4d6279 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -46,6 +46,9 @@ class EnvRegistryEntry(TypedDict, total=False): "vlm": { "actor_class_fqn": "nemo_rl.environments.vlm_environment.VLMEnvironment", }, + "nemo_gym": { + "actor_class_fqn": "nemo_rl.environments.nemo_gym.NemoGym", + }, } diff --git a/pyrefly.toml b/pyrefly.toml index 32e67b658a..93adf3fe88 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -65,6 +65,7 @@ project-includes = [ "nemo_rl/data/datasets/response_datasets/deepscaler.py", "nemo_rl/data/datasets/response_datasets/geometry3k.py", "nemo_rl/data/datasets/response_datasets/helpsteer3.py", + "nemo_rl/data/datasets/response_datasets/nemogym_dataset.py", "nemo_rl/data/datasets/response_datasets/oai_format_dataset.py", "nemo_rl/data/datasets/response_datasets/oasst.py", "nemo_rl/data/datasets/response_datasets/openmathinstruct2.py", @@ -102,16 +103,15 @@ project-includes = [ "nemo_rl/models/dtensor/parallelize.py", "nemo_rl/models/generation/__init__.py", "nemo_rl/models/generation/interfaces.py", + "nemo_rl/models/generation/sglang/__init__.py", + "nemo_rl/models/generation/sglang/config.py", "nemo_rl/models/generation/vllm/__init__.py", "nemo_rl/models/generation/vllm/config.py", "nemo_rl/models/generation/vllm/quantization/fp8_train_utils.py", "nemo_rl/models/generation/vllm/utils.py", "nemo_rl/models/generation/vllm/vllm_backend.py", - "nemo_rl/models/generation/sglang/__init__.py", - "nemo_rl/models/generation/sglang/config.py", "nemo_rl/models/huggingface/__init__.py", "nemo_rl/models/megatron/__init__.py", - "nemo_rl/models/megatron/community_import.py", "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 117fdca936..f3486de21e 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -13,6 +13,8 @@ # limitations under the License. import gc +import json +import tempfile from copy import deepcopy from dataclasses import asdict @@ -22,8 +24,10 @@ from transformers import AutoTokenizer from nemo_rl.data.collate_fn import rl_collate_fn +from nemo_rl.data.datasets.response_datasets import NemoGymDataset from nemo_rl.data.interfaces import DatumSpec from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message +from nemo_rl.data.processors import nemo_gym_data_processor from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.environments.games.sliding_puzzle import ( @@ -32,7 +36,6 @@ SlidingPuzzleGameLogic, SlidingPuzzleMetadata, ) -from nemo_rl.environments.nemo_gym import nemo_gym_example_to_nemo_rl_datum_spec from nemo_rl.experience.rollouts import ( _calculate_single_metric, run_async_multi_turn_rollout, @@ -794,10 +797,21 @@ def test_run_async_nemo_gym_rollout( nemo_gym_sanity_test_data, # noqa: F811 nemo_gym_tokenizer, # noqa: F811 ): + # only keep the input part of the data for the test + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for data in nemo_gym_sanity_test_data["input"]: + f.write(json.dumps(data) + "\n") + data_path = f.name + + # load the dataset and convert to compatible format for Nemo RL + nemo_gym_sanity_test_data = NemoGymDataset(data_path) nemo_rl_compatible_examples: list[DatumSpec] = [ - nemo_gym_example_to_nemo_rl_datum_spec(nemo_gym_example, idx) - for idx, nemo_gym_example in enumerate(nemo_gym_sanity_test_data["input"]) + nemo_gym_data_processor( + nemo_gym_sanity_test_data.dataset[idx], None, None, None, idx + ) + for idx in range(len(nemo_gym_sanity_test_data.dataset)) ] + input_batch: BatchedDataDict[DatumSpec] = rl_collate_fn(nemo_rl_compatible_examples) actual_result = run_async_nemo_gym_rollout( policy_generation=nemo_gym_vllm_generation,