diff --git a/.gitignore b/.gitignore index 954db7041d..55a992fece 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,8 @@ ckpts/ # Test coverage.json .coverage* +unit_results.json +unit_results/ test_assets/ # Cache diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 277619bb92..e300aec54b 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -21,7 +21,9 @@ "nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.VLLM, "nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker": PY_EXECUTABLES.MCORE, "nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM, + "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM, + "nemo_rl.environments.tools.retriever.RAGEnvironment": PY_EXECUTABLES.SYSTEM, } diff --git a/nemo_rl/environments/code_environment.py b/nemo_rl/environments/code_environment.py new file mode 100644 index 0000000000..029b9cd1c0 --- /dev/null +++ b/nemo_rl/environments/code_environment.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ast +import builtins +import os +import re +from collections.abc import Mapping, Sequence +from contextlib import contextmanager +from copy import copy +from io import IOBase +from pprint import pformat +from types import ModuleType +from typing import Any, Dict, List, Optional, Tuple, TypedDict + +import ray +import torch + +from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES +from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn +from nemo_rl.environments.utils import chunk_list_to_workers + + +class CodeEnvConfig(TypedDict): + num_workers: int + # whether to terminate the execution after expression evaluation + # if you want to execute multiple rounds of code, set this to False + # and wrap CodeEnvironment in another environment that terminates the generation + terminate_on_evaluation: bool + + +class CodeEnvMetadata(TypedDict): + context: Dict[str, Any] # Hold functions and variables defined in the code + working_dir: str # Working directory for file operations + + +@ray.remote # pragma: no cover +class CodeExecutionWorker: + """Helper class to process individual code execution steps.""" + + def __init__(self): + # Create sandbox with safe builtins + builtin_dict = {k: getattr(builtins, k) for k in dir(builtins)} + builtin_dict["open"] = self.safe_open + builtin_dict["__import__"] = self.safe_import + self.sandbox = {"__builtins__": builtin_dict} + + def sanitize(self, obj: Any) -> Any: + # TODO: better handling of unpicklable objects: custom __getstate__ and __setstate__ + # recursively remove all file objects as they are not picklable by ray + if isinstance(obj, (IOBase, ModuleType)): + # replace unpickable objects with a string representation + return repr(obj) + if isinstance(obj, Mapping): + return obj.__class__( + {self.sanitize(k): self.sanitize(v) for k, v in obj.items()} + ) + if isinstance(obj, Sequence) and not isinstance(obj, str): + return obj.__class__(self.sanitize(v) for v in obj) + if hasattr(obj, "__dict__"): + new_obj = copy(obj) + new_obj.__dict__ = { + self.sanitize(k): self.sanitize(v) for k, v in obj.__dict__.items() + } + return new_obj + return obj + + def format_result( + self, result: Any, code: Optional[str] = None, lookahead: Optional[str] = None + ) -> str: + if result is None: + # no return value + return "" + result = pformat(result) + multiline = (code and "\n" in code) or "\n" in result + if multiline: + # multi-line format + result = f"\n\n\n{result}\n" + else: + # inline format + result = f"{result}" + if lookahead: + if result.startswith(lookahead): + # The generation may look like "\n" if ">\n" is a single token. + # We trim \n from the result if the model has already generated it. + result = result[len(lookahead) :] + return result + + def execute(self, message_batch: str, metadata_batch: List[CodeEnvMetadata]) -> str: + """Execute code in a sandboxed environment.""" + results = [] + terminateds = [] + + for message, metadata in zip(message_batch, metadata_batch): + match = re.search(r"(.*)(.*)", message, re.DOTALL) + if not match: + results.append("") + terminateds.append(False) + continue + + code, lookahead = match.groups() + tree = ast.parse(code) + + if tree.body and isinstance(tree.body[-1], ast.Expr): + # Interactive mode + exec_code = ast.unparse(tree.body[:-1]) + eval_code = ast.unparse(tree.body[-1]) + else: + # Silent mode + exec_code = code + eval_code = None + + result = None + terminated = False + with self.chdir(metadata["working_dir"]): + try: + # isolate the code in a sandbox + # capture local variables in metadata["context"] + exec(exec_code, self.sandbox, metadata["context"]) + if eval_code: + result = eval(eval_code, self.sandbox, metadata["context"]) + terminated = True + except Exception as err: + result = err + + result = self.format_result(result, code, lookahead) + results.append(result) + terminateds.append(terminated) + + observations = [ + {"role": "environment", "content": result} for result in results + ] + metadata_batch = self.sanitize(metadata_batch) + + return observations, terminateds, metadata_batch + + @contextmanager + def chdir(self, dir: str): + """Change to temporary directory for file operations.""" + current_dir = os.getcwd() + os.chdir(dir) + try: + yield + finally: + os.chdir(current_dir) + + def safe_open(self, file: str, *args, **kwargs): + """Safe version of open() that only allows access to temporary directory.""" + real_file = os.path.realpath(file) + working_dir = os.path.realpath(os.getcwd()) + if os.path.commonpath([real_file, working_dir]) != working_dir: + raise PermissionError( + "Access beyond the temporary working directory is blocked" + ) + return open(file, *args, **kwargs) + + def safe_import(self, name: str, *args, **kwargs): + """Safe version of import that blocks risky modules.""" + risky_modules = { + "os", + "shutil", # erase filesystem + "sys", + "signal", # exit the current program + "socket", # network communication + "subprocess", + "threading", + "multiprocessing", # spawn threads or processes + "builtins", + "importlib", # bypass current blockers + } + if name in risky_modules: + raise PermissionError("Importing system and network modules is blocked") + return builtins.__import__(name, *args, **kwargs) + + +@ray.remote # pragma: no cover +class CodeEnvironment(EnvironmentInterface): + """Code execution environment that maintains state between steps.""" + + def __init__(self, cfg: CodeEnvConfig): + self.cfg = cfg + self.num_workers = cfg["num_workers"] + self.terminate_on_evaluation = cfg["terminate_on_evaluation"] + self.workers = [ + CodeExecutionWorker.options( + runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} + ).remote() + for _ in range(self.num_workers) + ] + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[CodeEnvMetadata], + ) -> EnvironmentReturn: + """Process a batch of code execution steps.""" + message_batch = [ml[-1]["content"] for ml in message_log_batch] + chunked_message_batch = chunk_list_to_workers(message_batch, self.num_workers) + chunked_metadata_batch = chunk_list_to_workers(metadata_batch, self.num_workers) + + # Process each chunk in parallel + futures = [ + self.workers[i].execute.remote(message_chunk, metadata_chunk) + for i, (message_chunk, metadata_chunk) in enumerate( + zip(chunked_message_batch, chunked_metadata_batch) + ) + ] + + results = ray.get(futures) + + # Unpack results + observations = [] + terminateds = [] + new_metadata_batch = [] + + for obs, term, meta in results: + observations += obs + terminateds += term + new_metadata_batch += meta + + if self.terminate_on_evaluation: + terminated_tensor = torch.tensor(terminateds, dtype=torch.bool) + else: + terminated_tensor = torch.zeros(len(terminateds), dtype=torch.bool) + rewards_tensor = torch.zeros_like(terminated_tensor, dtype=torch.float32) + + next_stop_strings = [[""]] * len(message_log_batch) + + return EnvironmentReturn( + observations=observations, + metadata=new_metadata_batch, + next_stop_strings=next_stop_strings, + rewards=rewards_tensor, + terminateds=terminated_tensor, + ) + + def shutdown(self): + # shutdown all workers + for worker in self.workers: + ray.kill(worker) + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + """Compute metrics for the batch.""" + # No specific metrics for code execution + return batch, {} diff --git a/nemo_rl/environments/tools/retriever.py b/nemo_rl/environments/tools/retriever.py new file mode 100644 index 0000000000..8f408fc92b --- /dev/null +++ b/nemo_rl/environments/tools/retriever.py @@ -0,0 +1,206 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import re +from collections import Counter +from typing import Any, Dict, List, TypedDict + +import ray +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn + + +class RAGEnvConfig(TypedDict): + dataset_name: str # Name of the dataset to load + dataset_split: str # Split of the dataset to use + text_column: str # Column name containing the text to retrieve + num_results: int # Number of documents to retrieve + k1: float # BM25 parameter + b: float # BM25 parameter + device: str # Device to compute BM25 + + +class BM25Retriever: + """Sparse BM25 retriever. + + Args: + documents: list of documents to retrieve from + num_result: retrieve top-k documents + k1: parameter of BM25. Values in [1.2, 2.0] are recommended. + b: parameter of BM25. 0.75 is recommended. + device: device to compute BM25 + """ + + def __init__( + self, + documents: List[str] = None, + num_result: int = 10, + k1: float = 1.5, + b: float = 0.75, + device: str = "cpu", + ): + if documents is None: + dataset = load_dataset("wikimedia/wikipedia", "20231101.en") + self.documents = [sample["text"] for sample in dataset["train"]] + else: + self.documents = documents + self.tokenizer = AutoTokenizer.from_pretrained( + "bert-base-uncased", use_fast=True + ) + self.num_result = num_result + self.k1 = k1 + self.b = b + self.device = device + self.corpus_size = len(self.documents) + self.vocab_size = self.tokenizer.vocab_size + + self.build_index() + + def build_index(self): + doc_ids = [] + token_ids = [] + tfs = [] + lengths = [] + + for i, document in enumerate( + tqdm(self.documents, "Build index for BM25Retriever") + ): + input_ids = self.tokenizer.encode(document, add_special_tokens=False) + token2cnt = Counter(input_ids) + token_ids += token2cnt.keys() + tfs += token2cnt.values() + doc_ids += [i] * len(token2cnt) + lengths.append(len(input_ids)) + + avg_dl = sum(lengths) / self.corpus_size + for i, doc_id in enumerate(doc_ids): + tfs[i] = ( + tfs[i] + * (self.k1 + 1) + / (tfs[i] + self.k1 * (1 - self.b + self.b * lengths[doc_id] / avg_dl)) + ) + + indices = torch.tensor([doc_ids, token_ids], device=self.device) + values = torch.tensor(tfs, device=self.device) + self.doc_tfs = torch.sparse_coo_tensor( + indices, values, (self.corpus_size, self.vocab_size) + ) + + idfs = [0] * self.vocab_size + token2df = Counter(token_ids) + for token_id, df in token2df.items(): + idfs[token_id] = math.log((self.corpus_size - df + 0.5) / (df + 0.5) + 1) + self.idfs = idfs + + def __call__(self, query: str) -> List[str]: + input_ids = self.tokenizer.encode(query, add_special_tokens=False) + token2cnt = Counter(input_ids) + token_ids = [] + query_idfs = [] + for token_id, query_tf in token2cnt.items(): + token_ids.append(token_id) + query_idfs.append(query_tf * self.idfs[token_id]) + + indices = torch.tensor([token_ids, [0] * len(token_ids)], device=self.device) + values = torch.tensor(query_idfs, device=self.device) + query_idfs = torch.sparse_coo_tensor(indices, values, (self.vocab_size, 1)) + + scores = torch.sparse.mm(self.doc_tfs, query_idfs) + scores = scores.to_dense().squeeze(-1) + results = [] + for i in scores.topk(k=self.num_result).indices.tolist(): + results.append(self.documents[i]) + + return results + + +@ray.remote # pragma: no cover +class RAGEnvironment(EnvironmentInterface): + """RAG environment that uses BM25 for document retrieval.""" + + def __init__(self, cfg: RAGEnvConfig): + self.cfg = cfg + + # Load dataset + dataset = load_dataset(cfg["dataset_name"], split=cfg["dataset_split"]) + documents = [sample[cfg["text_column"]] for sample in dataset] + + # Initialize BM25 retriever + self.retriever = BM25Retriever( + documents=documents, + num_result=cfg["num_results"], + k1=cfg["k1"], + b=cfg["b"], + device=cfg["device"], + ) + + def format_result(self, retrieved_docs: List[str]) -> str: + result = "\n" + for i, doc in enumerate(retrieved_docs): + result += f"<{i + 1}>\n{doc}\n\n" + result += "\n" + return result + + def step( + self, + message_log_batch: List[LLMMessageLogType], + metadata_batch: List[Dict[str, Any]], + ) -> EnvironmentReturn: + """Process a batch of retrieval steps.""" + # Extract queries from the last message in each log + messages = [ml[-1]["content"] for ml in message_log_batch] + + # Retrieve documents for each query + results = [] + for message in messages: + match = re.search(r"(.*)", message, re.DOTALL) + if not match: + results.append( + {"role": "environment", "content": "No retrieval query found!"} + ) + continue + query = match.group(1) + retrieved_docs = self.retriever(query) + result = self.format_result(retrieved_docs) + results.append({"role": "environment", "content": result}) + + batch_size = len(message_log_batch) + rewards_tensor = torch.zeros(batch_size, dtype=torch.float32) + terminated_tensor = torch.ones(batch_size, dtype=torch.bool) + next_stop_strings = [[""]] * batch_size + + return EnvironmentReturn( + observations=results, + metadata=metadata_batch, + next_stop_strings=next_stop_strings, + rewards=rewards_tensor, + terminateds=terminated_tensor, + ) + + def shutdown(self): + """Clean up resources.""" + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> tuple[BatchedDataDict, dict]: + """Compute metrics for the batch.""" + # No specific metrics for RAG + return batch, {} diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 1e0a027bcb..b9627e8a60 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -410,6 +410,8 @@ def run_multi_turn_rollout( tokenized_obs = tokenizer( env_obs_content, return_tensors="pt", add_special_tokens=False ).input_ids[0] + # tokenizer returns torch.float32 when env_obs_content is empty + tokenized_obs = tokenized_obs.to(dtype=torch.int64) # check if new message overflows max_seq_len if ( diff --git a/tests/unit/environments/test_code_environment.py b/tests/unit/environments/test_code_environment.py new file mode 100644 index 0000000000..27ada6d9bf --- /dev/null +++ b/tests/unit/environments/test_code_environment.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tempfile import TemporaryDirectory + +import pytest +import ray +from transformers import AutoTokenizer + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.environments.code_environment import ( + CodeEnvConfig, + CodeEnvironment, + CodeEnvMetadata, +) +from nemo_rl.experience.rollouts import run_multi_turn_rollout +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration + +MODEL_NAME = "meta-llama/Llama-3.2-1B" + +cfg: CodeEnvConfig = { + "num_workers": 2, + "terminate_on_evaluation": True, +} + +# Define basic vLLM test config +basic_vllm_test_config: VllmConfig = { + "backend": "vllm", + "model_name": MODEL_NAME, + "tokenizer_name": None, + "dtype": "bfloat16", + "max_new_tokens": 100, + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "stop_token_ids": None, + "stop_strings": None, + "vllm_cfg": { + "async_engine": False, + "precision": "bfloat16", + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "max_model_len": 1024, + "disable_log_stats": True, + "disable_log_requests": True, + "gpu_memory_utilization": 0.6, + "enforce_eager": "False", + }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, +} + + +@pytest.fixture(scope="function") +def code_env(): + """Create a code environment for testing.""" + try: + env_actor = CodeEnvironment.remote(cfg) + yield env_actor + finally: + if env_actor: + ray.kill(env_actor) + + +@pytest.fixture(scope="function") +def tokenizer(): + """Loads the tokenizer for the tests.""" + print(f"Loading tokenizer: {MODEL_NAME}") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + print( + f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})" + ) + return tokenizer + + +@pytest.fixture(scope="function") +def cluster(): + """Create a virtual cluster for testing.""" + cluster_instance = None + cluster_name = f"test-code-cluster-{id(cluster_instance)}" + print(f"\nCreating virtual cluster '{cluster_name}'...") + try: + cluster_instance = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[1], + use_gpus=True, + num_gpus_per_node=1, + max_colocated_worker_groups=2, + ) + yield cluster_instance + finally: + print(f"\nCleaning up cluster '{cluster_name}'...") + if cluster_instance: + cluster_instance.shutdown() + + +def test_untrusted_code(code_env): + """Test whether the code environment can block untrusted code.""" + codes = [ + "with open('allowed_file.txt', 'w') as fout:\n" + " fout.write('some content')\n" + "with open('allowed_file.txt') as fin:\n" + " content = fin.read()\n" + "content", + "with open('/etc/passwd', 'r') as fin:\n fin.read()", + "import math\nround(math.sqrt(8))", + "import os", + ] + results = [ + "\n\n\n'some content'\n", + "\n\n\nPermissionError('Access beyond the temporary working directory is blocked')\n", + "\n\n\n3\n", + "PermissionError('Importing system and network modules is blocked')", + ] + + message_log_batch = [ + [{"role": "user", "content": f"{code}"}] for code in codes + ] + temp_dirs = [TemporaryDirectory() for _ in codes] + metadata_batch = [ + CodeEnvMetadata( + context={}, + working_dir=temp_dir.name, + ) + for temp_dir in temp_dirs + ] + + # Execute the code + output = ray.get(code_env.step.remote(message_log_batch, metadata_batch)) + responses = [obs["content"] for obs in output.observations] + + assert responses == results, f"Got wrong output {responses}" + + +@pytest.mark.hf_gated +def test_vllm_execute_code(cluster, tokenizer, code_env): + """Test that vLLM can call the code executor.""" + # Prepare test data + codes = [ + "x = 3; y = 4\nThis is some regular text.\nx + y\n", + "\ndef f(x):\n return x * x\n\nf(2)\n\n", + ] + results = ["7", "\n\n4\n"] + + # Create message logs + message_logs = [] + metadata_batch = [] + temp_dirs = [] + for code in codes: + # Tokenize the message content + prompt = code * 4 + token_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)[ + "input_ids" + ][0] + temp_dir = TemporaryDirectory() + message_logs.append( + [{"role": "user", "content": prompt, "token_ids": token_ids}] + ) + metadata_batch.append(CodeEnvMetadata(context={}, working_dir=temp_dir.name)) + temp_dirs.append(temp_dir) + + # Create initial batch + initial_batch = BatchedDataDict( + { + "message_log": message_logs, + "extra_env_info": metadata_batch, + "task_name": ["code_execution"] * len(codes), + "stop_strings": [[""]] * len(codes), + } + ) + + # Create vLLM generation + vllm_config = basic_vllm_test_config.copy() + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) + vllm_generation = VllmGeneration(cluster, vllm_config) + + # Create code environment + task_to_env = {"code_execution": code_env} + + # Run rollout + vllm_generation.prepare_for_generation() + final_batch, _ = run_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=initial_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=256, + max_rollout_turns=2, + greedy=True, + ) + vllm_generation.finish_generation() + + # Check results + for i, msg_log in enumerate(final_batch["message_log"]): + # Get the last message which should contain the result + last_msg = msg_log[-1] + assert last_msg["role"] == "environment" + assert last_msg["content"] == results[i], ( + f"Expected {results[i]}, got {last_msg['content']}" + ) diff --git a/tests/unit/environments/test_retriever.py b/tests/unit/environments/test_retriever.py new file mode 100644 index 0000000000..824c09b041 --- /dev/null +++ b/tests/unit/environments/test_retriever.py @@ -0,0 +1,179 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import ray +from transformers import AutoTokenizer + +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.environments.tools.retriever import RAGEnvConfig, RAGEnvironment +from nemo_rl.experience.rollouts import run_multi_turn_rollout +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration + +MODEL_NAME = "meta-llama/Llama-3.2-1B" + +cfg: RAGEnvConfig = { + "dataset_name": "rahular/simple-wikipedia", + "dataset_split": "train", + "text_column": "text", + "num_results": 1, + "k1": 1.5, + "b": 0.75, + "device": "cpu", +} + +# Define basic vLLM test config +basic_vllm_test_config: VllmConfig = { + "backend": "vllm", + "model_name": MODEL_NAME, + "tokenizer_name": None, + "dtype": "bfloat16", + "max_new_tokens": 100, + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "stop_token_ids": None, + "stop_strings": None, + "vllm_cfg": { + "async_engine": False, + "precision": "bfloat16", + "tensor_parallel_size": 1, + "pipeline_parallel_size": 1, + "max_model_len": 1024, + "disable_log_stats": True, + "disable_log_requests": True, + "gpu_memory_utilization": 0.6, + "enforce_eager": "False", + }, + "colocated": { + "enabled": True, + "resources": { + "gpus_per_node": None, + "num_nodes": None, + }, + }, +} + + +@pytest.fixture(scope="function") +def rag_env(): + """Create a RAG environment for testing.""" + try: + env_actor = RAGEnvironment.remote(cfg) + yield env_actor + finally: + if env_actor: + ray.kill(env_actor) + + +@pytest.fixture(scope="function") +def tokenizer(): + """Loads the tokenizer for the tests.""" + print(f"Loading tokenizer: {MODEL_NAME}") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + print( + f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})" + ) + return tokenizer + + +@pytest.fixture(scope="function") +def cluster(): + """Create a virtual cluster for testing.""" + cluster_instance = None + cluster_name = f"test-rag-cluster-{id(cluster_instance)}" + print(f"\nCreating virtual cluster '{cluster_name}'...") + try: + cluster_instance = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[1], + use_gpus=True, + num_gpus_per_node=1, + max_colocated_worker_groups=2, + ) + yield cluster_instance + finally: + print(f"\nCleaning up cluster '{cluster_name}'...") + if cluster_instance: + cluster_instance.shutdown() + + +@pytest.mark.hf_gated +def test_vllm_retrieve(cluster, tokenizer, rag_env): + """Test that vLLM can use the RAG environment for document retrieval.""" + # Prepare test data + queries = [ + "Jen-Hsun Huang\n", + ] + expected_results = [ + "\n<1>\n" + "Nvidia was established in 1993 by Jen-Hsun Huang, Curtis Priem, and Chris Malachowsky. In 2000 Nvidia took intellectual possession of 3dfx, one of the biggest GPU producers in 1990s.\n" + "\n\n", + ] + + # Create message logs + message_logs = [] + for query in queries: + # Tokenize the message content + prompt = query * 4 + token_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)[ + "input_ids" + ][0] + message_logs.append( + [{"role": "user", "content": prompt, "token_ids": token_ids}] + ) + + # Create initial batch + initial_batch = BatchedDataDict( + { + "message_log": message_logs, + "extra_env_info": [{}] * len(queries), # No metadata needed for RAG + "task_name": ["document_retrieval"] * len(queries), + "stop_strings": [[""]] * len(queries), + } + ) + + # Create vLLM generation + vllm_config = basic_vllm_test_config.copy() + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) + vllm_generation = VllmGeneration(cluster, vllm_config) + + # Create RAG environment + task_to_env = {"document_retrieval": rag_env} + + # Run rollout + vllm_generation.prepare_for_generation() + final_batch, _ = run_multi_turn_rollout( + policy_generation=vllm_generation, + input_batch=initial_batch, + tokenizer=tokenizer, + task_to_env=task_to_env, + max_seq_len=256, + max_rollout_turns=1, + greedy=True, + ) + vllm_generation.finish_generation() + + # Check results + for i, msg_log in enumerate(final_batch["message_log"]): + # Get the last message which should contain the result + last_msg = msg_log[-1] + assert last_msg["role"] == "environment" + assert last_msg["content"] == expected_results[i], ( + f"Expected {expected_results[i]}, got {last_msg['content']}" + )