diff --git a/.gitignore b/.gitignore
index 954db7041d..55a992fece 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,6 +22,8 @@ ckpts/
# Test
coverage.json
.coverage*
+unit_results.json
+unit_results/
test_assets/
# Cache
diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py
index 277619bb92..e300aec54b 100644
--- a/nemo_rl/distributed/ray_actor_environment_registry.py
+++ b/nemo_rl/distributed/ray_actor_environment_registry.py
@@ -21,7 +21,9 @@
"nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.VLLM,
"nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker": PY_EXECUTABLES.MCORE,
"nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM,
+ "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM,
+ "nemo_rl.environments.tools.retriever.RAGEnvironment": PY_EXECUTABLES.SYSTEM,
}
diff --git a/nemo_rl/environments/code_environment.py b/nemo_rl/environments/code_environment.py
new file mode 100644
index 0000000000..029b9cd1c0
--- /dev/null
+++ b/nemo_rl/environments/code_environment.py
@@ -0,0 +1,259 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import ast
+import builtins
+import os
+import re
+from collections.abc import Mapping, Sequence
+from contextlib import contextmanager
+from copy import copy
+from io import IOBase
+from pprint import pformat
+from types import ModuleType
+from typing import Any, Dict, List, Optional, Tuple, TypedDict
+
+import ray
+import torch
+
+from nemo_rl.data.interfaces import LLMMessageLogType
+from nemo_rl.distributed.batched_data_dict import BatchedDataDict
+from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES
+from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn
+from nemo_rl.environments.utils import chunk_list_to_workers
+
+
+class CodeEnvConfig(TypedDict):
+ num_workers: int
+ # whether to terminate the execution after expression evaluation
+ # if you want to execute multiple rounds of code, set this to False
+ # and wrap CodeEnvironment in another environment that terminates the generation
+ terminate_on_evaluation: bool
+
+
+class CodeEnvMetadata(TypedDict):
+ context: Dict[str, Any] # Hold functions and variables defined in the code
+ working_dir: str # Working directory for file operations
+
+
+@ray.remote # pragma: no cover
+class CodeExecutionWorker:
+ """Helper class to process individual code execution steps."""
+
+ def __init__(self):
+ # Create sandbox with safe builtins
+ builtin_dict = {k: getattr(builtins, k) for k in dir(builtins)}
+ builtin_dict["open"] = self.safe_open
+ builtin_dict["__import__"] = self.safe_import
+ self.sandbox = {"__builtins__": builtin_dict}
+
+ def sanitize(self, obj: Any) -> Any:
+ # TODO: better handling of unpicklable objects: custom __getstate__ and __setstate__
+ # recursively remove all file objects as they are not picklable by ray
+ if isinstance(obj, (IOBase, ModuleType)):
+ # replace unpickable objects with a string representation
+ return repr(obj)
+ if isinstance(obj, Mapping):
+ return obj.__class__(
+ {self.sanitize(k): self.sanitize(v) for k, v in obj.items()}
+ )
+ if isinstance(obj, Sequence) and not isinstance(obj, str):
+ return obj.__class__(self.sanitize(v) for v in obj)
+ if hasattr(obj, "__dict__"):
+ new_obj = copy(obj)
+ new_obj.__dict__ = {
+ self.sanitize(k): self.sanitize(v) for k, v in obj.__dict__.items()
+ }
+ return new_obj
+ return obj
+
+ def format_result(
+ self, result: Any, code: Optional[str] = None, lookahead: Optional[str] = None
+ ) -> str:
+ if result is None:
+ # no return value
+ return ""
+ result = pformat(result)
+ multiline = (code and "\n" in code) or "\n" in result
+ if multiline:
+ # multi-line format
+ result = f"\n\n\n{result}\n"
+ else:
+ # inline format
+ result = f"{result}"
+ if lookahead:
+ if result.startswith(lookahead):
+ # The generation may look like "\n" if ">\n" is a single token.
+ # We trim \n from the result if the model has already generated it.
+ result = result[len(lookahead) :]
+ return result
+
+ def execute(self, message_batch: str, metadata_batch: List[CodeEnvMetadata]) -> str:
+ """Execute code in a sandboxed environment."""
+ results = []
+ terminateds = []
+
+ for message, metadata in zip(message_batch, metadata_batch):
+ match = re.search(r"(.*)(.*)", message, re.DOTALL)
+ if not match:
+ results.append("")
+ terminateds.append(False)
+ continue
+
+ code, lookahead = match.groups()
+ tree = ast.parse(code)
+
+ if tree.body and isinstance(tree.body[-1], ast.Expr):
+ # Interactive mode
+ exec_code = ast.unparse(tree.body[:-1])
+ eval_code = ast.unparse(tree.body[-1])
+ else:
+ # Silent mode
+ exec_code = code
+ eval_code = None
+
+ result = None
+ terminated = False
+ with self.chdir(metadata["working_dir"]):
+ try:
+ # isolate the code in a sandbox
+ # capture local variables in metadata["context"]
+ exec(exec_code, self.sandbox, metadata["context"])
+ if eval_code:
+ result = eval(eval_code, self.sandbox, metadata["context"])
+ terminated = True
+ except Exception as err:
+ result = err
+
+ result = self.format_result(result, code, lookahead)
+ results.append(result)
+ terminateds.append(terminated)
+
+ observations = [
+ {"role": "environment", "content": result} for result in results
+ ]
+ metadata_batch = self.sanitize(metadata_batch)
+
+ return observations, terminateds, metadata_batch
+
+ @contextmanager
+ def chdir(self, dir: str):
+ """Change to temporary directory for file operations."""
+ current_dir = os.getcwd()
+ os.chdir(dir)
+ try:
+ yield
+ finally:
+ os.chdir(current_dir)
+
+ def safe_open(self, file: str, *args, **kwargs):
+ """Safe version of open() that only allows access to temporary directory."""
+ real_file = os.path.realpath(file)
+ working_dir = os.path.realpath(os.getcwd())
+ if os.path.commonpath([real_file, working_dir]) != working_dir:
+ raise PermissionError(
+ "Access beyond the temporary working directory is blocked"
+ )
+ return open(file, *args, **kwargs)
+
+ def safe_import(self, name: str, *args, **kwargs):
+ """Safe version of import that blocks risky modules."""
+ risky_modules = {
+ "os",
+ "shutil", # erase filesystem
+ "sys",
+ "signal", # exit the current program
+ "socket", # network communication
+ "subprocess",
+ "threading",
+ "multiprocessing", # spawn threads or processes
+ "builtins",
+ "importlib", # bypass current blockers
+ }
+ if name in risky_modules:
+ raise PermissionError("Importing system and network modules is blocked")
+ return builtins.__import__(name, *args, **kwargs)
+
+
+@ray.remote # pragma: no cover
+class CodeEnvironment(EnvironmentInterface):
+ """Code execution environment that maintains state between steps."""
+
+ def __init__(self, cfg: CodeEnvConfig):
+ self.cfg = cfg
+ self.num_workers = cfg["num_workers"]
+ self.terminate_on_evaluation = cfg["terminate_on_evaluation"]
+ self.workers = [
+ CodeExecutionWorker.options(
+ runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM}
+ ).remote()
+ for _ in range(self.num_workers)
+ ]
+
+ def step(
+ self,
+ message_log_batch: List[LLMMessageLogType],
+ metadata_batch: List[CodeEnvMetadata],
+ ) -> EnvironmentReturn:
+ """Process a batch of code execution steps."""
+ message_batch = [ml[-1]["content"] for ml in message_log_batch]
+ chunked_message_batch = chunk_list_to_workers(message_batch, self.num_workers)
+ chunked_metadata_batch = chunk_list_to_workers(metadata_batch, self.num_workers)
+
+ # Process each chunk in parallel
+ futures = [
+ self.workers[i].execute.remote(message_chunk, metadata_chunk)
+ for i, (message_chunk, metadata_chunk) in enumerate(
+ zip(chunked_message_batch, chunked_metadata_batch)
+ )
+ ]
+
+ results = ray.get(futures)
+
+ # Unpack results
+ observations = []
+ terminateds = []
+ new_metadata_batch = []
+
+ for obs, term, meta in results:
+ observations += obs
+ terminateds += term
+ new_metadata_batch += meta
+
+ if self.terminate_on_evaluation:
+ terminated_tensor = torch.tensor(terminateds, dtype=torch.bool)
+ else:
+ terminated_tensor = torch.zeros(len(terminateds), dtype=torch.bool)
+ rewards_tensor = torch.zeros_like(terminated_tensor, dtype=torch.float32)
+
+ next_stop_strings = [[""]] * len(message_log_batch)
+
+ return EnvironmentReturn(
+ observations=observations,
+ metadata=new_metadata_batch,
+ next_stop_strings=next_stop_strings,
+ rewards=rewards_tensor,
+ terminateds=terminated_tensor,
+ )
+
+ def shutdown(self):
+ # shutdown all workers
+ for worker in self.workers:
+ ray.kill(worker)
+
+ def global_post_process_and_metrics(
+ self, batch: BatchedDataDict
+ ) -> Tuple[BatchedDataDict, dict]:
+ """Compute metrics for the batch."""
+ # No specific metrics for code execution
+ return batch, {}
diff --git a/nemo_rl/environments/tools/retriever.py b/nemo_rl/environments/tools/retriever.py
new file mode 100644
index 0000000000..8f408fc92b
--- /dev/null
+++ b/nemo_rl/environments/tools/retriever.py
@@ -0,0 +1,206 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import re
+from collections import Counter
+from typing import Any, Dict, List, TypedDict
+
+import ray
+import torch
+from datasets import load_dataset
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+from nemo_rl.data.interfaces import LLMMessageLogType
+from nemo_rl.distributed.batched_data_dict import BatchedDataDict
+from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn
+
+
+class RAGEnvConfig(TypedDict):
+ dataset_name: str # Name of the dataset to load
+ dataset_split: str # Split of the dataset to use
+ text_column: str # Column name containing the text to retrieve
+ num_results: int # Number of documents to retrieve
+ k1: float # BM25 parameter
+ b: float # BM25 parameter
+ device: str # Device to compute BM25
+
+
+class BM25Retriever:
+ """Sparse BM25 retriever.
+
+ Args:
+ documents: list of documents to retrieve from
+ num_result: retrieve top-k documents
+ k1: parameter of BM25. Values in [1.2, 2.0] are recommended.
+ b: parameter of BM25. 0.75 is recommended.
+ device: device to compute BM25
+ """
+
+ def __init__(
+ self,
+ documents: List[str] = None,
+ num_result: int = 10,
+ k1: float = 1.5,
+ b: float = 0.75,
+ device: str = "cpu",
+ ):
+ if documents is None:
+ dataset = load_dataset("wikimedia/wikipedia", "20231101.en")
+ self.documents = [sample["text"] for sample in dataset["train"]]
+ else:
+ self.documents = documents
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ "bert-base-uncased", use_fast=True
+ )
+ self.num_result = num_result
+ self.k1 = k1
+ self.b = b
+ self.device = device
+ self.corpus_size = len(self.documents)
+ self.vocab_size = self.tokenizer.vocab_size
+
+ self.build_index()
+
+ def build_index(self):
+ doc_ids = []
+ token_ids = []
+ tfs = []
+ lengths = []
+
+ for i, document in enumerate(
+ tqdm(self.documents, "Build index for BM25Retriever")
+ ):
+ input_ids = self.tokenizer.encode(document, add_special_tokens=False)
+ token2cnt = Counter(input_ids)
+ token_ids += token2cnt.keys()
+ tfs += token2cnt.values()
+ doc_ids += [i] * len(token2cnt)
+ lengths.append(len(input_ids))
+
+ avg_dl = sum(lengths) / self.corpus_size
+ for i, doc_id in enumerate(doc_ids):
+ tfs[i] = (
+ tfs[i]
+ * (self.k1 + 1)
+ / (tfs[i] + self.k1 * (1 - self.b + self.b * lengths[doc_id] / avg_dl))
+ )
+
+ indices = torch.tensor([doc_ids, token_ids], device=self.device)
+ values = torch.tensor(tfs, device=self.device)
+ self.doc_tfs = torch.sparse_coo_tensor(
+ indices, values, (self.corpus_size, self.vocab_size)
+ )
+
+ idfs = [0] * self.vocab_size
+ token2df = Counter(token_ids)
+ for token_id, df in token2df.items():
+ idfs[token_id] = math.log((self.corpus_size - df + 0.5) / (df + 0.5) + 1)
+ self.idfs = idfs
+
+ def __call__(self, query: str) -> List[str]:
+ input_ids = self.tokenizer.encode(query, add_special_tokens=False)
+ token2cnt = Counter(input_ids)
+ token_ids = []
+ query_idfs = []
+ for token_id, query_tf in token2cnt.items():
+ token_ids.append(token_id)
+ query_idfs.append(query_tf * self.idfs[token_id])
+
+ indices = torch.tensor([token_ids, [0] * len(token_ids)], device=self.device)
+ values = torch.tensor(query_idfs, device=self.device)
+ query_idfs = torch.sparse_coo_tensor(indices, values, (self.vocab_size, 1))
+
+ scores = torch.sparse.mm(self.doc_tfs, query_idfs)
+ scores = scores.to_dense().squeeze(-1)
+ results = []
+ for i in scores.topk(k=self.num_result).indices.tolist():
+ results.append(self.documents[i])
+
+ return results
+
+
+@ray.remote # pragma: no cover
+class RAGEnvironment(EnvironmentInterface):
+ """RAG environment that uses BM25 for document retrieval."""
+
+ def __init__(self, cfg: RAGEnvConfig):
+ self.cfg = cfg
+
+ # Load dataset
+ dataset = load_dataset(cfg["dataset_name"], split=cfg["dataset_split"])
+ documents = [sample[cfg["text_column"]] for sample in dataset]
+
+ # Initialize BM25 retriever
+ self.retriever = BM25Retriever(
+ documents=documents,
+ num_result=cfg["num_results"],
+ k1=cfg["k1"],
+ b=cfg["b"],
+ device=cfg["device"],
+ )
+
+ def format_result(self, retrieved_docs: List[str]) -> str:
+ result = "\n"
+ for i, doc in enumerate(retrieved_docs):
+ result += f"<{i + 1}>\n{doc}\n{i + 1}>\n"
+ result += "\n"
+ return result
+
+ def step(
+ self,
+ message_log_batch: List[LLMMessageLogType],
+ metadata_batch: List[Dict[str, Any]],
+ ) -> EnvironmentReturn:
+ """Process a batch of retrieval steps."""
+ # Extract queries from the last message in each log
+ messages = [ml[-1]["content"] for ml in message_log_batch]
+
+ # Retrieve documents for each query
+ results = []
+ for message in messages:
+ match = re.search(r"(.*)", message, re.DOTALL)
+ if not match:
+ results.append(
+ {"role": "environment", "content": "No retrieval query found!"}
+ )
+ continue
+ query = match.group(1)
+ retrieved_docs = self.retriever(query)
+ result = self.format_result(retrieved_docs)
+ results.append({"role": "environment", "content": result})
+
+ batch_size = len(message_log_batch)
+ rewards_tensor = torch.zeros(batch_size, dtype=torch.float32)
+ terminated_tensor = torch.ones(batch_size, dtype=torch.bool)
+ next_stop_strings = [[""]] * batch_size
+
+ return EnvironmentReturn(
+ observations=results,
+ metadata=metadata_batch,
+ next_stop_strings=next_stop_strings,
+ rewards=rewards_tensor,
+ terminateds=terminated_tensor,
+ )
+
+ def shutdown(self):
+ """Clean up resources."""
+ pass
+
+ def global_post_process_and_metrics(
+ self, batch: BatchedDataDict
+ ) -> tuple[BatchedDataDict, dict]:
+ """Compute metrics for the batch."""
+ # No specific metrics for RAG
+ return batch, {}
diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py
index 1e0a027bcb..b9627e8a60 100644
--- a/nemo_rl/experience/rollouts.py
+++ b/nemo_rl/experience/rollouts.py
@@ -410,6 +410,8 @@ def run_multi_turn_rollout(
tokenized_obs = tokenizer(
env_obs_content, return_tensors="pt", add_special_tokens=False
).input_ids[0]
+ # tokenizer returns torch.float32 when env_obs_content is empty
+ tokenized_obs = tokenized_obs.to(dtype=torch.int64)
# check if new message overflows max_seq_len
if (
diff --git a/tests/unit/environments/test_code_environment.py b/tests/unit/environments/test_code_environment.py
new file mode 100644
index 0000000000..27ada6d9bf
--- /dev/null
+++ b/tests/unit/environments/test_code_environment.py
@@ -0,0 +1,220 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tempfile import TemporaryDirectory
+
+import pytest
+import ray
+from transformers import AutoTokenizer
+
+from nemo_rl.distributed.batched_data_dict import BatchedDataDict
+from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
+from nemo_rl.environments.code_environment import (
+ CodeEnvConfig,
+ CodeEnvironment,
+ CodeEnvMetadata,
+)
+from nemo_rl.experience.rollouts import run_multi_turn_rollout
+from nemo_rl.models.generation import configure_generation_config
+from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration
+
+MODEL_NAME = "meta-llama/Llama-3.2-1B"
+
+cfg: CodeEnvConfig = {
+ "num_workers": 2,
+ "terminate_on_evaluation": True,
+}
+
+# Define basic vLLM test config
+basic_vllm_test_config: VllmConfig = {
+ "backend": "vllm",
+ "model_name": MODEL_NAME,
+ "tokenizer_name": None,
+ "dtype": "bfloat16",
+ "max_new_tokens": 100,
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "top_k": None,
+ "stop_token_ids": None,
+ "stop_strings": None,
+ "vllm_cfg": {
+ "async_engine": False,
+ "precision": "bfloat16",
+ "tensor_parallel_size": 1,
+ "pipeline_parallel_size": 1,
+ "max_model_len": 1024,
+ "disable_log_stats": True,
+ "disable_log_requests": True,
+ "gpu_memory_utilization": 0.6,
+ "enforce_eager": "False",
+ },
+ "colocated": {
+ "enabled": True,
+ "resources": {
+ "gpus_per_node": None,
+ "num_nodes": None,
+ },
+ },
+}
+
+
+@pytest.fixture(scope="function")
+def code_env():
+ """Create a code environment for testing."""
+ try:
+ env_actor = CodeEnvironment.remote(cfg)
+ yield env_actor
+ finally:
+ if env_actor:
+ ray.kill(env_actor)
+
+
+@pytest.fixture(scope="function")
+def tokenizer():
+ """Loads the tokenizer for the tests."""
+ print(f"Loading tokenizer: {MODEL_NAME}")
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ print(
+ f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})"
+ )
+ return tokenizer
+
+
+@pytest.fixture(scope="function")
+def cluster():
+ """Create a virtual cluster for testing."""
+ cluster_instance = None
+ cluster_name = f"test-code-cluster-{id(cluster_instance)}"
+ print(f"\nCreating virtual cluster '{cluster_name}'...")
+ try:
+ cluster_instance = RayVirtualCluster(
+ name=cluster_name,
+ bundle_ct_per_node_list=[1],
+ use_gpus=True,
+ num_gpus_per_node=1,
+ max_colocated_worker_groups=2,
+ )
+ yield cluster_instance
+ finally:
+ print(f"\nCleaning up cluster '{cluster_name}'...")
+ if cluster_instance:
+ cluster_instance.shutdown()
+
+
+def test_untrusted_code(code_env):
+ """Test whether the code environment can block untrusted code."""
+ codes = [
+ "with open('allowed_file.txt', 'w') as fout:\n"
+ " fout.write('some content')\n"
+ "with open('allowed_file.txt') as fin:\n"
+ " content = fin.read()\n"
+ "content",
+ "with open('/etc/passwd', 'r') as fin:\n fin.read()",
+ "import math\nround(math.sqrt(8))",
+ "import os",
+ ]
+ results = [
+ "\n\n\n'some content'\n",
+ "\n\n\nPermissionError('Access beyond the temporary working directory is blocked')\n",
+ "\n\n\n3\n",
+ "PermissionError('Importing system and network modules is blocked')",
+ ]
+
+ message_log_batch = [
+ [{"role": "user", "content": f"{code}"}] for code in codes
+ ]
+ temp_dirs = [TemporaryDirectory() for _ in codes]
+ metadata_batch = [
+ CodeEnvMetadata(
+ context={},
+ working_dir=temp_dir.name,
+ )
+ for temp_dir in temp_dirs
+ ]
+
+ # Execute the code
+ output = ray.get(code_env.step.remote(message_log_batch, metadata_batch))
+ responses = [obs["content"] for obs in output.observations]
+
+ assert responses == results, f"Got wrong output {responses}"
+
+
+@pytest.mark.hf_gated
+def test_vllm_execute_code(cluster, tokenizer, code_env):
+ """Test that vLLM can call the code executor."""
+ # Prepare test data
+ codes = [
+ "x = 3; y = 4\nThis is some regular text.\nx + y\n",
+ "\ndef f(x):\n return x * x\n\nf(2)\n\n",
+ ]
+ results = ["7", "\n\n4\n"]
+
+ # Create message logs
+ message_logs = []
+ metadata_batch = []
+ temp_dirs = []
+ for code in codes:
+ # Tokenize the message content
+ prompt = code * 4
+ token_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)[
+ "input_ids"
+ ][0]
+ temp_dir = TemporaryDirectory()
+ message_logs.append(
+ [{"role": "user", "content": prompt, "token_ids": token_ids}]
+ )
+ metadata_batch.append(CodeEnvMetadata(context={}, working_dir=temp_dir.name))
+ temp_dirs.append(temp_dir)
+
+ # Create initial batch
+ initial_batch = BatchedDataDict(
+ {
+ "message_log": message_logs,
+ "extra_env_info": metadata_batch,
+ "task_name": ["code_execution"] * len(codes),
+ "stop_strings": [[""]] * len(codes),
+ }
+ )
+
+ # Create vLLM generation
+ vllm_config = basic_vllm_test_config.copy()
+ vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True)
+ vllm_generation = VllmGeneration(cluster, vllm_config)
+
+ # Create code environment
+ task_to_env = {"code_execution": code_env}
+
+ # Run rollout
+ vllm_generation.prepare_for_generation()
+ final_batch, _ = run_multi_turn_rollout(
+ policy_generation=vllm_generation,
+ input_batch=initial_batch,
+ tokenizer=tokenizer,
+ task_to_env=task_to_env,
+ max_seq_len=256,
+ max_rollout_turns=2,
+ greedy=True,
+ )
+ vllm_generation.finish_generation()
+
+ # Check results
+ for i, msg_log in enumerate(final_batch["message_log"]):
+ # Get the last message which should contain the result
+ last_msg = msg_log[-1]
+ assert last_msg["role"] == "environment"
+ assert last_msg["content"] == results[i], (
+ f"Expected {results[i]}, got {last_msg['content']}"
+ )
diff --git a/tests/unit/environments/test_retriever.py b/tests/unit/environments/test_retriever.py
new file mode 100644
index 0000000000..824c09b041
--- /dev/null
+++ b/tests/unit/environments/test_retriever.py
@@ -0,0 +1,179 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import ray
+from transformers import AutoTokenizer
+
+from nemo_rl.distributed.batched_data_dict import BatchedDataDict
+from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
+from nemo_rl.environments.tools.retriever import RAGEnvConfig, RAGEnvironment
+from nemo_rl.experience.rollouts import run_multi_turn_rollout
+from nemo_rl.models.generation import configure_generation_config
+from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration
+
+MODEL_NAME = "meta-llama/Llama-3.2-1B"
+
+cfg: RAGEnvConfig = {
+ "dataset_name": "rahular/simple-wikipedia",
+ "dataset_split": "train",
+ "text_column": "text",
+ "num_results": 1,
+ "k1": 1.5,
+ "b": 0.75,
+ "device": "cpu",
+}
+
+# Define basic vLLM test config
+basic_vllm_test_config: VllmConfig = {
+ "backend": "vllm",
+ "model_name": MODEL_NAME,
+ "tokenizer_name": None,
+ "dtype": "bfloat16",
+ "max_new_tokens": 100,
+ "temperature": 1.0,
+ "top_p": 1.0,
+ "top_k": None,
+ "stop_token_ids": None,
+ "stop_strings": None,
+ "vllm_cfg": {
+ "async_engine": False,
+ "precision": "bfloat16",
+ "tensor_parallel_size": 1,
+ "pipeline_parallel_size": 1,
+ "max_model_len": 1024,
+ "disable_log_stats": True,
+ "disable_log_requests": True,
+ "gpu_memory_utilization": 0.6,
+ "enforce_eager": "False",
+ },
+ "colocated": {
+ "enabled": True,
+ "resources": {
+ "gpus_per_node": None,
+ "num_nodes": None,
+ },
+ },
+}
+
+
+@pytest.fixture(scope="function")
+def rag_env():
+ """Create a RAG environment for testing."""
+ try:
+ env_actor = RAGEnvironment.remote(cfg)
+ yield env_actor
+ finally:
+ if env_actor:
+ ray.kill(env_actor)
+
+
+@pytest.fixture(scope="function")
+def tokenizer():
+ """Loads the tokenizer for the tests."""
+ print(f"Loading tokenizer: {MODEL_NAME}")
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ print(
+ f"Tokenizer loaded. Pad token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id}), EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})"
+ )
+ return tokenizer
+
+
+@pytest.fixture(scope="function")
+def cluster():
+ """Create a virtual cluster for testing."""
+ cluster_instance = None
+ cluster_name = f"test-rag-cluster-{id(cluster_instance)}"
+ print(f"\nCreating virtual cluster '{cluster_name}'...")
+ try:
+ cluster_instance = RayVirtualCluster(
+ name=cluster_name,
+ bundle_ct_per_node_list=[1],
+ use_gpus=True,
+ num_gpus_per_node=1,
+ max_colocated_worker_groups=2,
+ )
+ yield cluster_instance
+ finally:
+ print(f"\nCleaning up cluster '{cluster_name}'...")
+ if cluster_instance:
+ cluster_instance.shutdown()
+
+
+@pytest.mark.hf_gated
+def test_vllm_retrieve(cluster, tokenizer, rag_env):
+ """Test that vLLM can use the RAG environment for document retrieval."""
+ # Prepare test data
+ queries = [
+ "Jen-Hsun Huang\n",
+ ]
+ expected_results = [
+ "\n<1>\n"
+ "Nvidia was established in 1993 by Jen-Hsun Huang, Curtis Priem, and Chris Malachowsky. In 2000 Nvidia took intellectual possession of 3dfx, one of the biggest GPU producers in 1990s.\n"
+ "1>\n\n",
+ ]
+
+ # Create message logs
+ message_logs = []
+ for query in queries:
+ # Tokenize the message content
+ prompt = query * 4
+ token_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)[
+ "input_ids"
+ ][0]
+ message_logs.append(
+ [{"role": "user", "content": prompt, "token_ids": token_ids}]
+ )
+
+ # Create initial batch
+ initial_batch = BatchedDataDict(
+ {
+ "message_log": message_logs,
+ "extra_env_info": [{}] * len(queries), # No metadata needed for RAG
+ "task_name": ["document_retrieval"] * len(queries),
+ "stop_strings": [[""]] * len(queries),
+ }
+ )
+
+ # Create vLLM generation
+ vllm_config = basic_vllm_test_config.copy()
+ vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True)
+ vllm_generation = VllmGeneration(cluster, vllm_config)
+
+ # Create RAG environment
+ task_to_env = {"document_retrieval": rag_env}
+
+ # Run rollout
+ vllm_generation.prepare_for_generation()
+ final_batch, _ = run_multi_turn_rollout(
+ policy_generation=vllm_generation,
+ input_batch=initial_batch,
+ tokenizer=tokenizer,
+ task_to_env=task_to_env,
+ max_seq_len=256,
+ max_rollout_turns=1,
+ greedy=True,
+ )
+ vllm_generation.finish_generation()
+
+ # Check results
+ for i, msg_log in enumerate(final_batch["message_log"]):
+ # Get the last message which should contain the result
+ last_msg = msg_log[-1]
+ assert last_msg["role"] == "environment"
+ assert last_msg["content"] == expected_results[i], (
+ f"Expected {expected_results[i]}, got {last_msg['content']}"
+ )