From 16640e8ff7f0f97794053fba5d400aed6d806795 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Wed, 17 Dec 2025 23:18:04 +0000 Subject: [PATCH 01/51] nemo gym integration Signed-off-by: Christian Munley --- examples/scripts/nemo_gym/README.md | 27 ++ examples/scripts/nemo_gym/config.yaml | 29 ++ examples/scripts/nemo_gym/train.py | 646 ++++++++++++++++++++++++++ trl/scripts/vllm_serve.py | 197 +++++++- trl/trainer/grpo_trainer.py | 9 +- 5 files changed, 905 insertions(+), 3 deletions(-) create mode 100644 examples/scripts/nemo_gym/README.md create mode 100644 examples/scripts/nemo_gym/config.yaml create mode 100644 examples/scripts/nemo_gym/train.py diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md new file mode 100644 index 00000000000..74d2ef61ff0 --- /dev/null +++ b/examples/scripts/nemo_gym/README.md @@ -0,0 +1,27 @@ +# NeMo Gym TRL GRPO integration + +Multi-step GRPO with TRL and NeMo Gym. + +## Setup + +1. Launch vLLM server: +```bash +CUDA_VISIBLE_DEVICES=4,5,6,7 trl vllm-serve \ + --model Qwen/Qwen3-4B-Instruct-2507 \ + --tensor-parallel-size 4 \ + --max-model-len 8192 \ + --trust-remote-code +``` + +2. Start NeMo Gym servers +``` +ng_run "+config_paths=[resources_servers/workplace_assistant/configs/workplace_assistant.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]" +``` + + +3. Run training: +```bash +CUDA_VISIBLE_DEVICES=0 python train.py --config config.yaml +``` + +can do dp=7 with 7/8 gpus for vllm server. Havent gotten multigpu training backend to work despite docs says it works https://huggingface.co/docs/trl/main/en/vllm_integration#modes-of-using-vllm-during-training \ No newline at end of file diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml new file mode 100644 index 00000000000..35e137fb583 --- /dev/null +++ b/examples/scripts/nemo_gym/config.yaml @@ -0,0 +1,29 @@ +model_name: "Qwen/Qwen3-4B-Instruct-2507" +dataset_path: "train-workplace.jsonl" + +agent_name: "simple_agent" + +output_dir: "outputs/trl_nemo_gym_workplace" +run_name: "workplace_assistant_qwen3_4b_instruct_2507" +project_name: "cmunley-nemo-gym-trl-int" + +learning_rate: 1.0e-5 +max_steps: 1000000 + +# these params are confusing! i just want to set responses per prompt (num_generations), prompts per step, and global batch size like Nemo RL. +# i think the below is doing: rpp 16, pps 8, gbs 128 +num_generations: 16 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 128 + +max_seq_length: 30000 + +temperature: 1.0 +top_p: 0.999 +weight_decay: 0.01 +warmup_ratio: 0.0 +lr_scheduler_type: "linear" +optim: "adamw_8bit" + +save_steps: 900000 +report_to: "wandb" \ No newline at end of file diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py new file mode 100644 index 00000000000..dd0b22c76e8 --- /dev/null +++ b/examples/scripts/nemo_gym/train.py @@ -0,0 +1,646 @@ +import os +import sys +import numpy as np + +# trl_path = os.path.join(os.path.dirname(__file__), "..", "trl") +# if os.path.exists(trl_path): +# sys.path.insert(0, trl_path) +# print(f"Using local TRL from: {trl_path}") +# else: +# from trl import GRPOConfig, GRPOTrainer +from trl import GRPOConfig, GRPOTrainer + +import argparse +import asyncio +import aiohttp +import json +import yaml +import requests +from omegaconf import OmegaConf +from typing import Any, Dict, List, Optional +from dataclasses import dataclass + +from datasets import Dataset, load_dataset +from trl import GRPOConfig, GRPOTrainer + +from transformers import AutoTokenizer + + +def get_agent_server( + head_server_host: str = "127.0.0.1", + head_server_port: int = 11000, + agent_name: str = None, +) -> str: + try: + response = requests.get( + f"http://{head_server_host}:{head_server_port}/global_config_dict_yaml", + timeout=10 + ) + response.raise_for_status() + global_config_yaml = response.text + global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) + + if agent_name: + for project_name, project_config in global_config_dict.items(): + if hasattr(project_config, 'responses_api_agents'): + agents = project_config.responses_api_agents + if hasattr(agents, agent_name): + agent_config = getattr(agents, agent_name) + agent_server = f"http://{agent_config.host}:{agent_config.port}" + return agent_server + + raise ValueError(f"Agent '{agent_name}' not found in any project's responses_api_agents") + + # If no agent_name specified, try to find it + for project_name, project_config in global_config_dict.items(): + if hasattr(project_config, 'responses_api_agents'): + agents = project_config.responses_api_agents + for name in agents.keys(): + agent_config = getattr(agents, name) + if hasattr(agent_config, 'host') and hasattr(agent_config, 'port'): + agent_server = f"http://{agent_config.host}:{agent_config.port}" + return agent_server + + raise ValueError("No agents found in global config") + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") + + +@dataclass +class TrainingConfig: + model_name: str + dataset_path: str + + agent_name: Optional[str] = None + + learning_rate: float = 5e-6 + max_steps: int = 100 + num_generations: int = 2 + per_device_train_batch_size: int = 2 + gradient_accumulation_steps: int = 16 + + max_seq_length: int = 1024 + max_prompt_length: int = None + + temperature: float = 1.0 + top_p: float = 1.0 + weight_decay: float = 0.01 + warmup_ratio: float = 0.1 + lr_scheduler_type: str = "linear" + optim: str = "adamw_8bit" + + output_dir: str = "outputs/trl_nemo_gym" + save_steps: int = 100 + report_to: str = "none" + run_name: str = None # Wandb + project_name: str = None # Wandb + +def reward_fn(completions: List[str], **kwargs) -> List[float]: + env_rewards = kwargs.get("env_reward", []) + + if not env_rewards: + print(f"WARNING: No rewards from Nemo Gym, returning zeros for {len(completions)} completions") + return [0.0] * len(completions) + + print(f"Received {len(env_rewards)} rewards from Nemo Gym") + print(f"Mean reward: {sum(env_rewards)/len(env_rewards):.3f}") + print(f"Reward std dev: {np.std(env_rewards):.3f}") + print(f"Min/max reward: {min(env_rewards):.3f}/{max(env_rewards):.3f}") + + return [float(r) for r in env_rewards] + + +def get_tool_result_tokens_via_eos( + tokenizer, + seen_token_ids: List[int], + new_prompt_ids: List[int], +) -> List[int]: + """ + Extract tool result tokens when simple prefix-slicing fails due to retokenization. + + The last EOS in seen_token_ids marks where the + previous model generation ended. Find that same EOS in new_prompt_ids, then return + everything after it (the new tool results / user messages). + + Based on NeMo RL's _replace_prefix_tokens: + https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + """ + if not seen_token_ids or not new_prompt_ids: + return [] + + eos_token_id = tokenizer.eos_token_id + assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!" + + # Find last EOS in new_prompt_ids within the "prefix" region (up to len(seen_token_ids)) + # search backwards from the prefix boundary + # EOS marks where the previous model generation ended + new_eos_pos = -1 + search_bound = min(len(seen_token_ids), len(new_prompt_ids)) + for pos in reversed(range(search_bound)): + if new_prompt_ids[pos] == eos_token_id: + new_eos_pos = pos + break + + if new_eos_pos < 0: + return [] + + new_content_start = new_eos_pos + 1 + if new_content_start >= len(new_prompt_ids): + return [] + + return new_prompt_ids[new_content_start:] + + +async def call_nemo_gym_agent( + prompts: List[str], + dataset_items: List[Dict[str, Any]], + agent_server: str, + timeout: float, + max_completion_length: int = 4096, + temperature: float = 1.0, + top_p: float = 0.999, +) -> List[Dict[str, Any]]: + print(f"Calling Nemo Gym agent: {agent_server}") + print(f"Number of prompts: {len(prompts)}") + print(f"Max completion length: {max_completion_length}") + + async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session: + tasks = [] + for i, (prompt, item) in enumerate(zip(prompts, dataset_items)): + request_body = item.copy() + + if "responses_create_params" not in request_body: + request_body["responses_create_params"] = { + "input": [{"role": "user", "content": prompt}], + } + + params = request_body["responses_create_params"] + params.setdefault("max_output_tokens", max_completion_length) + params["temperature"] = temperature + params["top_p"] = top_p + + if i == 0: + print(f"First request keys: {list(params.keys())}") + + task = session.post( + f"{agent_server}/run", + json=request_body, + timeout=aiohttp.ClientTimeout(total=timeout), + ) + tasks.append(task) + + print(f"Awaiting {len(tasks)} HTTP requests...") + responses = await asyncio.gather(*tasks, return_exceptions=True) + print(f"Got {len(responses)} responses") + + results = [] + for i, response in enumerate(responses): + if isinstance(response, Exception): + print(f"WARNING: Request {i} failed: {response}") + results.append({"response": {"output": []}, "reward": 0.0, "error": str(response)}) + else: + try: + json_data = await response.json() + if isinstance(json_data, dict): + results.append(json_data) + else: + print(f"WARNING: Request {i} returned non-dict: {type(json_data)}") + results.append({"response": {"output": []}, "reward": 0.0, "error": f"Non-dict response"}) + except Exception as e: + print(f"WARNING: Failed to parse response {i}: {e}") + results.append({"response": {"output": []}, "reward": 0.0, "error": str(e)}) + + return results + + +def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, List]: + """ + Rollout function for Nemo Gym agent within TRL GRPOTrainer + + Builds interleaved action/observation sequence with masking of observations. + - prompt_ids: first turn's prompt only + - completion_ids: interleaved [model_gen1, tool_result1, model_gen2, tool_result2, ...] + - completion_mask: 1 for model tokens, 0 for tool results + - logprobs: for model tokens, 0.0 for tool result tokens + + This ensures: + 1. Logprobs are computed on the full context, including tool results + 2. Loss is only backpropagated on model-generated tokens + """ + + current_step = trainer.state.global_step if hasattr(trainer, 'state') else 0 + + print(f"\n{'='*80}") + print(f"[nemo_gym_rollout_func] Starting Nemo Gym rollout (Training Step: {current_step})") + print(f"[nemo_gym_rollout_func] Received {len(prompts)} prompts from TRL") + print(f"[nemo_gym_rollout_func] Num generations per prompt: {trainer.args.num_generations}") + + unique_prompts_set = set(prompts) + print(f"DEBUG: Number of unique prompts in input: {len(unique_prompts_set)}") + print(f"DEBUG: Total number prompts: {len(prompts)}") + + print(f"\nDEBUG: All unique prompts ({len(unique_prompts_set)} total):") + for i, prompt in enumerate(sorted(list(unique_prompts_set))[:10]): + print(f" [{i}] {prompt}") + + if len(unique_prompts_set) > 10: + print(f" ... and {len(unique_prompts_set) - 10} more unique prompts") + + print(f"{'='*80}\n") + + num_generations = trainer.args.num_generations + print(f"[nemo_gym_rollout_func] Expanding prompts for {num_generations} generations per prompt...") + + expanded_prompts = [] + expanded_dataset_items = [] + + for prompt in prompts: + matching_item = None + for item in trainer.train_dataset: + if item.get("prompt") == prompt: + matching_item = dict(item) + # Deserialize JSON strings back to dicts/lists + for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth"]: + if key in matching_item and isinstance(matching_item[key], str): + try: + matching_item[key] = json.loads(matching_item[key]) + except: + pass + break + + if not matching_item: + print(f"WARNING: Could not find dataset item for prompt, using prompt only") + matching_item = {"prompt": prompt} + + for _ in range(num_generations): + expanded_prompts.append(prompt) + expanded_dataset_items.append(dict(matching_item)) + + print(f"[nemo_gym_rollout_func] Expanded to {len(expanded_prompts)} total requests ({len(prompts)} prompts × {num_generations} generations)") + + print("[nemo_gym_rollout_func] Calling Nemo Gym agent...") + print(f"[nemo_gym_rollout_func] Using temperature: {trainer.args.temperature}, top_p: {trainer.args.top_p}") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + responses = loop.run_until_complete( + call_nemo_gym_agent( + expanded_prompts, + expanded_dataset_items, + trainer.args.agent_server, + trainer.args.request_timeout, + trainer.args.max_completion_length, + temperature=trainer.args.temperature, + top_p=trainer.args.top_p, + ) + ) + finally: + loop.close() + + print(f"[nemo_gym_rollout_func] Received {len(responses)} responses from Nemo Gym") + + # Save trajectories to JSONL + trajectory_file = os.path.join(trainer.args.output_dir, "trajectories.jsonl") + os.makedirs(trainer.args.output_dir, exist_ok=True) + + with open(trajectory_file, 'a') as f: + for i, response in enumerate(responses): + trajectory_data = { + "step": current_step, + "rollout_idx": i, + "reward": response.get("reward", 0.0) if isinstance(response, dict) else 0.0, + "output": response.get("response", {}).get("output", []) if isinstance(response, dict) else [], + "error": response.get("error") if isinstance(response, dict) else str(response), + } + f.write(json.dumps(trajectory_data) + "\n") + + print(f"[Rollout] Saved {len(responses)} trajectories to {trajectory_file}") + + tokenizer = AutoTokenizer.from_pretrained(trainer.model.name_or_path) + + # interleaved completion with mask + prompt_ids: List[List[int]] = [] + completion_ids: List[List[int]] = [] + completion_mask: List[List[int]] = [] # 1 for model tokens, 0 for tool results + logprobs: List[List[float]] = [] + env_rewards: List[float] = [] + + failed_count = 0 + success_count = 0 + + for i, response in enumerate(responses): + if not isinstance(response, dict): + raise ValueError(f"Rollout {i} response is not a dict: {type(response)}") + + if "error" in response: + raise ValueError(f"Rollout {i} had error: {response['error']}") + + episode_reward = response.get("reward", 0.0) + output_items = response.get("response", {}).get("output", []) + + # Build interleaved completion: [model_gen1, tool_result1, model_gen2, tool_result2, ...] + # with mask: 1 for model tokens (train), 0 for tool results (don't train) + # Each turn gives us (prompt_ids, gen_ids). The prompt grows each turn as tool results + # are appended. We extract tool_result = current_prompt - previous_seen_tokens. + # trying to implement the same logic as NeMo RL's _replace_prefix_tokens in RL/nemo_rl/models/generation/vllm/vllm_worker_async.py + # for less token id mismatch and logprop error + + seen_token_ids: List[int] = [] + interleaved_completion: List[int] = [] + interleaved_mask: List[int] = [] + interleaved_logprobs: List[float] = [] + first_prompt = None + num_turns = 0 + + for item in output_items: + if "prompt_token_ids" not in item or "generation_token_ids" not in item: + continue + + num_turns += 1 + item_prompt_ids = item["prompt_token_ids"] + item_gen_ids = item["generation_token_ids"] + item_logprobs = item.get("generation_log_probs", []) + tool_result_tokens = [] + + if first_prompt is None: + first_prompt = item_prompt_ids + seen_token_ids = list(item_prompt_ids) + else: + # extract tool result tokens (delta between prompts) + if len(item_prompt_ids) > len(seen_token_ids): + if item_prompt_ids[:len(seen_token_ids)] == seen_token_ids: + # Simple case: prefix matches, just slice off the new tokens + tool_result_tokens = item_prompt_ids[len(seen_token_ids):] + else: + # Retokenization changed the prefix - use nemo RL _replace_prefix_tokens approach + tool_result_tokens = get_tool_result_tokens_via_eos( + tokenizer, seen_token_ids, item_prompt_ids + ) + if tool_result_tokens: + print(f"[Turn {num_turns}] Using nemo RL _replace_prefix_tokens approach to extract observation/tool result tokens: {len(tool_result_tokens)} observation tokens") + else: + print(f"[Turn {num_turns}] WARNING: Could not extract observation tokens") + + # Append tool results (mask=0) + if tool_result_tokens: + interleaved_completion.extend(tool_result_tokens) + interleaved_mask.extend([0] * len(tool_result_tokens)) + interleaved_logprobs.extend([0.0] * len(tool_result_tokens)) + + # Append model generation (mask=1) + interleaved_completion.extend(item_gen_ids) + interleaved_mask.extend([1] * len(item_gen_ids)) + interleaved_logprobs.extend( + item_logprobs if len(item_logprobs) == len(item_gen_ids) else [0.0] * len(item_gen_ids) + ) + + if tool_result_tokens: + seen_token_ids = seen_token_ids + tool_result_tokens + list(item_gen_ids) + else: + seen_token_ids = list(item_prompt_ids) + list(item_gen_ids) + + if not interleaved_completion or first_prompt is None: + raise ValueError(f"Rollout {i} has no valid turns") + + + success_count += 1 + + prompt_ids.append(first_prompt) + completion_ids.append(interleaved_completion) + completion_mask.append(interleaved_mask) + logprobs.append(interleaved_logprobs) + env_rewards.append(episode_reward) + + model_tokens = sum(interleaved_mask) + tool_tokens = len(interleaved_mask) - model_tokens + + print(f"\n{'='*60}") + print(f"[nemo_gym_rollout_func] Turns: {num_turns}, Reward: {episode_reward:.3f}") + print(f"[nemo_gym_rollout_func] Prompt tokens: {len(first_prompt)}") + print(f"[nemo_gym_rollout_func] Completion tokens: {len(interleaved_completion)} (model: {model_tokens}, tool: {tool_tokens})") + print(f"[nemo_gym_rollout_func] Completion preview: {tokenizer.decode(interleaved_completion)[:150]}...") + print(f"{'='*60}\n") + + print(f"\n{'='*80}") + print(f"[nemo_gym_rollout_func] Success: {success_count}, Failed: {failed_count}") + print(f"[nemo_gym_rollout_func] Total episodes: {len(completion_ids)}") + + if not prompt_ids: + raise RuntimeError( + "No valid rollouts. Check Nemo Gym and vLLM logs." + ) + + mean_reward = sum(env_rewards) / len(env_rewards) if env_rewards else 0.0 + total_model_tokens = sum(sum(m) for m in completion_mask) + total_tool_tokens = sum(len(m) - sum(m) for m in completion_mask) + print(f"[nemo_gym_rollout_func] Mean reward: {mean_reward:.3f}") + print(f"[nemo_gym_rollout_func] Total model generation tokens (not masked): {total_model_tokens}") + print(f"[nemo_gym_rollout_func] Total tool tokens (masked): {total_tool_tokens}") + + # We need to deduplicate prompt_ids to match TRL's current code that re-duplicates prompts + # TRL deduplicates here: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1266 so we had to duplicate prompts for num_generations + # TRL reduplicates here: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1314 so we need to dedup prompts + print(f"[nemo_gym_rollout_func] Deduplicating prompt_ids (keeping 1 per {num_generations} completions)...") + unique_prompt_ids = [] + for idx in range(0, len(prompt_ids), num_generations): + if idx < len(prompt_ids): + unique_prompt_ids.append(prompt_ids[idx]) + + print(f"[nemo_gym_rollout_func] Deduplicated: {len(prompt_ids)} → {len(unique_prompt_ids)} unique prompt_ids") + print(f"[nemo_gym_rollout_func] Final counts: {len(unique_prompt_ids)} prompt_ids, {len(completion_ids)} completion_ids") + print(f"[nemo_gym_rollout_func] Expected ratio: {len(completion_ids) / len(unique_prompt_ids) if unique_prompt_ids else 0:.1f} completions per prompt") + print(f"{'='*80}\n") + + return { + "prompt_ids": unique_prompt_ids, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "logprobs": logprobs, + "env_reward": env_rewards, + } + +def get_max_prompt_length(dataset: Dataset, tokenizer) -> int: + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + prompt_lengths = [len(tokenizer.encode(item.get("prompt", ""))) for item in dataset if item.get("prompt", "")] + prompt_lengths.sort() + max_length = prompt_lengths[-1] + print(f"[get_max_prompt_length] Min length: {min(prompt_lengths)}") + print(f"[get_max_prompt_length] Max length: {max(prompt_lengths)}") + print(f"[get_max_prompt_length] Mean length: {sum(prompt_lengths) / len(prompt_lengths):.1f}") + return max_length + + +def load_dataset_from_jsonl(path: str) -> Dataset: + # TODO: standardize nemo gym dataset format or only accept 1 here (instructions field, answer field, jsonl structure...) + data = [] + with open(path, 'r') as f: + for line in f: + if line.strip(): + item = json.loads(line) + + # Extract prompt before serializing + if "prompt" not in item: + if "responses_create_params" in item and isinstance(item["responses_create_params"], dict): + responses_params = item["responses_create_params"] + input_data = responses_params.get("input") + instructions = responses_params.get("instructions", "") + + # Handle both message list format and string format + if isinstance(input_data, list) and len(input_data) > 0: + # Format as messages (e.g. reasoning_gym) + prompt_parts = [] + if instructions: + prompt_parts.append(f"system: {instructions}") + for msg in input_data: + if isinstance(msg, dict) and "role" in msg and "content" in msg: + prompt_parts.append(f"{msg['role']}: {msg['content']}") + item["prompt"] = "\n\n".join(prompt_parts) if prompt_parts else "" + elif isinstance(input_data, str): + # Format as string (e.g. google_search) + # Combine instructions field (system prompt) + input field (question) + prompt_parts = [] + if instructions: + prompt_parts.append(instructions) + if input_data: + prompt_parts.append(input_data) + item["prompt"] = "\n\n".join(prompt_parts) if prompt_parts else "" + else: + item["prompt"] = item.get("question", "") + else: + item["prompt"] = item.get("question", "") + + # Serialize problematic nested structures to JSON strings + for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth"]: + if key in item and isinstance(item[key], (dict, list)): + item[key] = json.dumps(item[key]) + + data.append(item) + + print(f"Loaded {len(data)} examples from {path}") + + if len(data) < 100: + repeat_factor = 100 + print(f"Repeating dataset {repeat_factor}x: {len(data)} -> {len(data) * repeat_factor}") + data = data * repeat_factor + + dataset = Dataset.from_list(data) + # dataset = dataset.shuffle(seed=42) + + return dataset + + +def main(): + parser = argparse.ArgumentParser(description="") + parser.add_argument("--config", required=True, help="Path to config YAML file") + args = parser.parse_args() + + with open(args.config) as f: + config = TrainingConfig(**yaml.safe_load(f)) + + agent_server = get_agent_server( + head_server_host="127.0.0.1", + head_server_port=11000, + agent_name=config.agent_name, + ) + + if config.project_name: + os.environ["WANDB_PROJECT"] = config.project_name + + print(f"\n\nModel: {config.model_name}") + print(f"Dataset: {config.dataset_path}") + print(f"Nemo Gym Agent: {agent_server}") + print(f"vLLM Server: 127.0.0.1:8000") + print(f"Output dir: {config.output_dir}") + print(f"Max steps: {config.max_steps}") + print(f"Num generations: {config.num_generations}") + print(f"Batch size: {config.per_device_train_batch_size}") + print(f"Gradient accumulation: {config.gradient_accumulation_steps}") + + if config.dataset_path.endswith(('.jsonl', '.json')): + dataset = load_dataset_from_jsonl(config.dataset_path) + else: + dataset = load_dataset(config.dataset_path, split="train") + + print(f"Dataset has {len(dataset)} examples\n") + + if config.max_prompt_length is None: + config.max_prompt_length = get_max_prompt_length(dataset, config.model_name) + + training_args = GRPOConfig( + use_vllm=True, + vllm_mode="server", + vllm_server_host="127.0.0.1", + vllm_server_port=8000, + + temperature=config.temperature, + learning_rate=config.learning_rate, + weight_decay=config.weight_decay, + warmup_ratio=config.warmup_ratio, + lr_scheduler_type=config.lr_scheduler_type, + optim=config.optim, + + per_device_train_batch_size=config.per_device_train_batch_size, + gradient_accumulation_steps=config.gradient_accumulation_steps, + num_generations=config.num_generations, + + max_steps=config.max_steps, + save_steps=config.save_steps, + logging_steps=1, + report_to=config.report_to, + output_dir=config.output_dir, + run_name=config.run_name, # wandb + + epsilon=0.2, + loss_type="grpo", + mask_truncated_completions=True, + log_completions=False, + # wandb_log_unique_prompts=True, + + max_prompt_length=config.max_prompt_length, + max_completion_length=config.max_seq_length - config.max_prompt_length, + + shuffle_dataset=False, + ) + + training_args.agent_server = agent_server + training_args.request_timeout = 6000 + + print("\n" + "="*80) + print("GRPO Config:\n") + print(f"per_device_train_batch_size: {training_args.per_device_train_batch_size}") + print(f"gradient_accumulation_steps: {training_args.gradient_accumulation_steps}") + print(f"num_generations: {training_args.num_generations}") + print(f"steps_per_generation: {training_args.steps_per_generation if hasattr(training_args, 'steps_per_generation') else 'Not set (will default to gradient_accumulation_steps)'}") + print(f"generation_batch_size: {training_args.generation_batch_size if hasattr(training_args, 'generation_batch_size') else 'Not set (will be calculated)'}") + print(f"shuffle_dataset: {training_args.shuffle_dataset if hasattr(training_args, 'shuffle_dataset') else 'Not set (default: True)'}") + print(f"Dataset size: {len(dataset)}") + print("="*80 + "\n") + + print("Initializing GRPO Trainer...") + + trainer = GRPOTrainer( + model=config.model_name, + reward_funcs=reward_fn, + train_dataset=dataset, + rollout_func=nemo_gym_rollout_func, + args=training_args, + ) + + print("=" * 80) + print("Starting training...") + + trainer.train() + + print("=" * 80) + print("Training complete") + + output_dir = config.output_dir + "/final" + print(f"\nSaving model to {output_dir}") + trainer.save_model(output_dir) + trainer.processing_class.save_pretrained(output_dir) + + print("\nFinished saving model") + +if __name__ == "__main__": + main() diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index d5884b8290a..1858cf6e9bc 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -14,8 +14,12 @@ import argparse import base64 +import json import logging import os +import re +import time +import uuid from collections.abc import Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -628,6 +632,7 @@ class ChatRequest(BaseModel): guided_decoding_regex: str | None = None generation_kwargs: dict = field(default_factory=dict) chat_template_kwargs: dict = field(default_factory=dict) + tools: list[dict] | None = None class ChatResponse(BaseModel): prompt_ids: list[list[int]] @@ -732,7 +737,11 @@ async def chat(request: ChatRequest): "messages": messages, "sampling_params": sampling_params, "chat_template_kwargs": request.chat_template_kwargs, + "tools": request.tools if request.tools else None, + # "tool_choice": request.tool_choice, + # "parallel_tool_calls": request.parallel_tool_calls, } + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) # Receive results @@ -835,8 +844,192 @@ async def close_communicator(): connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) return {"message": "Request received, closing communicator"} - # Start the server - uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level) + class ChatCompletionRequest(BaseModel): + messages: list[dict] + model: str | None = None + temperature: float = 1.0 + top_p: float = 1.0 + max_completion_tokens: int | None = None + max_tokens: int | None = None + n: int = 1 + stop: str | list[str] | None = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + logprobs: bool = False + top_logprobs: int | None = None + tools: list[dict] | None = None + tool_choice: str | dict = "auto" + parallel_tool_calls: bool = True + + @app.post("/v1/chat/completions") + async def chat_completions(request: ChatCompletionRequest): + completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created_at = int(time.time()) + + messages = [] + for msg in request.messages: + role = msg.get("role", "") + if role not in ["system", "user", "assistant", "tool"]: + logger.warning(f"Unknown message role: {role}") + messages.append(msg) + + max_tokens = request.max_completion_tokens or request.max_tokens or 512 + + sampling_kwargs = { + "n": request.n, + "temperature": request.temperature, + "top_p": request.top_p, + "max_tokens": max_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "stop": request.stop, + } + + if request.logprobs or request.top_logprobs: + sampling_kwargs["logprobs"] = request.top_logprobs if request.top_logprobs else 1 + + sampling_params = SamplingParams(**sampling_kwargs) + + chat_template_kwargs = {} + if request.tool_choice and request.tool_choice != "auto": + chat_template_kwargs["tool_choice"] = request.tool_choice + + chunked_messages = chunk_list([messages], script_args.data_parallel_size) + + for connection, message_chunk in zip(connections, chunked_messages, strict=True): + if not message_chunk: + message_chunk = [[{"role": "user", "content": ""}]] + kwargs = { + "messages": message_chunk, + "sampling_params": sampling_params, + "tools": request.tools, + "chat_template_kwargs": chat_template_kwargs + } + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) + + all_outputs = [connection.recv() for connection in connections] + all_outputs = [output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk] + all_outputs = list(chain.from_iterable(all_outputs)) + + choices = [] + total_input_tokens = 0 + total_output_tokens = 0 + + for idx, output in enumerate(all_outputs): + total_input_tokens += len(output.prompt_token_ids) + + for gen_output in output.outputs: + total_output_tokens += len(gen_output.token_ids) + text = gen_output.text if hasattr(gen_output, "text") else "" + + tool_calls = None + finish_reason = "stop" + + if hasattr(gen_output, "tool_calls") and gen_output.tool_calls: + tool_calls = gen_output.tool_calls + finish_reason = "tool_calls" + elif request.tools and text: + # If no native tool call parser, try XML + # TODO: figure out how to use a tool call parser, or handle tool call parsing in Nemo Gym maybe?? + # or implement real async vllm engine and openai api server rather than faking the endpoint here + pattern = r'(.*?)' + matches = re.findall(pattern, text, re.DOTALL) + if matches: + tool_calls = [] + for match in matches: + try: + data = json.loads(match.strip()) + tool_calls.append({ + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": data.get("name", ""), + "arguments": json.dumps(data.get("arguments", {})) + } + }) + except json.JSONDecodeError: + continue + if tool_calls: + finish_reason = "tool_calls" + text = re.sub(pattern, "", text, flags=re.DOTALL).strip() + + if not request.parallel_tool_calls and tool_calls and len(tool_calls) > 1: + tool_calls = [tool_calls[0]] + + logprobs_data = None + if request.logprobs and hasattr(gen_output, "logprobs") and gen_output.logprobs: + logprobs_data = { + "content": [ + { + "token": str(token_id), + "logprob": float(list(logprob_dict.values())[0].logprob) if logprob_dict else 0.0, + "bytes": None, + "top_logprobs": [] + } + for token_id, logprob_dict in zip(gen_output.token_ids, gen_output.logprobs) + ] + } + + choices.append({ + "index": idx, + "message": { + "role": "assistant", + "content": text if not tool_calls else None, + "tool_calls": tool_calls + }, + "logprobs": logprobs_data, + "finish_reason": finish_reason + }) + + return { + "id": completion_id, + "object": "chat.completion", + "created": created_at, + "model": request.model or script_args.model, + "choices": choices, + "usage": { + "prompt_tokens": total_input_tokens, + "completion_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens + } + } + + class TokenizeRequest(BaseModel): + model: str | None = None + messages: list[dict] + tools: list[dict] | None = None + + @app.post("/tokenize") + async def tokenize(request: TokenizeRequest): + kwargs = { + "messages": [request.messages], + "tools": request.tools, + "add_generation_prompt": True, + "chat_template_kwargs": {} + } + + connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": kwargs}) + preprocessed_prompts = connections[0].recv() + + if preprocessed_prompts and len(preprocessed_prompts) > 1: + logger.warning(f"More than one tokenized message returned from preprocess_chat inside tokenize, double check results!") + + if preprocessed_prompts and len(preprocessed_prompts) > 0: + return { + "tokens": preprocessed_prompts[0]["prompt_token_ids"], + "model": request.model or script_args.model + } + return {"tokens": [], "model": request.model or script_args.model} + + uvicorn.run( + app, + host=script_args.host, + port=script_args.port, + log_level=script_args.log_level, + limit_concurrency=256, + backlog=4096, + timeout_keep_alive=60 + ) def make_parser(subparsers: argparse._SubParsersAction | None = None): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f5f333fa092..ebb0e100bfe 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1783,7 +1783,14 @@ def _generate_and_score_completions( prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + + # Allow custom completion_mask from rollout_func for multi-turn training + # This allows masking out non-trainable tokens (e.g., tool results, observations) in the completion + if "completion_mask" in extra_fields: + completion_mask_list = extra_fields.pop("completion_mask") + completion_mask = [torch.tensor(m, device=device, dtype=torch.long) for m in completion_mask_list] + else: + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") completion_mask = pad(completion_mask, padding_value=0, padding_side="right") if sampling_per_token_logps_list is not None: From 62617583bdc286b19674182f2189b9c72581b7b9 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Thu, 18 Dec 2025 03:36:03 +0000 Subject: [PATCH 02/51] couple updates Signed-off-by: Christian Munley --- examples/scripts/nemo_gym/config.yaml | 4 +-- examples/scripts/nemo_gym/train.py | 39 ++++++++++++++++++--------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 35e137fb583..6c84585d7de 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,10 +1,10 @@ model_name: "Qwen/Qwen3-4B-Instruct-2507" dataset_path: "train-workplace.jsonl" +task: "workplace-assistant" # used in run_name if not set agent_name: "simple_agent" output_dir: "outputs/trl_nemo_gym_workplace" -run_name: "workplace_assistant_qwen3_4b_instruct_2507" project_name: "cmunley-nemo-gym-trl-int" learning_rate: 1.0e-5 @@ -16,7 +16,7 @@ num_generations: 16 per_device_train_batch_size: 1 gradient_accumulation_steps: 128 -max_seq_length: 30000 +max_seq_length: 16384 temperature: 1.0 top_p: 0.999 diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index dd0b22c76e8..e7c22d0a2c1 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -2,12 +2,6 @@ import sys import numpy as np -# trl_path = os.path.join(os.path.dirname(__file__), "..", "trl") -# if os.path.exists(trl_path): -# sys.path.insert(0, trl_path) -# print(f"Using local TRL from: {trl_path}") -# else: -# from trl import GRPOConfig, GRPOTrainer from trl import GRPOConfig, GRPOTrainer import argparse @@ -22,6 +16,7 @@ from datasets import Dataset, load_dataset from trl import GRPOConfig, GRPOTrainer +from tqdm import tqdm from transformers import AutoTokenizer @@ -72,6 +67,7 @@ class TrainingConfig: model_name: str dataset_path: str + task: Optional[str] = None agent_name: Optional[str] = None learning_rate: float = 5e-6 @@ -106,12 +102,11 @@ def reward_fn(completions: List[str], **kwargs) -> List[float]: print(f"Received {len(env_rewards)} rewards from Nemo Gym") print(f"Mean reward: {sum(env_rewards)/len(env_rewards):.3f}") print(f"Reward std dev: {np.std(env_rewards):.3f}") - print(f"Min/max reward: {min(env_rewards):.3f}/{max(env_rewards):.3f}") return [float(r) for r in env_rewards] -def get_tool_result_tokens_via_eos( +def replace_prefix_tokens( tokenizer, seen_token_ids: List[int], new_prompt_ids: List[int], @@ -190,9 +185,12 @@ async def call_nemo_gym_agent( ) tasks.append(task) - print(f"Awaiting {len(tasks)} HTTP requests...") - responses = await asyncio.gather(*tasks, return_exceptions=True) - print(f"Got {len(responses)} responses") + responses = [] + with tqdm(total=len(tasks), desc="Agent requests") as pbar: + for coro in asyncio.as_completed(tasks): + result = await coro + responses.append(result) + pbar.update(1) results = [] for i, response in enumerate(responses): @@ -374,7 +372,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, tool_result_tokens = item_prompt_ids[len(seen_token_ids):] else: # Retokenization changed the prefix - use nemo RL _replace_prefix_tokens approach - tool_result_tokens = get_tool_result_tokens_via_eos( + tool_result_tokens = replace_prefix_tokens( tokenizer, seen_token_ids, item_prompt_ids ) if tool_result_tokens: @@ -547,6 +545,22 @@ def main(): if config.project_name: os.environ["WANDB_PROJECT"] = config.project_name + if config.run_name is None: + task = config.task or os.path.basename(config.dataset_path).replace('.jsonl', '').replace('.json', '') + model_short = config.model_name.split("/")[-1] + config.run_name = ( + f"{task}_{model_short}" + f"_ng{config.num_generations}" + f"_dbs{config.per_device_train_batch_size}" + f"_ga{config.gradient_accumulation_steps}" + f"_maxlen{config.max_seq_length}" + f"_lr{config.learning_rate}" + f"_temp{config.temperature}" + f"_topp{config.top_p}" + f"_wd{config.weight_decay}" + f"_wu{config.warmup_ratio}" + ) + print(f"\n\nModel: {config.model_name}") print(f"Dataset: {config.dataset_path}") print(f"Nemo Gym Agent: {agent_server}") @@ -592,6 +606,7 @@ def main(): run_name=config.run_name, # wandb epsilon=0.2, + epsilon_high=0.28, loss_type="grpo", mask_truncated_completions=True, log_completions=False, From 41053404d5910ea025443a2273c9419c4f641a43 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Sat, 20 Dec 2025 18:27:59 +0000 Subject: [PATCH 03/51] baseline without on policy correction Signed-off-by: Christian Munley --- examples/scripts/nemo_gym/config.yaml | 5 +--- examples/scripts/nemo_gym/train.py | 36 ++++++++++----------------- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 6c84585d7de..77244847b23 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,8 +1,7 @@ model_name: "Qwen/Qwen3-4B-Instruct-2507" dataset_path: "train-workplace.jsonl" -task: "workplace-assistant" # used in run_name if not set -agent_name: "simple_agent" +task: "workplace-assistant" # used in wandb run_name if not set explicitly output_dir: "outputs/trl_nemo_gym_workplace" project_name: "cmunley-nemo-gym-trl-int" @@ -10,8 +9,6 @@ project_name: "cmunley-nemo-gym-trl-int" learning_rate: 1.0e-5 max_steps: 1000000 -# these params are confusing! i just want to set responses per prompt (num_generations), prompts per step, and global batch size like Nemo RL. -# i think the below is doing: rpp 16, pps 8, gbs 128 num_generations: 16 per_device_train_batch_size: 1 gradient_accumulation_steps: 128 diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index e7c22d0a2c1..365d8d857c9 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -46,7 +46,7 @@ def get_agent_server( raise ValueError(f"Agent '{agent_name}' not found in any project's responses_api_agents") - # If no agent_name specified, try to find it + # If no agent_name specified, find it (usually is simple_agent) for project_name, project_config in global_config_dict.items(): if hasattr(project_config, 'responses_api_agents'): agents = project_config.responses_api_agents @@ -120,6 +120,8 @@ def replace_prefix_tokens( Based on NeMo RL's _replace_prefix_tokens: https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + + NOTE: this is old and should go in vllm_serve.py not here. """ if not seen_token_ids or not new_prompt_ids: return [] @@ -214,6 +216,8 @@ async def call_nemo_gym_agent( def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, List]: """ + Baseline implementation that is missing on policy tokenid correction (this would go in vllm_serve.py though) + Rollout function for Nemo Gym agent within TRL GRPOTrainer Builds interleaved action/observation sequence with masking of observations. @@ -258,7 +262,6 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, for item in trainer.train_dataset: if item.get("prompt") == prompt: matching_item = dict(item) - # Deserialize JSON strings back to dicts/lists for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth"]: if key in matching_item and isinstance(matching_item[key], str): try: @@ -298,7 +301,6 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, print(f"[nemo_gym_rollout_func] Received {len(responses)} responses from Nemo Gym") - # Save trajectories to JSONL trajectory_file = os.path.join(trainer.args.output_dir, "trajectories.jsonl") os.makedirs(trainer.args.output_dir, exist_ok=True) @@ -338,11 +340,9 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, output_items = response.get("response", {}).get("output", []) # Build interleaved completion: [model_gen1, tool_result1, model_gen2, tool_result2, ...] - # with mask: 1 for model tokens (train), 0 for tool results (don't train) - # Each turn gives us (prompt_ids, gen_ids). The prompt grows each turn as tool results - # are appended. We extract tool_result = current_prompt - previous_seen_tokens. - # trying to implement the same logic as NeMo RL's _replace_prefix_tokens in RL/nemo_rl/models/generation/vllm/vllm_worker_async.py - # for less token id mismatch and logprop error + # with mask: 1 for model tokens, 0 for tool results + # Each turn gives us (prompt_ids, gen_ids). + # tool_result = current_prompt - previous_seen_tokens. seen_token_ids: List[int] = [] interleaved_completion: List[int] = [] @@ -365,21 +365,12 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, first_prompt = item_prompt_ids seen_token_ids = list(item_prompt_ids) else: - # extract tool result tokens (delta between prompts) + # likely problematic due to retokenization. this is a baseline to compare against on-policy correction using _replace_prefix_tokens + # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html + # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 if len(item_prompt_ids) > len(seen_token_ids): - if item_prompt_ids[:len(seen_token_ids)] == seen_token_ids: - # Simple case: prefix matches, just slice off the new tokens - tool_result_tokens = item_prompt_ids[len(seen_token_ids):] - else: - # Retokenization changed the prefix - use nemo RL _replace_prefix_tokens approach - tool_result_tokens = replace_prefix_tokens( - tokenizer, seen_token_ids, item_prompt_ids - ) - if tool_result_tokens: - print(f"[Turn {num_turns}] Using nemo RL _replace_prefix_tokens approach to extract observation/tool result tokens: {len(tool_result_tokens)} observation tokens") - else: - print(f"[Turn {num_turns}] WARNING: Could not extract observation tokens") - + tool_result_tokens = item_prompt_ids[len(seen_token_ids):] + # Append tool results (mask=0) if tool_result_tokens: interleaved_completion.extend(tool_result_tokens) @@ -610,7 +601,6 @@ def main(): loss_type="grpo", mask_truncated_completions=True, log_completions=False, - # wandb_log_unique_prompts=True, max_prompt_length=config.max_prompt_length, max_completion_length=config.max_seq_length - config.max_prompt_length, From be5c156b410e0e5d77aca8814cb443bdd54ac750 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Sat, 20 Dec 2025 18:29:28 +0000 Subject: [PATCH 04/51] readme Signed-off-by: Christian Munley --- examples/scripts/nemo_gym/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index 74d2ef61ff0..bf686c6794f 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -24,4 +24,5 @@ ng_run "+config_paths=[resources_servers/workplace_assistant/configs/workplace_a CUDA_VISIBLE_DEVICES=0 python train.py --config config.yaml ``` -can do dp=7 with 7/8 gpus for vllm server. Havent gotten multigpu training backend to work despite docs says it works https://huggingface.co/docs/trl/main/en/vllm_integration#modes-of-using-vllm-during-training \ No newline at end of file +We should be able to do multinode, but im having issues with ngpu > 1 for training backend currently +https://huggingface.co/docs/trl/main/en/vllm_integration#modes-of-using-vllm-during-training \ No newline at end of file From 64b9ed47c7aefe59d50e63aa86285580796bf5d2 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Mon, 22 Dec 2025 06:38:14 +0000 Subject: [PATCH 05/51] wip Signed-off-by: Christian Munley --- examples/scripts/nemo_gym/config.yaml | 14 +- examples/scripts/nemo_gym/train.py | 266 ++++++++++++------------ pyproject.toml | 5 + trl/scripts/vllm_serve.py | 282 +++++++++++++++++++++++--- trl/trainer/grpo_trainer.py | 1 - 5 files changed, 409 insertions(+), 159 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 77244847b23..8d5eb3973c4 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,4 +1,7 @@ +# model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" +# model_name: "Qwen/Qwen3-0.6B" model_name: "Qwen/Qwen3-4B-Instruct-2507" +# dataset_path: "/home/ubuntu/unsloth-gym-integration/Gym/resources_servers/reasoning_gym/data/train_knights_knaves.jsonl" dataset_path: "train-workplace.jsonl" task: "workplace-assistant" # used in wandb run_name if not set explicitly @@ -6,17 +9,18 @@ task: "workplace-assistant" # used in wandb run_name if not set explicitly output_dir: "outputs/trl_nemo_gym_workplace" project_name: "cmunley-nemo-gym-trl-int" -learning_rate: 1.0e-5 +learning_rate: 2.0e-5 max_steps: 1000000 -num_generations: 16 +num_generations: 12 + per_device_train_batch_size: 1 -gradient_accumulation_steps: 128 +gradient_accumulation_steps: 192 -max_seq_length: 16384 +max_seq_length: 8192 temperature: 1.0 -top_p: 0.999 +top_p: 0.99 weight_decay: 0.01 warmup_ratio: 0.0 lr_scheduler_type: "linear" diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index 365d8d857c9..1729983c64a 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -1,9 +1,5 @@ import os -import sys import numpy as np - -from trl import GRPOConfig, GRPOTrainer - import argparse import asyncio import aiohttp @@ -13,14 +9,12 @@ from omegaconf import OmegaConf from typing import Any, Dict, List, Optional from dataclasses import dataclass - from datasets import Dataset, load_dataset from trl import GRPOConfig, GRPOTrainer from tqdm import tqdm from transformers import AutoTokenizer - def get_agent_server( head_server_host: str = "127.0.0.1", head_server_port: int = 11000, @@ -36,7 +30,7 @@ def get_agent_server( global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) if agent_name: - for project_name, project_config in global_config_dict.items(): + for _, project_config in global_config_dict.items(): if hasattr(project_config, 'responses_api_agents'): agents = project_config.responses_api_agents if hasattr(agents, agent_name): @@ -46,8 +40,7 @@ def get_agent_server( raise ValueError(f"Agent '{agent_name}' not found in any project's responses_api_agents") - # If no agent_name specified, find it (usually is simple_agent) - for project_name, project_config in global_config_dict.items(): + for _, project_config in global_config_dict.items(): if hasattr(project_config, 'responses_api_agents'): agents = project_config.responses_api_agents for name in agents.keys(): @@ -80,7 +73,7 @@ class TrainingConfig: max_prompt_length: int = None temperature: float = 1.0 - top_p: float = 1.0 + top_p: float = 0.999 weight_decay: float = 0.01 warmup_ratio: float = 0.1 lr_scheduler_type: str = "linear" @@ -99,56 +92,12 @@ def reward_fn(completions: List[str], **kwargs) -> List[float]: print(f"WARNING: No rewards from Nemo Gym, returning zeros for {len(completions)} completions") return [0.0] * len(completions) - print(f"Received {len(env_rewards)} rewards from Nemo Gym") - print(f"Mean reward: {sum(env_rewards)/len(env_rewards):.3f}") - print(f"Reward std dev: {np.std(env_rewards):.3f}") + print(f"[reward_fn] Mean reward: {sum(env_rewards)/len(env_rewards):.3f}") + print(f"[reward_fn] Reward std dev: {np.std(env_rewards):.3f}") return [float(r) for r in env_rewards] -def replace_prefix_tokens( - tokenizer, - seen_token_ids: List[int], - new_prompt_ids: List[int], -) -> List[int]: - """ - Extract tool result tokens when simple prefix-slicing fails due to retokenization. - - The last EOS in seen_token_ids marks where the - previous model generation ended. Find that same EOS in new_prompt_ids, then return - everything after it (the new tool results / user messages). - - Based on NeMo RL's _replace_prefix_tokens: - https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 - - NOTE: this is old and should go in vllm_serve.py not here. - """ - if not seen_token_ids or not new_prompt_ids: - return [] - - eos_token_id = tokenizer.eos_token_id - assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!" - - # Find last EOS in new_prompt_ids within the "prefix" region (up to len(seen_token_ids)) - # search backwards from the prefix boundary - # EOS marks where the previous model generation ended - new_eos_pos = -1 - search_bound = min(len(seen_token_ids), len(new_prompt_ids)) - for pos in reversed(range(search_bound)): - if new_prompt_ids[pos] == eos_token_id: - new_eos_pos = pos - break - - if new_eos_pos < 0: - return [] - - new_content_start = new_eos_pos + 1 - if new_content_start >= len(new_prompt_ids): - return [] - - return new_prompt_ids[new_content_start:] - - async def call_nemo_gym_agent( prompts: List[str], dataset_items: List[Dict[str, Any]], @@ -158,9 +107,9 @@ async def call_nemo_gym_agent( temperature: float = 1.0, top_p: float = 0.999, ) -> List[Dict[str, Any]]: - print(f"Calling Nemo Gym agent: {agent_server}") - print(f"Number of prompts: {len(prompts)}") - print(f"Max completion length: {max_completion_length}") + print(f"[call_nemo_gym_agent] Calling Nemo Gym agent at {agent_server}") + print(f"[call_nemo_gym_agent] len(prompts): {len(prompts)}") + print(f"[call_nemo_gym_agent] max_completion_length: {max_completion_length}") async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session: tasks = [] @@ -197,6 +146,8 @@ async def call_nemo_gym_agent( results = [] for i, response in enumerate(responses): if isinstance(response, Exception): + # should we error instead + # might be needed for truncated print(f"WARNING: Request {i} failed: {response}") results.append({"response": {"output": []}, "reward": 0.0, "error": str(response)}) else: @@ -216,8 +167,6 @@ async def call_nemo_gym_agent( def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, List]: """ - Baseline implementation that is missing on policy tokenid correction (this would go in vllm_serve.py though) - Rollout function for Nemo Gym agent within TRL GRPOTrainer Builds interleaved action/observation sequence with masking of observations. @@ -234,25 +183,9 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, current_step = trainer.state.global_step if hasattr(trainer, 'state') else 0 print(f"\n{'='*80}") - print(f"[nemo_gym_rollout_func] Starting Nemo Gym rollout (Training Step: {current_step})") - print(f"[nemo_gym_rollout_func] Received {len(prompts)} prompts from TRL") - print(f"[nemo_gym_rollout_func] Num generations per prompt: {trainer.args.num_generations}") - - unique_prompts_set = set(prompts) - print(f"DEBUG: Number of unique prompts in input: {len(unique_prompts_set)}") - print(f"DEBUG: Total number prompts: {len(prompts)}") - - print(f"\nDEBUG: All unique prompts ({len(unique_prompts_set)} total):") - for i, prompt in enumerate(sorted(list(unique_prompts_set))[:10]): - print(f" [{i}] {prompt}") - - if len(unique_prompts_set) > 10: - print(f" ... and {len(unique_prompts_set) - 10} more unique prompts") - - print(f"{'='*80}\n") - num_generations = trainer.args.num_generations - print(f"[nemo_gym_rollout_func] Expanding prompts for {num_generations} generations per prompt...") + prompts_dedup = set(prompts) + print(f"[nemo_gym_rollout_func] Got {len(prompts)} prompts, {len(prompts_dedup)} unique prompts, {trainer.args.num_generations} generations per prompt") expanded_prompts = [] expanded_dataset_items = [] @@ -262,6 +195,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, for item in trainer.train_dataset: if item.get("prompt") == prompt: matching_item = dict(item) + # Deserialize JSON strings back to dicts/lists for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth"]: if key in matching_item and isinstance(matching_item[key], str): try: @@ -274,11 +208,11 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, print(f"WARNING: Could not find dataset item for prompt, using prompt only") matching_item = {"prompt": prompt} - for _ in range(num_generations): + for _ in range(trainer.args.num_generations): expanded_prompts.append(prompt) expanded_dataset_items.append(dict(matching_item)) - print(f"[nemo_gym_rollout_func] Expanded to {len(expanded_prompts)} total requests ({len(prompts)} prompts × {num_generations} generations)") + print(f"[nemo_gym_rollout_func] Expanded to {len(expanded_prompts)} total requests ({len(prompts)} prompts × {trainer.args.num_generations} generations)") print("[nemo_gym_rollout_func] Calling Nemo Gym agent...") print(f"[nemo_gym_rollout_func] Using temperature: {trainer.args.temperature}, top_p: {trainer.args.top_p}") @@ -299,7 +233,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, finally: loop.close() - print(f"[nemo_gym_rollout_func] Received {len(responses)} responses from Nemo Gym") + print(f"[nemo_gym_rollout_func] Got {len(responses)} responses from Nemo Gym agent") trajectory_file = os.path.join(trainer.args.output_dir, "trajectories.jsonl") os.makedirs(trainer.args.output_dir, exist_ok=True) @@ -315,34 +249,80 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, } f.write(json.dumps(trajectory_data) + "\n") - print(f"[Rollout] Saved {len(responses)} trajectories to {trajectory_file}") + print(f"[nemo_gym_rollout_func] Saved {len(responses)} trajectories to {trajectory_file}") - tokenizer = AutoTokenizer.from_pretrained(trainer.model.name_or_path) + tokenizer = trainer.processing_class # interleaved completion with mask prompt_ids: List[List[int]] = [] completion_ids: List[List[int]] = [] - completion_mask: List[List[int]] = [] # 1 for model tokens, 0 for tool results + completion_mask: List[List[int]] = [] # 1 for action, 0 for observation logprobs: List[List[float]] = [] env_rewards: List[float] = [] - + num_turns_list: List[int] = [] + failed_count = 0 success_count = 0 for i, response in enumerate(responses): - if not isinstance(response, dict): - raise ValueError(f"Rollout {i} response is not a dict: {type(response)}") + expected_prompt = expanded_prompts[i] + expected_prompt_ids = tokenizer.encode(expected_prompt, add_special_tokens=False) - if "error" in response: - raise ValueError(f"Rollout {i} had error: {response['error']}") + rollout_failed = False + failure_reason = None + + if not isinstance(response, dict): + rollout_failed = True + failure_reason = f"response is not a dict: {type(response)}" + elif "error" in response: + rollout_failed = True + failure_reason = f"had error: {response['error']}" + else: + output_items = response.get("response", {}).get("output", []) + + if not output_items: + rollout_failed = True + failure_reason = "has no output items (we are masking and returing just eos for truncated rollouts right now)" + else: + has_content = False + for item in output_items: + if item.get("type") == "message": + content_list = item.get("content", []) + for content_item in content_list: + if content_item.get("type") == "output_text": + text = content_item.get("text", "").strip() + if text: + has_content = True + break + elif item.get("type") == "function_call": + has_content = True + break + if has_content: + break + + if not has_content: + rollout_failed = True + failure_reason = "has empty content in all output items (we are masking and returing just eos for truncated rollouts right now)" + + # truncated or other failure - mask out + if rollout_failed: + failed_count += 1 + print(f"[nemo_gym_rollout_func] WARNING: Rollout {i} {failure_reason}. Filling with eos and zero reward.") + eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 + prompt_ids.append(expected_prompt_ids) + completion_ids.append([eos_token_id]) + completion_mask.append([0]) + logprobs.append([0.0]) + env_rewards.append(0.0) + num_turns_list.append(0) + continue episode_reward = response.get("reward", 0.0) output_items = response.get("response", {}).get("output", []) - # Build interleaved completion: [model_gen1, tool_result1, model_gen2, tool_result2, ...] - # with mask: 1 for model tokens, 0 for tool results - # Each turn gives us (prompt_ids, gen_ids). - # tool_result = current_prompt - previous_seen_tokens. + # Make interleaved completion: [model_gen1, tool_result1, model_gen2, tool_result2, ...] with mask 1 for generations, 0 for tool results (aka action, observation) + # Each turn has prompt_ids, gen_ids. Find tool_result = current_prompt - previous_seen_tokens and mask it + # on policy tokenid correction replace_prefix_tokens is done in vllm server seen_token_ids: List[int] = [] interleaved_completion: List[int] = [] @@ -354,52 +334,83 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, for item in output_items: if "prompt_token_ids" not in item or "generation_token_ids" not in item: continue - + num_turns += 1 item_prompt_ids = item["prompt_token_ids"] item_gen_ids = item["generation_token_ids"] item_logprobs = item.get("generation_log_probs", []) tool_result_tokens = [] - + if first_prompt is None: first_prompt = item_prompt_ids seen_token_ids = list(item_prompt_ids) else: - # likely problematic due to retokenization. this is a baseline to compare against on-policy correction using _replace_prefix_tokens - # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html - # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + # Extract tool result tokens if len(item_prompt_ids) > len(seen_token_ids): + # Assert prefix matches + if item_prompt_ids[:len(seen_token_ids)] != seen_token_ids: + diverge_idx = -1 + for idx in range(min(len(seen_token_ids), len(item_prompt_ids))): + if seen_token_ids[idx] != item_prompt_ids[idx]: + diverge_idx = idx + break + + # Show context around divergence + # prob can delete this, was for debugging right + context_window = 20 + start = max(0, diverge_idx - context_window) if diverge_idx >= 0 else 0 + end = min(len(seen_token_ids), diverge_idx + context_window) if diverge_idx >= 0 else min(50, len(seen_token_ids)) + + error_msg = ( + f"[Turn {num_turns}] Non-contiguous messages found! " + f"This may be a tokenization issue.\n" + f"Length of expected prefix: {len(seen_token_ids)}\n" + f"Length of new prompt: {len(item_prompt_ids)}\n" + ) + + if diverge_idx >= 0: + error_msg += ( + f"Tokens diverge at index {diverge_idx}:\n" + f"Expected[{start}:{end}]: {seen_token_ids[start:end]}\n" + f"Got[{start}:{end}]: {item_prompt_ids[start:end]}\n" + f"Expected token at [{diverge_idx}]: {seen_token_ids[diverge_idx]}\n" + f"Got token at [{diverge_idx}]: {item_prompt_ids[diverge_idx]}\n" + ) + else: + error_msg += ( + f"Prefix length mismatch but tokens match up to min length.\n" + f"Expected (first 50): {seen_token_ids[:50]}\n" + f"Got (first 50): {item_prompt_ids[:50]}\n" + ) + + raise ValueError(error_msg) tool_result_tokens = item_prompt_ids[len(seen_token_ids):] - # Append tool results (mask=0) if tool_result_tokens: interleaved_completion.extend(tool_result_tokens) interleaved_mask.extend([0] * len(tool_result_tokens)) interleaved_logprobs.extend([0.0] * len(tool_result_tokens)) - # Append model generation (mask=1) interleaved_completion.extend(item_gen_ids) interleaved_mask.extend([1] * len(item_gen_ids)) interleaved_logprobs.extend( item_logprobs if len(item_logprobs) == len(item_gen_ids) else [0.0] * len(item_gen_ids) ) - - if tool_result_tokens: - seen_token_ids = seen_token_ids + tool_result_tokens + list(item_gen_ids) - else: - seen_token_ids = list(item_prompt_ids) + list(item_gen_ids) + + seen_token_ids = list(item_prompt_ids) + list(item_gen_ids) if not interleaved_completion or first_prompt is None: raise ValueError(f"Rollout {i} has no valid turns") success_count += 1 - + prompt_ids.append(first_prompt) completion_ids.append(interleaved_completion) completion_mask.append(interleaved_mask) logprobs.append(interleaved_logprobs) env_rewards.append(episode_reward) + num_turns_list.append(num_turns) model_tokens = sum(interleaved_mask) tool_tokens = len(interleaved_mask) - model_tokens @@ -407,38 +418,39 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, print(f"\n{'='*60}") print(f"[nemo_gym_rollout_func] Turns: {num_turns}, Reward: {episode_reward:.3f}") print(f"[nemo_gym_rollout_func] Prompt tokens: {len(first_prompt)}") - print(f"[nemo_gym_rollout_func] Completion tokens: {len(interleaved_completion)} (model: {model_tokens}, tool: {tool_tokens})") - print(f"[nemo_gym_rollout_func] Completion preview: {tokenizer.decode(interleaved_completion)[:150]}...") + print(f"[nemo_gym_rollout_func] Rollout tokens: {len(interleaved_completion)} (model: {model_tokens}, tool: {tool_tokens})") + print(f"[nemo_gym_rollout_func] Rollout preview: {tokenizer.decode(interleaved_completion)[:150]}...") print(f"{'='*60}\n") print(f"\n{'='*80}") print(f"[nemo_gym_rollout_func] Success: {success_count}, Failed: {failed_count}") - print(f"[nemo_gym_rollout_func] Total episodes: {len(completion_ids)}") if not prompt_ids: raise RuntimeError( "No valid rollouts. Check Nemo Gym and vLLM logs." ) - mean_reward = sum(env_rewards) / len(env_rewards) if env_rewards else 0.0 - total_model_tokens = sum(sum(m) for m in completion_mask) - total_tool_tokens = sum(len(m) - sum(m) for m in completion_mask) - print(f"[nemo_gym_rollout_func] Mean reward: {mean_reward:.3f}") - print(f"[nemo_gym_rollout_func] Total model generation tokens (not masked): {total_model_tokens}") - print(f"[nemo_gym_rollout_func] Total tool tokens (masked): {total_tool_tokens}") - + print(f"[nemo_gym_rollout_func] Mean reward: {sum(env_rewards) / len(env_rewards) if env_rewards else 0.0:.3f}") + print(f"[nemo_gym_rollout_func] Total action tokens: {sum(sum(m) for m in completion_mask)}, total observation tokens: {sum(len(m) - sum(m) for m in completion_mask)}") + + # log num turns to wandb + if num_turns_list: + import wandb + wandb.log({ + "train/num_turns_mean": sum(num_turns_list) / len(num_turns_list), + "train/num_turns_min": min(num_turns_list), + "train/num_turns_max": max(num_turns_list), + }) + print(f"[nemo_gym_rollout_func] Num turns mean: {sum(num_turns_list) / len(num_turns_list):.2f}, min: {min(num_turns_list)}, max: {max(num_turns_list)}") + # We need to deduplicate prompt_ids to match TRL's current code that re-duplicates prompts # TRL deduplicates here: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1266 so we had to duplicate prompts for num_generations # TRL reduplicates here: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1314 so we need to dedup prompts - print(f"[nemo_gym_rollout_func] Deduplicating prompt_ids (keeping 1 per {num_generations} completions)...") unique_prompt_ids = [] - for idx in range(0, len(prompt_ids), num_generations): + for idx in range(0, len(prompt_ids), trainer.args.num_generations): if idx < len(prompt_ids): unique_prompt_ids.append(prompt_ids[idx]) - print(f"[nemo_gym_rollout_func] Deduplicated: {len(prompt_ids)} → {len(unique_prompt_ids)} unique prompt_ids") - print(f"[nemo_gym_rollout_func] Final counts: {len(unique_prompt_ids)} prompt_ids, {len(completion_ids)} completion_ids") - print(f"[nemo_gym_rollout_func] Expected ratio: {len(completion_ids) / len(unique_prompt_ids) if unique_prompt_ids else 0:.1f} completions per prompt") print(f"{'='*80}\n") return { @@ -447,6 +459,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, "completion_mask": completion_mask, "logprobs": logprobs, "env_reward": env_rewards, + "num_turns": num_turns_list, } def get_max_prompt_length(dataset: Dataset, tokenizer) -> int: @@ -454,9 +467,6 @@ def get_max_prompt_length(dataset: Dataset, tokenizer) -> int: prompt_lengths = [len(tokenizer.encode(item.get("prompt", ""))) for item in dataset if item.get("prompt", "")] prompt_lengths.sort() max_length = prompt_lengths[-1] - print(f"[get_max_prompt_length] Min length: {min(prompt_lengths)}") - print(f"[get_max_prompt_length] Max length: {max(prompt_lengths)}") - print(f"[get_max_prompt_length] Mean length: {sum(prompt_lengths) / len(prompt_lengths):.1f}") return max_length @@ -487,7 +497,6 @@ def load_dataset_from_jsonl(path: str) -> Dataset: item["prompt"] = "\n\n".join(prompt_parts) if prompt_parts else "" elif isinstance(input_data, str): # Format as string (e.g. google_search) - # Combine instructions field (system prompt) + input field (question) prompt_parts = [] if instructions: prompt_parts.append(instructions) @@ -518,10 +527,11 @@ def load_dataset_from_jsonl(path: str) -> Dataset: return dataset - def main(): parser = argparse.ArgumentParser(description="") parser.add_argument("--config", required=True, help="Path to config YAML file") + parser.add_argument("--vllm_server_host", type=str, default="127.0.0.1", + help="vLLM server hostname/IP") args = parser.parse_args() with open(args.config) as f: @@ -575,7 +585,7 @@ def main(): training_args = GRPOConfig( use_vllm=True, vllm_mode="server", - vllm_server_host="127.0.0.1", + vllm_server_host=args.vllm_server_host, vllm_server_port=8000, temperature=config.temperature, diff --git a/pyproject.toml b/pyproject.toml index 25993a5b109..99d7f5cbadf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,12 @@ requires-python = ">=3.10" dependencies = [ "accelerate>=1.4.0", "datasets>=3.0.0", + "fastapi>=0.124.4", + "omegaconf>=2.3.0", "transformers>=4.56.1", + "uvicorn>=0.38.0", + "vllm>=0.11.2", + "wandb>=0.23.1", ] dynamic = ["version"] diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 1858cf6e9bc..4b3edb1983a 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -31,8 +31,10 @@ import torch import torch.distributed.distributed_c10d as c10d from transformers import is_torch_xpu_available, is_vision_available +from transformers import AutoTokenizer from trl import TrlParser +# from trl.chat_template_utils import add_response_schema # For native tool call parsing from trl.import_utils import ( is_fastapi_available, is_pydantic_available, @@ -374,7 +376,23 @@ def llm_worker( method_name = command["method"] args, kwargs = command.get("args", ()), command.get("kwargs", {}) method = getattr(llm, method_name) - result = method(*args, **kwargs) + + try: + result = method(*args, **kwargs) + except ValueError as e: + error_msg = str(e) + if "longer than the maximum model length" in error_msg or "context length" in error_msg: + logger.error(f"[Worker] Context length exceeded: {error_msg}") + if method_name in ["generate", "chat"]: + result = [] + else: + raise + else: + raise + except Exception as e: + logger.error(f"[Worker] Unexpected error in {method_name}: {e}") + raise + if command["type"] == "call": connection.send(result) elif command["type"] == "shutdown": @@ -412,6 +430,80 @@ def sanitize_logprob(logprob): return value +def _replace_prefix_tokens( + tokenizer, + model_prefix_token_ids: list[int], + template_prefix_token_ids: list[int], + template_token_ids: list[int], +) -> list[int]: + """ + Fixes up chat template-tokenized messages to match model output tokenization. + + This preserves the monotonic tokens property for optimized multi-turn training by ensuring + that previously generated tokens are not retokenized differently by the chat template. + + Based on NeMo RL's _replace_prefix_tokens: + https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + + When model generates text that gets re-tokenized by chat template in subsequent turns, + the token IDs may differ due to tokenization ambiguity (e.g., whitespace, context). + + Example: + Turn 1: Model outputs [220, 17] (decodes to " 4") + Turn 2: Template retokenizes as [1001] (also decodes to " 4") + + This creates off-policy issues where training logprobs don't match generation logprobs. + + This solves the issue by keeping exact model tokens up to the last EOS, then append new template tokens after EOS. + This ensures we preserve the actual tokens the model generated, not a retokenized version. + + Args: + tokenizer: The tokenizer instance, must have eos_token_id + model_prefix_token_ids: Token IDs from actual model generation to preserve + template_prefix_token_ids: Chat template applied up to last assistant message + template_token_ids: Chat template applied to full conversation + + Returns: + Combined token sequence: model_tokens[:-1] + template_tokens[after_eos:] + + Example: + model_prefix_token_ids = [1, 2, 3, 220, 17, 2] # Last 2 is EOS + template_prefix_token_ids = [1, 2, 3, 1001, 2] # Up to last assistant + template_token_ids = [1, 2, 3, 1001, 2, 21, 22] # Full conversation + + Output: [1, 2, 3, 220, 17, 2, 21, 22] # Keeps original [220, 17], adds new [21, 22] + """ + if not model_prefix_token_ids: + return template_token_ids + + eos_token_id = tokenizer.eos_token_id + if eos_token_id is None: + logger.warning("Tokenizer has no EOS token ID, cannot apply _replace_prefix_tokens") + return template_token_ids + + # Find where to cut the model prefix (before EOS if present at end) + model_cut_end = len(model_prefix_token_ids) + if model_prefix_token_ids and model_prefix_token_ids[-1] == eos_token_id: + model_cut_end -= 1 + + # Find the last EOS in template prefix + template_cut_start = -1 + for pos in reversed(range(len(template_prefix_token_ids))): + if template_token_ids[pos] == eos_token_id: + template_cut_start = pos + break + + if template_cut_start < 0: + logger.warning("No EOS token found in template prefix, cannot apply _replace_prefix_tokens") + return template_token_ids + + result = ( + model_prefix_token_ids[:model_cut_end] + + template_token_ids[template_cut_start:] + ) + + return result + def main(script_args: ScriptArguments): if not is_fastapi_available(): raise ImportError( @@ -442,8 +534,22 @@ def main(script_args: ScriptArguments): connections.append(parent_connection) processes.append(process) + cached_tokenizer = None + @asynccontextmanager async def lifespan(app: FastAPI): + nonlocal cached_tokenizer + + logger.info(f"Loading tokenizer for {script_args.model}...") + cached_tokenizer = AutoTokenizer.from_pretrained(script_args.model, trust_remote_code=script_args.trust_remote_code) + + # uncomment for native tool call parsing + # try: + # cached_tokenizer = add_response_schema(cached_tokenizer) + # logger.info("Response schema added - vLLM will use native tool call parsing") + # except (ValueError, AttributeError) as e: + # logger.warning(f"Could not add response schema: {e}. Will fall back to XML parsing if tools are used.") + # Wait for all workers to send "ready" ready_connections = set() while len(ready_connections) < script_args.data_parallel_size: @@ -894,23 +1000,103 @@ async def chat_completions(request: ChatCompletionRequest): if request.tool_choice and request.tool_choice != "auto": chat_template_kwargs["tool_choice"] = request.tool_choice - chunked_messages = chunk_list([messages], script_args.data_parallel_size) + has_prefix_token_ids = any( + msg.get("role") == "assistant" and "prompt_token_ids" in msg + for msg in messages + ) - for connection, message_chunk in zip(connections, chunked_messages, strict=True): - if not message_chunk: - message_chunk = [[{"role": "user", "content": ""}]] - kwargs = { - "messages": message_chunk, - "sampling_params": sampling_params, - "tools": request.tools, - "chat_template_kwargs": chat_template_kwargs - } - connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) + if has_prefix_token_ids: + # do on policy token id correction and call generate instead of chat + # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html + # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + logger.info(f"[/chat/completions] Detected prefix token IDs in assistant message, using _replace_prefix_tokens") + tokenizer = cached_tokenizer + + # preprocess full conversation + connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": { + "messages": [messages], "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, "add_generation_prompt": True}}) + template_prompts = connections[0].recv() + template_prompt = template_prompts[0] + + # extract model prefix tokens from last assistant message + model_prefix_tokens = None + last_assistant_idx = None + for i in reversed(range(len(messages))): + if messages[i].get("role") == "assistant": + last_assistant_idx = i + if "prompt_token_ids" in messages[i]: + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get("generation_token_ids", []) + break + + if model_prefix_tokens and last_assistant_idx is not None: + messages_to_last_assistant = messages[:last_assistant_idx + 1] + connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": { + "messages": [messages_to_last_assistant], "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, "add_generation_prompt": False}}) + template_prefix_prompts = connections[0].recv() + template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] + + logger.info(f"[/chat/completions] Calling _replace_prefix_tokens model_prefix len: {len(model_prefix_tokens)}, template_prefix len: {len(template_prefix_token_ids)}, template_full len: {len(template_prompt['prompt_token_ids'])}") + + corrected_token_ids = _replace_prefix_tokens( + tokenizer, + model_prefix_tokens, + template_prefix_token_ids, + template_prompt["prompt_token_ids"] + ) + + logger.info(f"[/chat/completions] final len = {len(corrected_token_ids)}") + else: + logger.info(f"[/chat/completions] Skipping _replace_prefix_tokens, model_prefix_tokens={model_prefix_tokens is not None}, last_assistant_idx={last_assistant_idx}") + corrected_token_ids = template_prompt["prompt_token_ids"] + + corrected_prompt = {"prompt_token_ids": corrected_token_ids} + chunked_prompts = chunk_list([corrected_prompt], script_args.data_parallel_size) + + for connection, prompts in zip(connections, chunked_prompts, strict=True): + if not prompts: + prompts = [{"prompt_token_ids": [tokenizer.eos_token_id]}] + connection.send({"type": "call", "method": "generate", "kwargs": { + "prompts": prompts, "sampling_params": sampling_params}}) + else: + # no prefix token IDs, use chat() + chunked_messages = chunk_list([messages], script_args.data_parallel_size) + + for connection, message_chunk in zip(connections, chunked_messages, strict=True): + if not message_chunk: + message_chunk = [[{"role": "user", "content": ""}]] + kwargs = { + "messages": message_chunk, + "sampling_params": sampling_params, + "tools": request.tools, + "chat_template_kwargs": chat_template_kwargs + } + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) all_outputs = [connection.recv() for connection in connections] - all_outputs = [output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk] + if has_prefix_token_ids: + all_outputs = [o for o in all_outputs if o] + else: + all_outputs = [output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk] all_outputs = list(chain.from_iterable(all_outputs)) + if not all_outputs: + logger.warning("[/chat/completions] All workers returned empty - max seq len probably exceeded. Returning empty msg with finish_reason=length") + return { + "id": completion_id, + "object": "chat.completion", + "created": created_at, + "model": request.model or script_args.model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": ""}, + "finish_reason": "length", + "logprobs": None + }], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + } + choices = [] total_input_tokens = 0 total_output_tokens = 0 @@ -929,9 +1115,6 @@ async def chat_completions(request: ChatCompletionRequest): tool_calls = gen_output.tool_calls finish_reason = "tool_calls" elif request.tools and text: - # If no native tool call parser, try XML - # TODO: figure out how to use a tool call parser, or handle tool call parsing in Nemo Gym maybe?? - # or implement real async vllm engine and openai api server rather than faking the endpoint here pattern = r'(.*?)' matches = re.findall(pattern, text, re.DOTALL) if matches: @@ -1001,8 +1184,15 @@ class TokenizeRequest(BaseModel): @app.post("/tokenize") async def tokenize(request: TokenizeRequest): + messages = request.messages + + has_prefix_token_ids = any( + msg.get("role") == "assistant" and "prompt_token_ids" in msg + for msg in messages + ) + kwargs = { - "messages": [request.messages], + "messages": [messages], "tools": request.tools, "add_generation_prompt": True, "chat_template_kwargs": {} @@ -1013,13 +1203,55 @@ async def tokenize(request: TokenizeRequest): if preprocessed_prompts and len(preprocessed_prompts) > 1: logger.warning(f"More than one tokenized message returned from preprocess_chat inside tokenize, double check results!") - - if preprocessed_prompts and len(preprocessed_prompts) > 0: - return { - "tokens": preprocessed_prompts[0]["prompt_token_ids"], - "model": request.model or script_args.model - } - return {"tokens": [], "model": request.model or script_args.model} + + if not preprocessed_prompts or len(preprocessed_prompts) == 0: + return {"tokens": [], "model": request.model or script_args.model} + + template_prompt = preprocessed_prompts[0] + result_tokens = template_prompt["prompt_token_ids"] + + if has_prefix_token_ids: + tokenizer = cached_tokenizer + + # Extract model prefix tokens from last assistant message + model_prefix_tokens = None + last_assistant_idx = None + for i in reversed(range(len(messages))): + if messages[i].get("role") == "assistant": + last_assistant_idx = i + if "prompt_token_ids" in messages[i]: + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get("generation_token_ids", []) + break + + if model_prefix_tokens and last_assistant_idx is not None: + # Preprocess up to last assistant + messages_to_last_assistant = messages[:last_assistant_idx + 1] + connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": { + "messages": [messages_to_last_assistant], + "tools": request.tools, + "add_generation_prompt": False, + "chat_template_kwargs": {} + }}) + template_prefix_prompts = connections[0].recv() + template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] + + logger.info(f"[/tokenize] Calling _replace_prefix_tokens, model_prefix length: {len(model_prefix_tokens)}, template_prefix length: {len(template_prefix_token_ids)}, template_full length: {len(template_prompt['prompt_token_ids'])}") + + result_tokens = _replace_prefix_tokens( + tokenizer, + model_prefix_tokens, + template_prefix_token_ids, + template_prompt["prompt_token_ids"] + ) + + logger.info(f"[/tokenize] final length = {len(result_tokens)}") + else: + logger.info(f"[/tokenize] Skipping _replace_prefix_tokens, one of model_prefix_tokens={model_prefix_tokens is not None}, last_assistant_idx={last_assistant_idx} is None") + + return { + "tokens": result_tokens, + "model": request.model or script_args.model + } uvicorn.run( app, @@ -1028,7 +1260,7 @@ async def tokenize(request: TokenizeRequest): log_level=script_args.log_level, limit_concurrency=256, backlog=4096, - timeout_keep_alive=60 + timeout_keep_alive=600 ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ebb0e100bfe..be4ac452395 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1785,7 +1785,6 @@ def _generate_and_score_completions( completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] # Allow custom completion_mask from rollout_func for multi-turn training - # This allows masking out non-trainable tokens (e.g., tool results, observations) in the completion if "completion_mask" in extra_fields: completion_mask_list = extra_fields.pop("completion_mask") completion_mask = [torch.tensor(m, device=device, dtype=torch.long) for m in completion_mask_list] From 948869f85a9eab2578db89335ec0cd354a226016 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 6 Jan 2026 22:09:06 -0800 Subject: [PATCH 06/51] fixes Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train.py | 351 ++++++++--------------------- 1 file changed, 89 insertions(+), 262 deletions(-) diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index 1729983c64a..be6f0d624db 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -1,5 +1,4 @@ import os -import numpy as np import argparse import asyncio import aiohttp @@ -12,13 +11,12 @@ from datasets import Dataset, load_dataset from trl import GRPOConfig, GRPOTrainer from tqdm import tqdm - +import wandb from transformers import AutoTokenizer def get_agent_server( head_server_host: str = "127.0.0.1", head_server_port: int = 11000, - agent_name: str = None, ) -> str: try: response = requests.get( @@ -28,29 +26,21 @@ def get_agent_server( response.raise_for_status() global_config_yaml = response.text global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) - - if agent_name: - for _, project_config in global_config_dict.items(): - if hasattr(project_config, 'responses_api_agents'): - agents = project_config.responses_api_agents - if hasattr(agents, agent_name): - agent_config = getattr(agents, agent_name) - agent_server = f"http://{agent_config.host}:{agent_config.port}" - return agent_server - - raise ValueError(f"Agent '{agent_name}' not found in any project's responses_api_agents") - + for _, project_config in global_config_dict.items(): if hasattr(project_config, 'responses_api_agents'): agents = project_config.responses_api_agents for name in agents.keys(): agent_config = getattr(agents, name) if hasattr(agent_config, 'host') and hasattr(agent_config, 'port'): - agent_server = f"http://{agent_config.host}:{agent_config.port}" + agent_host = agent_config.host + if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"): + agent_host = head_server_host + agent_server = f"http://{agent_host}:{agent_config.port}" return agent_server - + raise ValueError("No agents found in global config") - + except requests.exceptions.RequestException as e: raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") @@ -61,7 +51,6 @@ class TrainingConfig: dataset_path: str task: Optional[str] = None - agent_name: Optional[str] = None learning_rate: float = 5e-6 max_steps: int = 100 @@ -75,7 +64,8 @@ class TrainingConfig: temperature: float = 1.0 top_p: float = 0.999 weight_decay: float = 0.01 - warmup_ratio: float = 0.1 + warmup_ratio: float = 0.0 + warmup_steps: int = 0 lr_scheduler_type: str = "linear" optim: str = "adamw_8bit" @@ -85,19 +75,18 @@ class TrainingConfig: run_name: str = None # Wandb project_name: str = None # Wandb -def reward_fn(completions: List[str], **kwargs) -> List[float]: - env_rewards = kwargs.get("env_reward", []) + log_completions: bool = False + num_completions_to_print: int = None - if not env_rewards: - print(f"WARNING: No rewards from Nemo Gym, returning zeros for {len(completions)} completions") - return [0.0] * len(completions) - - print(f"[reward_fn] Mean reward: {sum(env_rewards)/len(env_rewards):.3f}") - print(f"[reward_fn] Reward std dev: {np.std(env_rewards):.3f}") + eval_dataset_path: Optional[str] = None + eval_strategy: str = "no" + eval_steps: int = 50 +def reward_fn(completions: List[str], **kwargs) -> List[float]: + env_rewards = kwargs.get("env_reward") + assert env_rewards is not None, "env_reward not found in kwargs" return [float(r) for r in env_rewards] - async def call_nemo_gym_agent( prompts: List[str], dataset_items: List[Dict[str, Any]], @@ -107,13 +96,11 @@ async def call_nemo_gym_agent( temperature: float = 1.0, top_p: float = 0.999, ) -> List[Dict[str, Any]]: - print(f"[call_nemo_gym_agent] Calling Nemo Gym agent at {agent_server}") - print(f"[call_nemo_gym_agent] len(prompts): {len(prompts)}") - print(f"[call_nemo_gym_agent] max_completion_length: {max_completion_length}") + print(f"Calling Nemo Gym agent at {agent_server} with {len(prompts)} prompts") async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session: tasks = [] - for i, (prompt, item) in enumerate(zip(prompts, dataset_items)): + for prompt, item in zip(prompts, dataset_items): request_body = item.copy() if "responses_create_params" not in request_body: @@ -126,9 +113,6 @@ async def call_nemo_gym_agent( params["temperature"] = temperature params["top_p"] = top_p - if i == 0: - print(f"First request keys: {list(params.keys())}") - task = session.post( f"{agent_server}/run", json=request_body, @@ -145,77 +129,45 @@ async def call_nemo_gym_agent( results = [] for i, response in enumerate(responses): - if isinstance(response, Exception): - # should we error instead - # might be needed for truncated - print(f"WARNING: Request {i} failed: {response}") - results.append({"response": {"output": []}, "reward": 0.0, "error": str(response)}) - else: - try: - json_data = await response.json() - if isinstance(json_data, dict): - results.append(json_data) - else: - print(f"WARNING: Request {i} returned non-dict: {type(json_data)}") - results.append({"response": {"output": []}, "reward": 0.0, "error": f"Non-dict response"}) - except Exception as e: - print(f"WARNING: Failed to parse response {i}: {e}") - results.append({"response": {"output": []}, "reward": 0.0, "error": str(e)}) + try: + json_data = await response.json() + if not isinstance(json_data, dict): + raise ValueError(f"Expected dict, got {type(json_data)}") + results.append(json_data) + except Exception as e: + print(f"WARNING: Request {i} failed: {e}") + results.append({"response": {"output": []}, "reward": 0.0, "error": str(e)}) return results def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, List]: - """ - Rollout function for Nemo Gym agent within TRL GRPOTrainer - - Builds interleaved action/observation sequence with masking of observations. - - prompt_ids: first turn's prompt only - - completion_ids: interleaved [model_gen1, tool_result1, model_gen2, tool_result2, ...] - - completion_mask: 1 for model tokens, 0 for tool results - - logprobs: for model tokens, 0.0 for tool result tokens - - This ensures: - 1. Logprobs are computed on the full context, including tool results - 2. Loss is only backpropagated on model-generated tokens - """ - - current_step = trainer.state.global_step if hasattr(trainer, 'state') else 0 - - print(f"\n{'='*80}") + current_step = trainer.state.global_step - prompts_dedup = set(prompts) - print(f"[nemo_gym_rollout_func] Got {len(prompts)} prompts, {len(prompts_dedup)} unique prompts, {trainer.args.num_generations} generations per prompt") + is_eval = not trainer.model.training + num_generations = trainer.args.num_generations_eval if is_eval and trainer.args.num_generations_eval else trainer.args.num_generations + dataset = trainer.eval_dataset if is_eval and trainer.eval_dataset is not None else trainer.train_dataset expanded_prompts = [] expanded_dataset_items = [] for prompt in prompts: matching_item = None - for item in trainer.train_dataset: + for item in dataset: if item.get("prompt") == prompt: matching_item = dict(item) - # Deserialize JSON strings back to dicts/lists - for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth"]: + for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata"]: if key in matching_item and isinstance(matching_item[key], str): - try: - matching_item[key] = json.loads(matching_item[key]) - except: - pass + matching_item[key] = json.loads(matching_item[key]) break if not matching_item: - print(f"WARNING: Could not find dataset item for prompt, using prompt only") - matching_item = {"prompt": prompt} + raise ValueError(f"Could not find dataset item for prompt: {prompt}") - for _ in range(trainer.args.num_generations): + for _ in range(num_generations): expanded_prompts.append(prompt) expanded_dataset_items.append(dict(matching_item)) - print(f"[nemo_gym_rollout_func] Expanded to {len(expanded_prompts)} total requests ({len(prompts)} prompts × {trainer.args.num_generations} generations)") - - print("[nemo_gym_rollout_func] Calling Nemo Gym agent...") - print(f"[nemo_gym_rollout_func] Using temperature: {trainer.args.temperature}, top_p: {trainer.args.top_p}") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: @@ -233,8 +185,6 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, finally: loop.close() - print(f"[nemo_gym_rollout_func] Got {len(responses)} responses from Nemo Gym agent") - trajectory_file = os.path.join(trainer.args.output_dir, "trajectories.jsonl") os.makedirs(trainer.args.output_dir, exist_ok=True) @@ -249,8 +199,6 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, } f.write(json.dumps(trajectory_data) + "\n") - print(f"[nemo_gym_rollout_func] Saved {len(responses)} trajectories to {trajectory_file}") - tokenizer = trainer.processing_class # interleaved completion with mask @@ -261,54 +209,32 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, env_rewards: List[float] = [] num_turns_list: List[int] = [] - failed_count = 0 - success_count = 0 - for i, response in enumerate(responses): expected_prompt = expanded_prompts[i] expected_prompt_ids = tokenizer.encode(expected_prompt, add_special_tokens=False) - rollout_failed = False - failure_reason = None - if not isinstance(response, dict): rollout_failed = True - failure_reason = f"response is not a dict: {type(response)}" - elif "error" in response: + elif response.get("error"): rollout_failed = True - failure_reason = f"had error: {response['error']}" else: output_items = response.get("response", {}).get("output", []) - if not output_items: rollout_failed = True - failure_reason = "has no output items (we are masking and returing just eos for truncated rollouts right now)" else: - has_content = False - for item in output_items: - if item.get("type") == "message": - content_list = item.get("content", []) - for content_item in content_list: - if content_item.get("type") == "output_text": - text = content_item.get("text", "").strip() - if text: - has_content = True - break - elif item.get("type") == "function_call": - has_content = True - break - if has_content: - break - - if not has_content: - rollout_failed = True - failure_reason = "has empty content in all output items (we are masking and returing just eos for truncated rollouts right now)" + has_content = any( + item.get("type") == "function_call" or ( + item.get("type") == "message" and + any(c.get("type") == "output_text" and c.get("text", "").strip() + for c in item.get("content", [])) + ) + for item in output_items + ) + rollout_failed = not has_content # truncated or other failure - mask out if rollout_failed: - failed_count += 1 - print(f"[nemo_gym_rollout_func] WARNING: Rollout {i} {failure_reason}. Filling with eos and zero reward.") - eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0 + eos_token_id = tokenizer.eos_token_id or 0 prompt_ids.append(expected_prompt_ids) completion_ids.append([eos_token_id]) completion_mask.append([0]) @@ -320,7 +246,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, episode_reward = response.get("reward", 0.0) output_items = response.get("response", {}).get("output", []) - # Make interleaved completion: [model_gen1, tool_result1, model_gen2, tool_result2, ...] with mask 1 for generations, 0 for tool results (aka action, observation) + # Make interleaved completion: (p,a,o,a,o...) & mask all but assistant role # Each turn has prompt_ids, gen_ids. Find tool_result = current_prompt - previous_seen_tokens and mask it # on policy tokenid correction replace_prefix_tokens is done in vllm server @@ -331,9 +257,9 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, first_prompt = None num_turns = 0 - for item in output_items: + for idx, item in enumerate(output_items): if "prompt_token_ids" not in item or "generation_token_ids" not in item: - continue + raise ValueError(f"Item {idx} missing prompt_token_ids or generation_token_ids") num_turns += 1 item_prompt_ids = item["prompt_token_ids"] @@ -345,45 +271,12 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, first_prompt = item_prompt_ids seen_token_ids = list(item_prompt_ids) else: - # Extract tool result tokens if len(item_prompt_ids) > len(seen_token_ids): - # Assert prefix matches if item_prompt_ids[:len(seen_token_ids)] != seen_token_ids: - diverge_idx = -1 - for idx in range(min(len(seen_token_ids), len(item_prompt_ids))): - if seen_token_ids[idx] != item_prompt_ids[idx]: - diverge_idx = idx - break - - # Show context around divergence - # prob can delete this, was for debugging right - context_window = 20 - start = max(0, diverge_idx - context_window) if diverge_idx >= 0 else 0 - end = min(len(seen_token_ids), diverge_idx + context_window) if diverge_idx >= 0 else min(50, len(seen_token_ids)) - - error_msg = ( - f"[Turn {num_turns}] Non-contiguous messages found! " - f"This may be a tokenization issue.\n" - f"Length of expected prefix: {len(seen_token_ids)}\n" - f"Length of new prompt: {len(item_prompt_ids)}\n" + raise ValueError( + f"[Turn {num_turns}] Non-contiguous messages (tokenization issue). " + f"Expected prefix len {len(seen_token_ids)}, got prompt len {len(item_prompt_ids)}" ) - - if diverge_idx >= 0: - error_msg += ( - f"Tokens diverge at index {diverge_idx}:\n" - f"Expected[{start}:{end}]: {seen_token_ids[start:end]}\n" - f"Got[{start}:{end}]: {item_prompt_ids[start:end]}\n" - f"Expected token at [{diverge_idx}]: {seen_token_ids[diverge_idx]}\n" - f"Got token at [{diverge_idx}]: {item_prompt_ids[diverge_idx]}\n" - ) - else: - error_msg += ( - f"Prefix length mismatch but tokens match up to min length.\n" - f"Expected (first 50): {seen_token_ids[:50]}\n" - f"Got (first 50): {item_prompt_ids[:50]}\n" - ) - - raise ValueError(error_msg) tool_result_tokens = item_prompt_ids[len(seen_token_ids):] if tool_result_tokens: @@ -393,65 +286,36 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, interleaved_completion.extend(item_gen_ids) interleaved_mask.extend([1] * len(item_gen_ids)) - interleaved_logprobs.extend( - item_logprobs if len(item_logprobs) == len(item_gen_ids) else [0.0] * len(item_gen_ids) - ) + assert len(item_logprobs) == len(item_gen_ids), f"Logprobs len {len(item_logprobs)} != gen len {len(item_gen_ids)}" + interleaved_logprobs.extend(item_logprobs) seen_token_ids = list(item_prompt_ids) + list(item_gen_ids) if not interleaved_completion or first_prompt is None: raise ValueError(f"Rollout {i} has no valid turns") - - success_count += 1 - prompt_ids.append(first_prompt) completion_ids.append(interleaved_completion) completion_mask.append(interleaved_mask) logprobs.append(interleaved_logprobs) env_rewards.append(episode_reward) num_turns_list.append(num_turns) - - model_tokens = sum(interleaved_mask) - tool_tokens = len(interleaved_mask) - model_tokens - - print(f"\n{'='*60}") - print(f"[nemo_gym_rollout_func] Turns: {num_turns}, Reward: {episode_reward:.3f}") - print(f"[nemo_gym_rollout_func] Prompt tokens: {len(first_prompt)}") - print(f"[nemo_gym_rollout_func] Rollout tokens: {len(interleaved_completion)} (model: {model_tokens}, tool: {tool_tokens})") - print(f"[nemo_gym_rollout_func] Rollout preview: {tokenizer.decode(interleaved_completion)[:150]}...") - print(f"{'='*60}\n") - - print(f"\n{'='*80}") - print(f"[nemo_gym_rollout_func] Success: {success_count}, Failed: {failed_count}") if not prompt_ids: raise RuntimeError( "No valid rollouts. Check Nemo Gym and vLLM logs." ) - print(f"[nemo_gym_rollout_func] Mean reward: {sum(env_rewards) / len(env_rewards) if env_rewards else 0.0:.3f}") - print(f"[nemo_gym_rollout_func] Total action tokens: {sum(sum(m) for m in completion_mask)}, total observation tokens: {sum(len(m) - sum(m) for m in completion_mask)}") - # log num turns to wandb if num_turns_list: - import wandb wandb.log({ "train/num_turns_mean": sum(num_turns_list) / len(num_turns_list), "train/num_turns_min": min(num_turns_list), "train/num_turns_max": max(num_turns_list), }) - print(f"[nemo_gym_rollout_func] Num turns mean: {sum(num_turns_list) / len(num_turns_list):.2f}, min: {min(num_turns_list)}, max: {max(num_turns_list)}") - - # We need to deduplicate prompt_ids to match TRL's current code that re-duplicates prompts - # TRL deduplicates here: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1266 so we had to duplicate prompts for num_generations - # TRL reduplicates here: https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1314 so we need to dedup prompts - unique_prompt_ids = [] - for idx in range(0, len(prompt_ids), trainer.args.num_generations): - if idx < len(prompt_ids): - unique_prompt_ids.append(prompt_ids[idx]) - print(f"{'='*80}\n") + # Deduplicate prompt_ids since TRL re-duplicates them internally + unique_prompt_ids = prompt_ids[::num_generations] return { "prompt_ids": unique_prompt_ids, @@ -464,30 +328,23 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, def get_max_prompt_length(dataset: Dataset, tokenizer) -> int: tokenizer = AutoTokenizer.from_pretrained(tokenizer) - prompt_lengths = [len(tokenizer.encode(item.get("prompt", ""))) for item in dataset if item.get("prompt", "")] - prompt_lengths.sort() - max_length = prompt_lengths[-1] - return max_length - + return max(len(tokenizer.encode(item.get("prompt", ""))) for item in dataset if item.get("prompt")) def load_dataset_from_jsonl(path: str) -> Dataset: - # TODO: standardize nemo gym dataset format or only accept 1 here (instructions field, answer field, jsonl structure...) data = [] with open(path, 'r') as f: for line in f: if line.strip(): item = json.loads(line) - # Extract prompt before serializing if "prompt" not in item: if "responses_create_params" in item and isinstance(item["responses_create_params"], dict): responses_params = item["responses_create_params"] input_data = responses_params.get("input") instructions = responses_params.get("instructions", "") - # Handle both message list format and string format if isinstance(input_data, list) and len(input_data) > 0: - # Format as messages (e.g. reasoning_gym) + # (e.g. reasoning_gym) prompt_parts = [] if instructions: prompt_parts.append(f"system: {instructions}") @@ -496,7 +353,7 @@ def load_dataset_from_jsonl(path: str) -> Dataset: prompt_parts.append(f"{msg['role']}: {msg['content']}") item["prompt"] = "\n\n".join(prompt_parts) if prompt_parts else "" elif isinstance(input_data, str): - # Format as string (e.g. google_search) + # (e.g. google_search) prompt_parts = [] if instructions: prompt_parts.append(instructions) @@ -508,39 +365,29 @@ def load_dataset_from_jsonl(path: str) -> Dataset: else: item["prompt"] = item.get("question", "") - # Serialize problematic nested structures to JSON strings - for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth"]: + for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata"]: if key in item and isinstance(item[key], (dict, list)): item[key] = json.dumps(item[key]) data.append(item) - print(f"Loaded {len(data)} examples from {path}") - - if len(data) < 100: - repeat_factor = 100 - print(f"Repeating dataset {repeat_factor}x: {len(data)} -> {len(data) * repeat_factor}") - data = data * repeat_factor - - dataset = Dataset.from_list(data) - # dataset = dataset.shuffle(seed=42) - - return dataset + return Dataset.from_list(data) def main(): parser = argparse.ArgumentParser(description="") parser.add_argument("--config", required=True, help="Path to config YAML file") parser.add_argument("--vllm_server_host", type=str, default="127.0.0.1", help="vLLM server hostname/IP") + parser.add_argument("--head_server_host", type=str, default="127.0.0.1", + help="Head server hostname/IP for ng_run") args = parser.parse_args() with open(args.config) as f: config = TrainingConfig(**yaml.safe_load(f)) agent_server = get_agent_server( - head_server_host="127.0.0.1", + head_server_host=args.head_server_host, head_server_port=11000, - agent_name=config.agent_name, ) if config.project_name: @@ -551,27 +398,15 @@ def main(): model_short = config.model_name.split("/")[-1] config.run_name = ( f"{task}_{model_short}" - f"_ng{config.num_generations}" + f"_rpp{config.num_generations}" f"_dbs{config.per_device_train_batch_size}" f"_ga{config.gradient_accumulation_steps}" f"_maxlen{config.max_seq_length}" f"_lr{config.learning_rate}" f"_temp{config.temperature}" f"_topp{config.top_p}" - f"_wd{config.weight_decay}" - f"_wu{config.warmup_ratio}" ) - print(f"\n\nModel: {config.model_name}") - print(f"Dataset: {config.dataset_path}") - print(f"Nemo Gym Agent: {agent_server}") - print(f"vLLM Server: 127.0.0.1:8000") - print(f"Output dir: {config.output_dir}") - print(f"Max steps: {config.max_steps}") - print(f"Num generations: {config.num_generations}") - print(f"Batch size: {config.per_device_train_batch_size}") - print(f"Gradient accumulation: {config.gradient_accumulation_steps}") - if config.dataset_path.endswith(('.jsonl', '.json')): dataset = load_dataset_from_jsonl(config.dataset_path) else: @@ -579,6 +414,11 @@ def main(): print(f"Dataset has {len(dataset)} examples\n") + eval_dataset = None + if config.eval_dataset_path: + eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path) + print(f"Eval dataset has {len(eval_dataset)} examples\n") + if config.max_prompt_length is None: config.max_prompt_length = get_max_prompt_length(dataset, config.model_name) @@ -588,74 +428,61 @@ def main(): vllm_server_host=args.vllm_server_host, vllm_server_port=8000, + gradient_checkpointing=True, + temperature=config.temperature, learning_rate=config.learning_rate, weight_decay=config.weight_decay, warmup_ratio=config.warmup_ratio, + warmup_steps=config.warmup_steps, lr_scheduler_type=config.lr_scheduler_type, optim=config.optim, per_device_train_batch_size=config.per_device_train_batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, num_generations=config.num_generations, + num_generations_eval=1, max_steps=config.max_steps, save_steps=config.save_steps, logging_steps=1, report_to=config.report_to, output_dir=config.output_dir, - run_name=config.run_name, # wandb - + run_name=config.run_name, # wandb + + eval_strategy=config.eval_strategy, + eval_steps=config.eval_steps, + epsilon=0.2, epsilon_high=0.28, loss_type="grpo", mask_truncated_completions=True, - log_completions=False, - + log_completions=config.log_completions, + num_completions_to_print=config.num_completions_to_print, + max_prompt_length=config.max_prompt_length, max_completion_length=config.max_seq_length - config.max_prompt_length, shuffle_dataset=False, + + model_init_kwargs={ + "torch_dtype": "auto", + }, ) training_args.agent_server = agent_server training_args.request_timeout = 6000 - print("\n" + "="*80) - print("GRPO Config:\n") - print(f"per_device_train_batch_size: {training_args.per_device_train_batch_size}") - print(f"gradient_accumulation_steps: {training_args.gradient_accumulation_steps}") - print(f"num_generations: {training_args.num_generations}") - print(f"steps_per_generation: {training_args.steps_per_generation if hasattr(training_args, 'steps_per_generation') else 'Not set (will default to gradient_accumulation_steps)'}") - print(f"generation_batch_size: {training_args.generation_batch_size if hasattr(training_args, 'generation_batch_size') else 'Not set (will be calculated)'}") - print(f"shuffle_dataset: {training_args.shuffle_dataset if hasattr(training_args, 'shuffle_dataset') else 'Not set (default: True)'}") - print(f"Dataset size: {len(dataset)}") - print("="*80 + "\n") - - print("Initializing GRPO Trainer...") - trainer = GRPOTrainer( model=config.model_name, reward_funcs=reward_fn, train_dataset=dataset, + eval_dataset=eval_dataset, rollout_func=nemo_gym_rollout_func, args=training_args, ) - print("=" * 80) - print("Starting training...") - trainer.train() - print("=" * 80) - print("Training complete") - - output_dir = config.output_dir + "/final" - print(f"\nSaving model to {output_dir}") - trainer.save_model(output_dir) - trainer.processing_class.save_pretrained(output_dir) - - print("\nFinished saving model") - if __name__ == "__main__": main() From 52a3140f165f91ef5c98b1777c8c69a4778d7fbc Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 6 Jan 2026 22:14:24 -0800 Subject: [PATCH 07/51] readme Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index bf686c6794f..0958655ce4d 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -24,5 +24,4 @@ ng_run "+config_paths=[resources_servers/workplace_assistant/configs/workplace_a CUDA_VISIBLE_DEVICES=0 python train.py --config config.yaml ``` -We should be able to do multinode, but im having issues with ngpu > 1 for training backend currently -https://huggingface.co/docs/trl/main/en/vllm_integration#modes-of-using-vllm-during-training \ No newline at end of file +multinode is working, an example will be uploaded soon! \ No newline at end of file From 0e71cbbd8807c7294d2d22d2e311eb6ecec7234f Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 6 Jan 2026 22:14:56 -0800 Subject: [PATCH 08/51] cfg Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/config.yaml | 47 ++++++++++++++------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 8d5eb3973c4..95a8724181d 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,30 +1,33 @@ # model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # model_name: "Qwen/Qwen3-0.6B" -model_name: "Qwen/Qwen3-4B-Instruct-2507" -# dataset_path: "/home/ubuntu/unsloth-gym-integration/Gym/resources_servers/reasoning_gym/data/train_knights_knaves.jsonl" -dataset_path: "train-workplace.jsonl" - -task: "workplace-assistant" # used in wandb run_name if not set explicitly +# model_name: "Qwen/Qwen2.5-1.5B-Instruct" +# model_name: "Qwen/Qwen3-30B-A3B-Instruct-2507" +# dataset_path: "/lustre/fsw/portfolios/llmservice/users/cmunley/Gym/resources_servers/reasoning_gym/data/train_knights_knaves.jsonl" +# dataset_path: "train-workplace.jsonl" +# dataset_path: "/lustre/fsw/portfolios/llmservice/users/cmunley/Gym/resources_servers/xlam_fc/data/train.jsonl" -output_dir: "outputs/trl_nemo_gym_workplace" +model_name: "Qwen/Qwen3-4B-Instruct-2507" +dataset_path: "train-mcqa.jsonl" +task: "mcqa" # just used in wandb run_name if name not set explicitly +output_dir: "outputs/trl_nemo_gym_mcqa" project_name: "cmunley-nemo-gym-trl-int" - -learning_rate: 2.0e-5 +learning_rate: 3.0e-6 max_steps: 1000000 - -num_generations: 12 - +num_generations: 8 per_device_train_batch_size: 1 -gradient_accumulation_steps: 192 - -max_seq_length: 8192 - -temperature: 1.0 -top_p: 0.99 -weight_decay: 0.01 -warmup_ratio: 0.0 +gradient_accumulation_steps: 8 +max_seq_length: 16384 +temperature: 1 +top_p: 0.999 +weight_decay: 0.00 +# warmup_ratio: 0.0 +warmup_steps: 10 lr_scheduler_type: "linear" -optim: "adamw_8bit" - +optim: "adamw_torch_fused" save_steps: 900000 -report_to: "wandb" \ No newline at end of file +report_to: "wandb" +log_completions: true +num_completions_to_print: 2 +eval_dataset_path: "val-mcqa-50.jsonl" +eval_strategy: "steps" +eval_steps: 5 \ No newline at end of file From 35480997b243f7f090acbc55a9fe205ec068cfeb Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 6 Jan 2026 22:30:53 -0800 Subject: [PATCH 09/51] small fix Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index be6f0d624db..a57281d89ac 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -57,7 +57,6 @@ class TrainingConfig: num_generations: int = 2 per_device_train_batch_size: int = 2 gradient_accumulation_steps: int = 16 - max_seq_length: int = 1024 max_prompt_length: int = None @@ -74,7 +73,6 @@ class TrainingConfig: report_to: str = "none" run_name: str = None # Wandb project_name: str = None # Wandb - log_completions: bool = False num_completions_to_print: int = None @@ -161,9 +159,6 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, matching_item[key] = json.loads(matching_item[key]) break - if not matching_item: - raise ValueError(f"Could not find dataset item for prompt: {prompt}") - for _ in range(num_generations): expanded_prompts.append(prompt) expanded_dataset_items.append(dict(matching_item)) @@ -201,7 +196,6 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, tokenizer = trainer.processing_class - # interleaved completion with mask prompt_ids: List[List[int]] = [] completion_ids: List[List[int]] = [] completion_mask: List[List[int]] = [] # 1 for action, 0 for observation @@ -232,7 +226,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, ) rollout_failed = not has_content - # truncated or other failure - mask out + # truncated or other failure - mask if rollout_failed: eos_token_id = tokenizer.eos_token_id or 0 prompt_ids.append(expected_prompt_ids) @@ -246,9 +240,10 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, episode_reward = response.get("reward", 0.0) output_items = response.get("response", {}).get("output", []) - # Make interleaved completion: (p,a,o,a,o...) & mask all but assistant role - # Each turn has prompt_ids, gen_ids. Find tool_result = current_prompt - previous_seen_tokens and mask it - # on policy tokenid correction replace_prefix_tokens is done in vllm server + # interleaved completion with mask (p,a,o,a,o...) + # Each turn has prompt_ids, gen_ids + # tool_result = prompt_ids - seen_token_ids (to mask it) + # replace_prefix_tokens done in vllm server https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html seen_token_ids: List[int] = [] interleaved_completion: List[int] = [] @@ -314,8 +309,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, "train/num_turns_max": max(num_turns_list), }) - # Deduplicate prompt_ids since TRL re-duplicates them internally - unique_prompt_ids = prompt_ids[::num_generations] + unique_prompt_ids = prompt_ids[::num_generations] # TRL re-duplicates them return { "prompt_ids": unique_prompt_ids, @@ -344,7 +338,7 @@ def load_dataset_from_jsonl(path: str) -> Dataset: instructions = responses_params.get("instructions", "") if isinstance(input_data, list) and len(input_data) > 0: - # (e.g. reasoning_gym) + # list of messages format (e.g. reasoning_gym) prompt_parts = [] if instructions: prompt_parts.append(f"system: {instructions}") @@ -353,8 +347,9 @@ def load_dataset_from_jsonl(path: str) -> Dataset: prompt_parts.append(f"{msg['role']}: {msg['content']}") item["prompt"] = "\n\n".join(prompt_parts) if prompt_parts else "" elif isinstance(input_data, str): - # (e.g. google_search) + # prompt as string, no list of messages (e.g. google_search) prompt_parts = [] + # system prompt if instructions: prompt_parts.append(instructions) if input_data: From 8373899800ccc4b0d7ee9d29fa44c449bd9336e2 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Thu, 8 Jan 2026 17:01:31 -0800 Subject: [PATCH 10/51] docs Signed-off-by: cmunley1 --- trl/scripts/vllm_serve.py | 73 ++++++++++++++------------------------- 1 file changed, 25 insertions(+), 48 deletions(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 4b3edb1983a..e3d25ed4960 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -437,41 +437,29 @@ def _replace_prefix_tokens( template_token_ids: list[int], ) -> list[int]: """ - Fixes up chat template-tokenized messages to match model output tokenization. - - This preserves the monotonic tokens property for optimized multi-turn training by ensuring - that previously generated tokens are not retokenized differently by the chat template. + This function is for fixing up the chat template-tokenized messages history + to match the model output tokenization up to the last assistant turn, + in order to preserve the monotonic tokens property for optimized multi-turn + training. + + RL training frameworks train models on token IDs, but the OpenAI compatible + server communicates in what is basically de-tokenized text. When multiple + model calls are made to the OpenAI compatible server in a single trajectory, + model generations in previous model calls may be re-tokenized to something + that is different than what was generated. This is not too big of an issue + (that we know of) at inference time, but the log probs the model produces + are different enough for the differently re-tokenized generation result that + it causes the training to be off policy. Off policy isn't necessarily a bad + thing in isolation, but this source of off-policyness may cause unexpected + issues if not properly accounted for. It also mis-aligns the token ID + sequences across model calls, which is strange during training. + + There are real cases where the model output string _does not match_ the chat + template tokenization of the parsed model output. A concrete example is + inconsistent whitespace tokens around tool call special tokens. Based on NeMo RL's _replace_prefix_tokens: https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 - - When model generates text that gets re-tokenized by chat template in subsequent turns, - the token IDs may differ due to tokenization ambiguity (e.g., whitespace, context). - - Example: - Turn 1: Model outputs [220, 17] (decodes to " 4") - Turn 2: Template retokenizes as [1001] (also decodes to " 4") - - This creates off-policy issues where training logprobs don't match generation logprobs. - - This solves the issue by keeping exact model tokens up to the last EOS, then append new template tokens after EOS. - This ensures we preserve the actual tokens the model generated, not a retokenized version. - - Args: - tokenizer: The tokenizer instance, must have eos_token_id - model_prefix_token_ids: Token IDs from actual model generation to preserve - template_prefix_token_ids: Chat template applied up to last assistant message - template_token_ids: Chat template applied to full conversation - - Returns: - Combined token sequence: model_tokens[:-1] + template_tokens[after_eos:] - - Example: - model_prefix_token_ids = [1, 2, 3, 220, 17, 2] # Last 2 is EOS - template_prefix_token_ids = [1, 2, 3, 1001, 2] # Up to last assistant - template_token_ids = [1, 2, 3, 1001, 2, 21, 22] # Full conversation - - Output: [1, 2, 3, 220, 17, 2, 21, 22] # Keeps original [220, 17], adds new [21, 22] """ if not model_prefix_token_ids: return template_token_ids @@ -481,18 +469,18 @@ def _replace_prefix_tokens( logger.warning("Tokenizer has no EOS token ID, cannot apply _replace_prefix_tokens") return template_token_ids - # Find where to cut the model prefix (before EOS if present at end) model_cut_end = len(model_prefix_token_ids) if model_prefix_token_ids and model_prefix_token_ids[-1] == eos_token_id: model_cut_end -= 1 - # Find the last EOS in template prefix + # We take everything starting with the EOS token ID. template_cut_start = -1 for pos in reversed(range(len(template_prefix_token_ids))): if template_token_ids[pos] == eos_token_id: template_cut_start = pos break + # This should never be the case, but if template_cut_start < 0: logger.warning("No EOS token found in template prefix, cannot apply _replace_prefix_tokens") return template_token_ids @@ -1009,7 +997,6 @@ async def chat_completions(request: ChatCompletionRequest): # do on policy token id correction and call generate instead of chat # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 - logger.info(f"[/chat/completions] Detected prefix token IDs in assistant message, using _replace_prefix_tokens") tokenizer = cached_tokenizer # preprocess full conversation @@ -1037,7 +1024,6 @@ async def chat_completions(request: ChatCompletionRequest): template_prefix_prompts = connections[0].recv() template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] - logger.info(f"[/chat/completions] Calling _replace_prefix_tokens model_prefix len: {len(model_prefix_tokens)}, template_prefix len: {len(template_prefix_token_ids)}, template_full len: {len(template_prompt['prompt_token_ids'])}") corrected_token_ids = _replace_prefix_tokens( tokenizer, @@ -1046,9 +1032,7 @@ async def chat_completions(request: ChatCompletionRequest): template_prompt["prompt_token_ids"] ) - logger.info(f"[/chat/completions] final len = {len(corrected_token_ids)}") else: - logger.info(f"[/chat/completions] Skipping _replace_prefix_tokens, model_prefix_tokens={model_prefix_tokens is not None}, last_assistant_idx={last_assistant_idx}") corrected_token_ids = template_prompt["prompt_token_ids"] corrected_prompt = {"prompt_token_ids": corrected_token_ids} @@ -1082,7 +1066,6 @@ async def chat_completions(request: ChatCompletionRequest): all_outputs = list(chain.from_iterable(all_outputs)) if not all_outputs: - logger.warning("[/chat/completions] All workers returned empty - max seq len probably exceeded. Returning empty msg with finish_reason=length") return { "id": completion_id, "object": "chat.completion", @@ -1111,10 +1094,10 @@ async def chat_completions(request: ChatCompletionRequest): tool_calls = None finish_reason = "stop" - if hasattr(gen_output, "tool_calls") and gen_output.tool_calls: + if hasattr(gen_output, "tool_calls") and gen_output.tool_calls: # native tool call parsing tool_calls = gen_output.tool_calls finish_reason = "tool_calls" - elif request.tools and text: + elif request.tools and text: # try manual tool call parsing eg qwen3 style xml format... this is a hack. pattern = r'(.*?)' matches = re.findall(pattern, text, re.DOTALL) if matches: @@ -1202,7 +1185,7 @@ async def tokenize(request: TokenizeRequest): preprocessed_prompts = connections[0].recv() if preprocessed_prompts and len(preprocessed_prompts) > 1: - logger.warning(f"More than one tokenized message returned from preprocess_chat inside tokenize, double check results!") + logger.warning("More than one tokenized message returned from preprocess_chat inside tokenize, double check results!") if not preprocessed_prompts or len(preprocessed_prompts) == 0: return {"tokens": [], "model": request.model or script_args.model} @@ -1235,8 +1218,6 @@ async def tokenize(request: TokenizeRequest): template_prefix_prompts = connections[0].recv() template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] - logger.info(f"[/tokenize] Calling _replace_prefix_tokens, model_prefix length: {len(model_prefix_tokens)}, template_prefix length: {len(template_prefix_token_ids)}, template_full length: {len(template_prompt['prompt_token_ids'])}") - result_tokens = _replace_prefix_tokens( tokenizer, model_prefix_tokens, @@ -1244,10 +1225,6 @@ async def tokenize(request: TokenizeRequest): template_prompt["prompt_token_ids"] ) - logger.info(f"[/tokenize] final length = {len(result_tokens)}") - else: - logger.info(f"[/tokenize] Skipping _replace_prefix_tokens, one of model_prefix_tokens={model_prefix_tokens is not None}, last_assistant_idx={last_assistant_idx} is None") - return { "tokens": result_tokens, "model": request.model or script_args.model From fe4bce62990b68bde160d78ace73ffa0725dbc9c Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Thu, 15 Jan 2026 09:39:07 -0800 Subject: [PATCH 11/51] fixes Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/config.yaml | 15 +++++++------ examples/scripts/nemo_gym/train.py | 32 ++++++++++++++++++--------- trl/trainer/grpo_trainer.py | 24 ++++++++++++++++++-- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 95a8724181d..cf813078887 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -7,15 +7,15 @@ # dataset_path: "/lustre/fsw/portfolios/llmservice/users/cmunley/Gym/resources_servers/xlam_fc/data/train.jsonl" model_name: "Qwen/Qwen3-4B-Instruct-2507" -dataset_path: "train-mcqa.jsonl" -task: "mcqa" # just used in wandb run_name if name not set explicitly -output_dir: "outputs/trl_nemo_gym_mcqa" +dataset_path: "train-workplace.jsonl" +task: "workplace_IS_false" # just used in wandb run_name if name not set explicitly +output_dir: "outputs/trl_nemo_gym_workplace" project_name: "cmunley-nemo-gym-trl-int" -learning_rate: 3.0e-6 +learning_rate: 1.0e-5 max_steps: 1000000 num_generations: 8 per_device_train_batch_size: 1 -gradient_accumulation_steps: 8 +gradient_accumulation_steps: 16 max_seq_length: 16384 temperature: 1 top_p: 0.999 @@ -28,6 +28,7 @@ save_steps: 900000 report_to: "wandb" log_completions: true num_completions_to_print: 2 -eval_dataset_path: "val-mcqa-50.jsonl" +eval_dataset_path: "val-workplace-50.jsonl" eval_strategy: "steps" -eval_steps: 5 \ No newline at end of file +eval_steps: 5 +vllm_importance_sampling_correction: false diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index a57281d89ac..f0d2acd5e6d 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -70,6 +70,7 @@ class TrainingConfig: output_dir: str = "outputs/trl_nemo_gym" save_steps: int = 100 + save_total_limit: int = None report_to: str = "none" run_name: str = None # Wandb project_name: str = None # Wandb @@ -79,6 +80,9 @@ class TrainingConfig: eval_dataset_path: Optional[str] = None eval_strategy: str = "no" eval_steps: int = 50 + eval_on_start: bool = False + + vllm_importance_sampling_correction: bool = False def reward_fn(completions: List[str], **kwargs) -> List[float]: env_rewards = kwargs.get("env_reward") @@ -96,6 +100,7 @@ async def call_nemo_gym_agent( ) -> List[Dict[str, Any]]: print(f"Calling Nemo Gym agent at {agent_server} with {len(prompts)} prompts") + # todo: increase limits async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session: tasks = [] for prompt, item in zip(prompts, dataset_items): @@ -118,16 +123,13 @@ async def call_nemo_gym_agent( ) tasks.append(task) - responses = [] - with tqdm(total=len(tasks), desc="Agent requests") as pbar: - for coro in asyncio.as_completed(tasks): - result = await coro - responses.append(result) - pbar.update(1) + responses = await asyncio.gather(*tasks, return_exceptions=True) results = [] for i, response in enumerate(responses): try: + if isinstance(response, Exception): + raise response json_data = await response.json() if not isinstance(json_data, dict): raise ValueError(f"Expected dict, got {type(json_data)}") @@ -198,7 +200,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, prompt_ids: List[List[int]] = [] completion_ids: List[List[int]] = [] - completion_mask: List[List[int]] = [] # 1 for action, 0 for observation + completion_mask: List[List[int]] = [] # 1 for action, 0 for observation/user logprobs: List[List[float]] = [] env_rewards: List[float] = [] num_turns_list: List[int] = [] @@ -254,7 +256,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, for idx, item in enumerate(output_items): if "prompt_token_ids" not in item or "generation_token_ids" not in item: - raise ValueError(f"Item {idx} missing prompt_token_ids or generation_token_ids") + continue num_turns += 1 item_prompt_ids = item["prompt_token_ids"] @@ -278,7 +280,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, interleaved_completion.extend(tool_result_tokens) interleaved_mask.extend([0] * len(tool_result_tokens)) interleaved_logprobs.extend([0.0] * len(tool_result_tokens)) - + interleaved_completion.extend(item_gen_ids) interleaved_mask.extend([1] * len(item_gen_ids)) assert len(item_logprobs) == len(item_gen_ids), f"Logprobs len {len(item_logprobs)} != gen len {len(item_gen_ids)}" @@ -375,11 +377,18 @@ def main(): help="vLLM server hostname/IP") parser.add_argument("--head_server_host", type=str, default="127.0.0.1", help="Head server hostname/IP for ng_run") + parser.add_argument("--resume_from_checkpoint", type=str, default=None, + help="Path to checkpoint to resume from") args = parser.parse_args() with open(args.config) as f: config = TrainingConfig(**yaml.safe_load(f)) + if isinstance(config.learning_rate, str): + config.learning_rate = float(config.learning_rate) + if isinstance(config.weight_decay, str): + config.weight_decay = float(config.weight_decay) + agent_server = get_agent_server( head_server_host=args.head_server_host, head_server_port=11000, @@ -440,6 +449,7 @@ def main(): max_steps=config.max_steps, save_steps=config.save_steps, + save_total_limit=config.save_total_limit, logging_steps=1, report_to=config.report_to, output_dir=config.output_dir, @@ -448,6 +458,7 @@ def main(): eval_strategy=config.eval_strategy, eval_steps=config.eval_steps, + vllm_importance_sampling_correction=config.vllm_importance_sampling_correction, epsilon=0.2, epsilon_high=0.28, loss_type="grpo", @@ -457,7 +468,6 @@ def main(): max_prompt_length=config.max_prompt_length, max_completion_length=config.max_seq_length - config.max_prompt_length, - shuffle_dataset=False, model_init_kwargs={ @@ -477,7 +487,7 @@ def main(): args=training_args, ) - trainer.train() + trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) if __name__ == "__main__": main() diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index be4ac452395..0c682e24ddf 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1809,7 +1809,10 @@ def _generate_and_score_completions( # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + # attend to all non-padding tokens, but mask out user/tool result tokens in loss + completion_attention_mask = (completion_ids != self.pad_token_id).long() + attention_mask = torch.cat([prompt_mask, completion_attention_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size @@ -1864,6 +1867,19 @@ def _generate_and_score_completions( else: old_per_token_logps = None + # track sampling logp diff even when IS off for debugging + # could remove this + if self.use_vllm and sampling_per_token_logps is not None and old_per_token_logps is None: + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, + ) + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: mask = completion_mask if not self.tools else completion_mask * tool_mask @@ -2000,7 +2016,9 @@ def _generate_and_score_completions( if images is not None: self._logs["images"].extend(gather_object(images)) - if self.use_vllm and self.vllm_importance_sampling_correction: + # track sampling logp diff even when IS off for debugging + # could remove this + if self.use_vllm and old_per_token_logps is not None and sampling_per_token_logps is not None: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool() delta = delta[mask] @@ -2013,6 +2031,8 @@ def _generate_and_score_completions( self.accelerator.gather(max_delta).max().item() ) + # track IS ratio only when IS correction is enabled + if self.use_vllm and self.vllm_importance_sampling_correction: if sequence_level_is: flat_is_ratio = vllm_importance_sampling_ratio.flatten() else: From facfb5af2528d5e8613406353233f903321a7a4a Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Thu, 15 Jan 2026 09:58:34 -0800 Subject: [PATCH 12/51] remove flag Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index f0d2acd5e6d..fda015adb1f 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -70,7 +70,6 @@ class TrainingConfig: output_dir: str = "outputs/trl_nemo_gym" save_steps: int = 100 - save_total_limit: int = None report_to: str = "none" run_name: str = None # Wandb project_name: str = None # Wandb @@ -80,7 +79,6 @@ class TrainingConfig: eval_dataset_path: Optional[str] = None eval_strategy: str = "no" eval_steps: int = 50 - eval_on_start: bool = False vllm_importance_sampling_correction: bool = False @@ -449,7 +447,6 @@ def main(): max_steps=config.max_steps, save_steps=config.save_steps, - save_total_limit=config.save_total_limit, logging_steps=1, report_to=config.report_to, output_dir=config.output_dir, From ac94e1b18d65920382a46ea2ce632fc6d6bda44f Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 00:18:13 -0800 Subject: [PATCH 13/51] multi env Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/config.yaml | 46 +++-- examples/scripts/nemo_gym/train.py | 241 ++++++++++---------------- 2 files changed, 112 insertions(+), 175 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index cf813078887..5af51bd1196 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,34 +1,32 @@ -# model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" -# model_name: "Qwen/Qwen3-0.6B" -# model_name: "Qwen/Qwen2.5-1.5B-Instruct" -# model_name: "Qwen/Qwen3-30B-A3B-Instruct-2507" -# dataset_path: "/lustre/fsw/portfolios/llmservice/users/cmunley/Gym/resources_servers/reasoning_gym/data/train_knights_knaves.jsonl" -# dataset_path: "train-workplace.jsonl" -# dataset_path: "/lustre/fsw/portfolios/llmservice/users/cmunley/Gym/resources_servers/xlam_fc/data/train.jsonl" - model_name: "Qwen/Qwen3-4B-Instruct-2507" -dataset_path: "train-workplace.jsonl" -task: "workplace_IS_false" # just used in wandb run_name if name not set explicitly -output_dir: "outputs/trl_nemo_gym_workplace" -project_name: "cmunley-nemo-gym-trl-int" + +dataset_path: "data/train.jsonl" +eval_dataset_path: "data/val.jsonl" + +output_dir: "outputs/nemo_gym" +run_name_prefix: "nemo_gym" +report_to: "wandb" +project_name: "trl-nemo-gym" +log_completions: true +num_completions_to_print: 2 + learning_rate: 1.0e-5 -max_steps: 1000000 +max_steps: 1000 num_generations: 8 per_device_train_batch_size: 1 gradient_accumulation_steps: 16 max_seq_length: 16384 -temperature: 1 -top_p: 0.999 -weight_decay: 0.00 -# warmup_ratio: 0.0 warmup_steps: 10 lr_scheduler_type: "linear" optim: "adamw_torch_fused" -save_steps: 900000 -report_to: "wandb" -log_completions: true -num_completions_to_print: 2 -eval_dataset_path: "val-workplace-50.jsonl" -eval_strategy: "steps" -eval_steps: 5 +weight_decay: 0.0 vllm_importance_sampling_correction: false + +temperature: 1.0 +top_p: 0.999 + +save_steps: 100 + +eval_strategy: "steps" +eval_steps: 50 + diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index fda015adb1f..d64bfbefbc7 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -10,47 +10,14 @@ from dataclasses import dataclass from datasets import Dataset, load_dataset from trl import GRPOConfig, GRPOTrainer -from tqdm import tqdm import wandb -from transformers import AutoTokenizer - -def get_agent_server( - head_server_host: str = "127.0.0.1", - head_server_port: int = 11000, -) -> str: - try: - response = requests.get( - f"http://{head_server_host}:{head_server_port}/global_config_dict_yaml", - timeout=10 - ) - response.raise_for_status() - global_config_yaml = response.text - global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) - - for _, project_config in global_config_dict.items(): - if hasattr(project_config, 'responses_api_agents'): - agents = project_config.responses_api_agents - for name in agents.keys(): - agent_config = getattr(agents, name) - if hasattr(agent_config, 'host') and hasattr(agent_config, 'port'): - agent_host = agent_config.host - if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"): - agent_host = head_server_host - agent_server = f"http://{agent_host}:{agent_config.port}" - return agent_server - - raise ValueError("No agents found in global config") - - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") - @dataclass class TrainingConfig: model_name: str dataset_path: str - task: Optional[str] = None + run_name_prefix: Optional[str] = None learning_rate: float = 5e-6 max_steps: int = 100 @@ -82,6 +49,39 @@ class TrainingConfig: vllm_importance_sampling_correction: bool = False +def get_agent_servers( + head_server_host: str = "127.0.0.1", + head_server_port: int = 11000, +) -> Dict[str, str]: + try: + response = requests.get( + f"http://{head_server_host}:{head_server_port}/global_config_dict_yaml", + timeout=10 + ) + response.raise_for_status() + global_config_yaml = response.text + global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) + + agent_servers = {} + for _, project_config in global_config_dict.items(): + if hasattr(project_config, 'responses_api_agents'): + agents = project_config.responses_api_agents + for name in agents.keys(): + agent_config = getattr(agents, name) + if hasattr(agent_config, 'host') and hasattr(agent_config, 'port'): + agent_host = agent_config.host + if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"): + agent_host = head_server_host + agent_servers[name] = f"http://{agent_host}:{agent_config.port}" + + if not agent_servers: + raise ValueError("No agents found in global config") + + return agent_servers + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") + def reward_fn(completions: List[str], **kwargs) -> List[float]: env_rewards = kwargs.get("env_reward") assert env_rewards is not None, "env_reward not found in kwargs" @@ -90,15 +90,12 @@ def reward_fn(completions: List[str], **kwargs) -> List[float]: async def call_nemo_gym_agent( prompts: List[str], dataset_items: List[Dict[str, Any]], - agent_server: str, + agent_servers: Dict[str, str], timeout: float, max_completion_length: int = 4096, temperature: float = 1.0, top_p: float = 0.999, ) -> List[Dict[str, Any]]: - print(f"Calling Nemo Gym agent at {agent_server} with {len(prompts)} prompts") - - # todo: increase limits async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session: tasks = [] for prompt, item in zip(prompts, dataset_items): @@ -114,8 +111,14 @@ async def call_nemo_gym_agent( params["temperature"] = temperature params["top_p"] = top_p + agent_ref = item.get("agent_ref", {}) + agent_name = agent_ref.get("name") if isinstance(agent_ref, dict) else None + if not agent_name or agent_name not in agent_servers: + raise ValueError(f"Missing or invalid agent_ref. Got: {agent_ref}. Available: {list(agent_servers.keys())}") + agent_url = agent_servers[agent_name] + task = session.post( - f"{agent_server}/run", + f"{agent_url}/run", json=request_body, timeout=aiohttp.ClientTimeout(total=timeout), ) @@ -140,8 +143,6 @@ async def call_nemo_gym_agent( def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, List]: - current_step = trainer.state.global_step - is_eval = not trainer.model.training num_generations = trainer.args.num_generations_eval if is_eval and trainer.args.num_generations_eval else trainer.args.num_generations dataset = trainer.eval_dataset if is_eval and trainer.eval_dataset is not None else trainer.train_dataset @@ -149,19 +150,17 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, expanded_prompts = [] expanded_dataset_items = [] - for prompt in prompts: - matching_item = None - for item in dataset: - if item.get("prompt") == prompt: - matching_item = dict(item) - for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata"]: - if key in matching_item and isinstance(matching_item[key], str): - matching_item[key] = json.loads(matching_item[key]) - break + for idx_str in prompts: + idx = int(idx_str) + item = dict(dataset[idx]) + + for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata", "agent_ref"]: + if key in item and isinstance(item[key], str): + item[key] = json.loads(item[key]) for _ in range(num_generations): - expanded_prompts.append(prompt) - expanded_dataset_items.append(dict(matching_item)) + expanded_prompts.append(idx_str) + expanded_dataset_items.append(dict(item)) loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -170,7 +169,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, call_nemo_gym_agent( expanded_prompts, expanded_dataset_items, - trainer.args.agent_server, + trainer.args.agent_servers, trainer.args.request_timeout, trainer.args.max_completion_length, temperature=trainer.args.temperature, @@ -180,56 +179,35 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, finally: loop.close() - trajectory_file = os.path.join(trainer.args.output_dir, "trajectories.jsonl") - os.makedirs(trainer.args.output_dir, exist_ok=True) - - with open(trajectory_file, 'a') as f: - for i, response in enumerate(responses): - trajectory_data = { - "step": current_step, - "rollout_idx": i, - "reward": response.get("reward", 0.0) if isinstance(response, dict) else 0.0, - "output": response.get("response", {}).get("output", []) if isinstance(response, dict) else [], - "error": response.get("error") if isinstance(response, dict) else str(response), - } - f.write(json.dumps(trajectory_data) + "\n") - tokenizer = trainer.processing_class - prompt_ids: List[List[int]] = [] - completion_ids: List[List[int]] = [] - completion_mask: List[List[int]] = [] # 1 for action, 0 for observation/user + prompt_ids: List[List[int]] = [] + completion_ids: List[List[int]] = [] # list of rollouts + completion_mask: List[List[int]] = [] # only train on assistant turns + logprobs: List[List[float]] = [] env_rewards: List[float] = [] num_turns_list: List[int] = [] for i, response in enumerate(responses): - expected_prompt = expanded_prompts[i] - expected_prompt_ids = tokenizer.encode(expected_prompt, add_special_tokens=False) + eos_token_id = tokenizer.eos_token_id or 0 - if not isinstance(response, dict): - rollout_failed = True - elif response.get("error"): + if not isinstance(response, dict) or response.get("error"): rollout_failed = True else: output_items = response.get("response", {}).get("output", []) - if not output_items: - rollout_failed = True - else: - has_content = any( - item.get("type") == "function_call" or ( - item.get("type") == "message" and - any(c.get("type") == "output_text" and c.get("text", "").strip() - for c in item.get("content", [])) - ) - for item in output_items + has_content = output_items and any( + item.get("type") == "function_call" or ( + item.get("type") == "message" and + any(c.get("type") == "output_text" and c.get("text", "").strip() + for c in item.get("content", [])) ) - rollout_failed = not has_content + for item in output_items + ) + rollout_failed = not has_content - # truncated or other failure - mask if rollout_failed: - eos_token_id = tokenizer.eos_token_id or 0 - prompt_ids.append(expected_prompt_ids) + prompt_ids.append([eos_token_id]) completion_ids.append([eos_token_id]) completion_mask.append([0]) logprobs.append([0.0]) @@ -240,15 +218,11 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, episode_reward = response.get("reward", 0.0) output_items = response.get("response", {}).get("output", []) - # interleaved completion with mask (p,a,o,a,o...) - # Each turn has prompt_ids, gen_ids - # tool_result = prompt_ids - seen_token_ids (to mask it) - # replace_prefix_tokens done in vllm server https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html + rollout_ids: List[int] = [] + rollout_mask: List[int] = [] + rollout_logprobs: List[float] = [] seen_token_ids: List[int] = [] - interleaved_completion: List[int] = [] - interleaved_mask: List[int] = [] - interleaved_logprobs: List[float] = [] first_prompt = None num_turns = 0 @@ -275,24 +249,24 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, tool_result_tokens = item_prompt_ids[len(seen_token_ids):] if tool_result_tokens: - interleaved_completion.extend(tool_result_tokens) - interleaved_mask.extend([0] * len(tool_result_tokens)) - interleaved_logprobs.extend([0.0] * len(tool_result_tokens)) + rollout_ids.extend(tool_result_tokens) + rollout_mask.extend([0] * len(tool_result_tokens)) + rollout_logprobs.extend([0.0] * len(tool_result_tokens)) - interleaved_completion.extend(item_gen_ids) - interleaved_mask.extend([1] * len(item_gen_ids)) + rollout_ids.extend(item_gen_ids) + rollout_mask.extend([1] * len(item_gen_ids)) assert len(item_logprobs) == len(item_gen_ids), f"Logprobs len {len(item_logprobs)} != gen len {len(item_gen_ids)}" - interleaved_logprobs.extend(item_logprobs) + rollout_logprobs.extend(item_logprobs) seen_token_ids = list(item_prompt_ids) + list(item_gen_ids) - if not interleaved_completion or first_prompt is None: + if not rollout_ids or first_prompt is None: raise ValueError(f"Rollout {i} has no valid turns") - prompt_ids.append(first_prompt) - completion_ids.append(interleaved_completion) - completion_mask.append(interleaved_mask) - logprobs.append(interleaved_logprobs) + prompt_ids.append(first_prompt) # list of prompts + completion_ids.append(rollout_ids) # list of rollouts + completion_mask.append(rollout_mask) + logprobs.append(rollout_logprobs) env_rewards.append(episode_reward) num_turns_list.append(num_turns) @@ -309,7 +283,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, "train/num_turns_max": max(num_turns_list), }) - unique_prompt_ids = prompt_ids[::num_generations] # TRL re-duplicates them + unique_prompt_ids = prompt_ids[::num_generations] return { "prompt_ids": unique_prompt_ids, @@ -320,47 +294,16 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, "num_turns": num_turns_list, } -def get_max_prompt_length(dataset: Dataset, tokenizer) -> int: - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - return max(len(tokenizer.encode(item.get("prompt", ""))) for item in dataset if item.get("prompt")) def load_dataset_from_jsonl(path: str) -> Dataset: data = [] with open(path, 'r') as f: - for line in f: + for idx, line in enumerate(f): if line.strip(): item = json.loads(line) + item["prompt"] = str(idx) - if "prompt" not in item: - if "responses_create_params" in item and isinstance(item["responses_create_params"], dict): - responses_params = item["responses_create_params"] - input_data = responses_params.get("input") - instructions = responses_params.get("instructions", "") - - if isinstance(input_data, list) and len(input_data) > 0: - # list of messages format (e.g. reasoning_gym) - prompt_parts = [] - if instructions: - prompt_parts.append(f"system: {instructions}") - for msg in input_data: - if isinstance(msg, dict) and "role" in msg and "content" in msg: - prompt_parts.append(f"{msg['role']}: {msg['content']}") - item["prompt"] = "\n\n".join(prompt_parts) if prompt_parts else "" - elif isinstance(input_data, str): - # prompt as string, no list of messages (e.g. google_search) - prompt_parts = [] - # system prompt - if instructions: - prompt_parts.append(instructions) - if input_data: - prompt_parts.append(input_data) - item["prompt"] = "\n\n".join(prompt_parts) if prompt_parts else "" - else: - item["prompt"] = item.get("question", "") - else: - item["prompt"] = item.get("question", "") - - for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata"]: + for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata", "agent_ref"]: if key in item and isinstance(item[key], (dict, list)): item[key] = json.dumps(item[key]) @@ -387,7 +330,7 @@ def main(): if isinstance(config.weight_decay, str): config.weight_decay = float(config.weight_decay) - agent_server = get_agent_server( + agent_servers = get_agent_servers( head_server_host=args.head_server_host, head_server_port=11000, ) @@ -396,10 +339,10 @@ def main(): os.environ["WANDB_PROJECT"] = config.project_name if config.run_name is None: - task = config.task or os.path.basename(config.dataset_path).replace('.jsonl', '').replace('.json', '') + run_name_prefix = config.run_name_prefix or os.path.basename(config.dataset_path).replace('.jsonl', '').replace('.json', '') model_short = config.model_name.split("/")[-1] config.run_name = ( - f"{task}_{model_short}" + f"{run_name_prefix}_{model_short}" f"_rpp{config.num_generations}" f"_dbs{config.per_device_train_batch_size}" f"_ga{config.gradient_accumulation_steps}" @@ -414,16 +357,12 @@ def main(): else: dataset = load_dataset(config.dataset_path, split="train") - print(f"Dataset has {len(dataset)} examples\n") eval_dataset = None if config.eval_dataset_path: eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path) print(f"Eval dataset has {len(eval_dataset)} examples\n") - if config.max_prompt_length is None: - config.max_prompt_length = get_max_prompt_length(dataset, config.model_name) - training_args = GRPOConfig( use_vllm=True, vllm_mode="server", @@ -450,7 +389,7 @@ def main(): logging_steps=1, report_to=config.report_to, output_dir=config.output_dir, - run_name=config.run_name, # wandb + run_name=config.run_name, eval_strategy=config.eval_strategy, eval_steps=config.eval_steps, @@ -472,7 +411,7 @@ def main(): }, ) - training_args.agent_server = agent_server + training_args.agent_servers = agent_servers training_args.request_timeout = 6000 trainer = GRPOTrainer( From 32c5a6bd20f656e1723b247d0c548726d131a475 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 10:55:45 -0800 Subject: [PATCH 14/51] small fix Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index d64bfbefbc7..f8cc58c358d 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -10,6 +10,7 @@ from dataclasses import dataclass from datasets import Dataset, load_dataset from trl import GRPOConfig, GRPOTrainer +from transformers import AutoTokenizer import wandb @dataclass @@ -17,7 +18,7 @@ class TrainingConfig: model_name: str dataset_path: str - run_name_prefix: Optional[str] = None + task: Optional[str] = None learning_rate: float = 5e-6 max_steps: int = 100 @@ -63,16 +64,16 @@ def get_agent_servers( global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) agent_servers = {} - for _, project_config in global_config_dict.items(): + for project_name, project_config in global_config_dict.items(): if hasattr(project_config, 'responses_api_agents'): agents = project_config.responses_api_agents - for name in agents.keys(): - agent_config = getattr(agents, name) + for agent_key in agents.keys(): + agent_config = getattr(agents, agent_key) if hasattr(agent_config, 'host') and hasattr(agent_config, 'port'): agent_host = agent_config.host if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"): agent_host = head_server_host - agent_servers[name] = f"http://{agent_host}:{agent_config.port}" + agent_servers[project_name] = f"http://{agent_host}:{agent_config.port}" if not agent_servers: raise ValueError("No agents found in global config") @@ -87,7 +88,7 @@ def reward_fn(completions: List[str], **kwargs) -> List[float]: assert env_rewards is not None, "env_reward not found in kwargs" return [float(r) for r in env_rewards] -async def call_nemo_gym_agent( +async def call_nemo_gym_agents( prompts: List[str], dataset_items: List[Dict[str, Any]], agent_servers: Dict[str, str], @@ -166,7 +167,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, asyncio.set_event_loop(loop) try: responses = loop.run_until_complete( - call_nemo_gym_agent( + call_nemo_gym_agents( expanded_prompts, expanded_dataset_items, trainer.args.agent_servers, @@ -339,10 +340,10 @@ def main(): os.environ["WANDB_PROJECT"] = config.project_name if config.run_name is None: - run_name_prefix = config.run_name_prefix or os.path.basename(config.dataset_path).replace('.jsonl', '').replace('.json', '') + task = config.task or os.path.basename(config.dataset_path).replace('.jsonl', '').replace('.json', '') model_short = config.model_name.split("/")[-1] config.run_name = ( - f"{run_name_prefix}_{model_short}" + f"{task}_{model_short}" f"_rpp{config.num_generations}" f"_dbs{config.per_device_train_batch_size}" f"_ga{config.gradient_accumulation_steps}" @@ -403,7 +404,7 @@ def main(): num_completions_to_print=config.num_completions_to_print, max_prompt_length=config.max_prompt_length, - max_completion_length=config.max_seq_length - config.max_prompt_length, + max_completion_length=config.max_seq_length - config.max_prompt_length if config.max_prompt_length else config.max_seq_length, shuffle_dataset=False, model_init_kwargs={ @@ -414,8 +415,11 @@ def main(): training_args.agent_servers = agent_servers training_args.request_timeout = 6000 + tokenizer = AutoTokenizer.from_pretrained(config.model_name, truncation_side="left", padding_side="left") + trainer = GRPOTrainer( model=config.model_name, + processing_class=tokenizer, reward_funcs=reward_fn, train_dataset=dataset, eval_dataset=eval_dataset, From 5619096edd2b26c7b825a9027b4f725c55faeffd Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 14:57:53 -0800 Subject: [PATCH 15/51] dataset index Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train.py index f8cc58c358d..4827c6bd576 100644 --- a/examples/scripts/nemo_gym/train.py +++ b/examples/scripts/nemo_gym/train.py @@ -153,11 +153,7 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, for idx_str in prompts: idx = int(idx_str) - item = dict(dataset[idx]) - - for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata", "agent_ref"]: - if key in item and isinstance(item[key], str): - item[key] = json.loads(item[key]) + item = json.loads(dataset[idx]["metadata"]) for _ in range(num_generations): expanded_prompts.append(idx_str) @@ -302,14 +298,10 @@ def load_dataset_from_jsonl(path: str) -> Dataset: for idx, line in enumerate(f): if line.strip(): item = json.loads(line) - item["prompt"] = str(idx) - - for key in ["responses_create_params", "expected_answers", "metadata", "ground_truth", "options", "template_metadata", "agent_ref"]: - if key in item and isinstance(item[key], (dict, list)): - item[key] = json.dumps(item[key]) - - data.append(item) - + data.append({ + "prompt": str(idx), # use index for lookup as not all nemo gym datasets have the same metadata fields. maybe not the most elegant + "metadata": json.dumps(item), + }) return Dataset.from_list(data) def main(): @@ -358,7 +350,6 @@ def main(): else: dataset = load_dataset(config.dataset_path, split="train") - eval_dataset = None if config.eval_dataset_path: eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path) @@ -413,7 +404,7 @@ def main(): ) training_args.agent_servers = agent_servers - training_args.request_timeout = 6000 + training_args.request_timeout = 10800 tokenizer = AutoTokenizer.from_pretrained(config.model_name, truncation_side="left", padding_side="left") From 04821b5aab7c4f0c204399a2ecc285c655eac781 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 19:43:14 -0800 Subject: [PATCH 16/51] multinode example Signed-off-by: cmunley1 --- submit.sh | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 submit.sh diff --git a/submit.sh b/submit.sh new file mode 100644 index 00000000000..3317bc5b228 --- /dev/null +++ b/submit.sh @@ -0,0 +1,109 @@ +#!/bin/bash +#SBATCH -A account +#SBATCH -p partition +#SBATCH -N 5 +#SBATCH --gres gpu:8 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --time=4:00:00 +#SBATCH --job-name=trl_nemo_gym +#SBATCH --output=logs/%j/slurm.out +#SBATCH --error=logs/%j/slurm.err + +CONTAINER_IMAGE="nvcr.io/nvidia/pytorch:25.12-py3" +MOUNTS="/path/to/mounts:/path/to/mounts" + +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) + +TRAIN_NODE_0="${NODELIST[0]}" +TRAIN_NODE_1="${NODELIST[1]}" +TRAIN_NODE_2="${NODELIST[2]}" +TRAIN_NODE_3="${NODELIST[3]}" +VLLM_NODE="${NODELIST[4]}" + +echo "Training Nodes: $TRAIN_NODE_0, $TRAIN_NODE_1, $TRAIN_NODE_2, $TRAIN_NODE_3" +echo "vLLM Node: $VLLM_NODE" +echo "Main process IP: $TRAIN_NODE_0" + +LOG_DIR="logs/${SLURM_JOB_ID}" +mkdir -p ${LOG_DIR} + +echo "Starting ng_run and vLLM on ${VLLM_NODE}..." +echo "Logs will be saved to: ${LOG_DIR}" + +srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${MOUNTS}" \ + --container-mount-home \ + bash -c " + LOG_DIR=/path/to/logs + mkdir -p \${LOG_DIR} + + # Install uv if not already installed + curl -LsSf https://astral.sh/uv/install.sh | sh + source \$HOME/.local/bin/env + + # Start nemo gym servers + (set -x && \ + export HOME=/path/to/user && \ + export PATH=\$HOME/.local/bin:\$PATH && \ + cd /path/to/user/Gym && \ + uv venv --python 3.12 && \ + source .venv/bin/activate && \ + uv sync && \ + ray stop --force && \ + ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0 +head_server.port=11000) > \${LOG_DIR}/ng_run.log 2>&1 & + + sleep 10 + + # Start trl vllm server + (set -x && \ + export HOME=/path/to/user && \ + export HF_HOME=/path/to/user/hf_home && \ + cd /path/to/user/trl && \ + source .venv/bin/activate && \ + python -m trl.scripts.vllm_serve \ + --model Qwen/Qwen3-4B-Instruct-2507 \ + --host 0.0.0.0 \ + --tensor-parallel-size 8 \ + --data-parallel-size 1 \ + --max-model-len 16384 \ + --gpu-memory-utilization 0.7 \ + --port 8000) > \${LOG_DIR}/vllm_serve.log 2>&1 & + + wait +" & + +echo "Waiting for nemo gym and vllm to start..." +sleep 120 + +echo "Launching training on 4 nodes..." + +TRAIN_NODES_LIST="${TRAIN_NODE_0},${TRAIN_NODE_1},${TRAIN_NODE_2},${TRAIN_NODE_3}" + +srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${MOUNTS}" \ + --container-mount-home \ + bash -c " + set -x && \ + export HOME=/path/to/user && \ + export HF_HOME=/path/to/user/hf_home && \ + cd /path/to/user/trl && \ + source .venv/bin/activate && \ + cd examples/scripts/nemo_gym && \ + accelerate launch \ + --config_file deepspeed_zero3.yaml \ + --num_processes 32 \ + --num_machines 4 \ + --machine_rank \$SLURM_PROCID \ + --main_process_ip ${TRAIN_NODE_0} \ + --main_process_port 29500 \ + --rdzv_backend c10d \ + train.py \ + --config config.yaml \ + --vllm_server_host ${VLLM_NODE} \ + --head_server_host ${VLLM_NODE}" & + +wait + From 52b2f5c9d4211e4903388630fa84defdb685eeac Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 19:52:36 -0800 Subject: [PATCH 17/51] client and tests Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/README.md | 19 +++++--- tests/test_vllm_client_server.py | 58 ++++++++++++++++++++++ trl/extras/vllm_client.py | 76 +++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 6 deletions(-) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index 0958655ce4d..cbc26050cf3 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -1,12 +1,12 @@ -# NeMo Gym TRL GRPO integration +# Post-training with NeMo Gym and TRL -Multi-step GRPO with TRL and NeMo Gym. +This integration supports training language models in NeMo-Gym environments using TRL GRPO. -## Setup +## Interactive single node 1. Launch vLLM server: ```bash -CUDA_VISIBLE_DEVICES=4,5,6,7 trl vllm-serve \ +CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve \ --model Qwen/Qwen3-4B-Instruct-2507 \ --tensor-parallel-size 4 \ --max-model-len 8192 \ @@ -21,7 +21,14 @@ ng_run "+config_paths=[resources_servers/workplace_assistant/configs/workplace_a 3. Run training: ```bash -CUDA_VISIBLE_DEVICES=0 python train.py --config config.yaml +CUDA_VISIBLE_DEVICES=4 python train.py --config config.yaml ``` -multinode is working, an example will be uploaded soon! \ No newline at end of file +## Multinode with slurm + +See submit.sh for a multinode example! + +``` + + +``` \ No newline at end of file diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 97fb17ab05a..0b49552e3a3 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -147,6 +147,64 @@ def test_reset_prefix_cache(self): # Test resetting the prefix cache self.client.reset_prefix_cache() + def test_chat_completions_endpoint(self): + data = self.client.chat_completions( + messages=[{"role": "user", "content": "Say hello"}], + max_tokens=32, + ) + + assert "id" in data + assert "choices" in data + assert "usage" in data + assert len(data["choices"]) > 0 + assert data["choices"][0]["message"]["role"] == "assistant" + assert data["choices"][0]["finish_reason"] in ["stop", "length", "tool_calls"] + + def test_chat_completions_with_tools(self): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information for a location", + "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}, + }, + } + ] + data = self.client.chat_completions( + messages=[{"role": "user", "content": "What's the weather in San Francisco?"}], + tools=tools, + max_tokens=100, + ) + + assert "choices" in data + assert len(data["choices"]) > 0 + assert "message" in data["choices"][0] + + def test_chat_completions_with_params(self): + data = self.client.chat_completions( + messages=[{"role": "user", "content": "Tell me a joke"}], + n=2, + temperature=0.8, + top_p=0.9, + max_tokens=32, + ) + + assert len(data["choices"]) == 2 + + for choice in data["choices"]: + assert "message" in choice + assert choice["message"]["role"] == "assistant" + + def test_tokenize_endpoint(self): + data = self.client.tokenize(messages=[{"role": "user", "content": "Hello, how are you?"}]) + + assert "tokens" in data + assert "model" in data + assert isinstance(data["tokens"], list) + assert len(data["tokens"]) > 0 + assert all(isinstance(tok, int) for tok in data["tokens"]) + @classmethod def teardown_class(cls): # Close the client diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index e21df6d837e..14ed9c0a90d 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -492,6 +492,82 @@ def reset_prefix_cache(self): if response.status_code != 200: raise Exception(f"Request failed: {response.status_code}, {response.text}") + def chat_completions( + self, + messages: list[dict], + model: str | None = None, + temperature: float = 1.0, + top_p: float = 1.0, + max_tokens: int | None = None, + n: int = 1, + tools: list[dict] | None = None, + **kwargs, + ) -> dict: + """ + OpenAI-compatible chat completions endpoint. + + Args: + messages (`list[dict]`): + List of messages in OpenAI format with "role" and "content" keys. + model (`str`, *optional*): + Model name to use. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for sampling. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. + max_tokens (`int`, *optional*): + Maximum number of tokens to generate. + n (`int`, *optional*, defaults to `1`): + Number of completions to generate. + tools (`list[dict]`, *optional*): + List of tool definitions for tool calling. + **kwargs: + Additional parameters to pass to the endpoint. + + Returns: + `dict`: + OpenAI-compatible response with "choices", "usage", etc. + """ + url = f"{self.base_url}/v1/chat/completions" + response = self.session.post( + url, + json={ + "messages": messages, + "model": model, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "n": n, + "tools": tools, + **kwargs, + }, + ) + if response.status_code == 200: + return response.json() + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def tokenize(self, messages: list[dict], tools: list[dict] | None = None) -> dict: + """ + Tokenize messages to get token IDs. + + Args: + messages (`list[dict]`): + List of messages to tokenize. + tools (`list[dict]`, *optional*): + List of tool definitions. + + Returns: + `dict`: + Dictionary with "tokens" (list of token IDs) and "model" keys. + """ + url = f"{self.base_url}/tokenize" + response = self.session.post(url, json={"messages": messages, "tools": tools}) + if response.status_code == 200: + return response.json() + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + def close_communicator(self): """ Closes the weight update group and cleans up the communication group. From 0793c054e4c7dd8e14f720695f0afb256bcd7b53 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:10:36 -0800 Subject: [PATCH 18/51] remove native tool parsing, use fastapi state Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/README.md | 5 ----- trl/scripts/vllm_serve.py | 20 ++++---------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index cbc26050cf3..736f60e2e4c 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -27,8 +27,3 @@ CUDA_VISIBLE_DEVICES=4 python train.py --config config.yaml ## Multinode with slurm See submit.sh for a multinode example! - -``` - - -``` \ No newline at end of file diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index e3d25ed4960..4c9661276ab 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -34,7 +34,6 @@ from transformers import AutoTokenizer from trl import TrlParser -# from trl.chat_template_utils import add_response_schema # For native tool call parsing from trl.import_utils import ( is_fastapi_available, is_pydantic_available, @@ -522,21 +521,10 @@ def main(script_args: ScriptArguments): connections.append(parent_connection) processes.append(process) - cached_tokenizer = None - @asynccontextmanager async def lifespan(app: FastAPI): - nonlocal cached_tokenizer - logger.info(f"Loading tokenizer for {script_args.model}...") - cached_tokenizer = AutoTokenizer.from_pretrained(script_args.model, trust_remote_code=script_args.trust_remote_code) - - # uncomment for native tool call parsing - # try: - # cached_tokenizer = add_response_schema(cached_tokenizer) - # logger.info("Response schema added - vLLM will use native tool call parsing") - # except (ValueError, AttributeError) as e: - # logger.warning(f"Could not add response schema: {e}. Will fall back to XML parsing if tools are used.") + app.state.tokenizer = AutoTokenizer.from_pretrained(script_args.model, trust_remote_code=script_args.trust_remote_code) # Wait for all workers to send "ready" ready_connections = set() @@ -994,10 +982,10 @@ async def chat_completions(request: ChatCompletionRequest): ) if has_prefix_token_ids: - # do on policy token id correction and call generate instead of chat + # do on policy token id correction and call generate instead of chat # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 - tokenizer = cached_tokenizer + tokenizer = app.state.tokenizer # preprocess full conversation connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": { @@ -1194,7 +1182,7 @@ async def tokenize(request: TokenizeRequest): result_tokens = template_prompt["prompt_token_ids"] if has_prefix_token_ids: - tokenizer = cached_tokenizer + tokenizer = app.state.tokenizer # Extract model prefix tokens from last assistant message model_prefix_tokens = None From 5f8ccc9118d3c235d17bc7280e06d6113df26a8c Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:15:12 -0800 Subject: [PATCH 19/51] remove old code Signed-off-by: cmunley1 --- trl/scripts/vllm_serve.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 4c9661276ab..a241ff91996 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -820,8 +820,6 @@ async def chat(request: ChatRequest): "sampling_params": sampling_params, "chat_template_kwargs": request.chat_template_kwargs, "tools": request.tools if request.tools else None, - # "tool_choice": request.tool_choice, - # "parallel_tool_calls": request.parallel_tool_calls, } connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) @@ -1082,10 +1080,8 @@ async def chat_completions(request: ChatCompletionRequest): tool_calls = None finish_reason = "stop" - if hasattr(gen_output, "tool_calls") and gen_output.tool_calls: # native tool call parsing - tool_calls = gen_output.tool_calls - finish_reason = "tool_calls" - elif request.tools and text: # try manual tool call parsing eg qwen3 style xml format... this is a hack. + # Manual XML-json tool call parsing + if request.tools and text: pattern = r'(.*?)' matches = re.findall(pattern, text, re.DOTALL) if matches: From 743d5eaa7d60fb465f7ae2de0ecb7d815e270ff4 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:16:03 -0800 Subject: [PATCH 20/51] enable IS Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 5af51bd1196..5770c389e3a 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -20,7 +20,7 @@ warmup_steps: 10 lr_scheduler_type: "linear" optim: "adamw_torch_fused" weight_decay: 0.0 -vllm_importance_sampling_correction: false +vllm_importance_sampling_correction: true temperature: 1.0 top_p: 0.999 From d98dd8a0c279edd22331ee0ae94ca406be78a5da Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:18:58 -0800 Subject: [PATCH 21/51] remove logp diff tracking without is Signed-off-by: cmunley1 --- trl/trainer/grpo_trainer.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0c682e24ddf..0141dac8c07 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1867,19 +1867,6 @@ def _generate_and_score_completions( else: old_per_token_logps = None - # track sampling logp diff even when IS off for debugging - # could remove this - if self.use_vllm and sampling_per_token_logps is not None and old_per_token_logps is None: - old_per_token_logps, _ = self._get_per_token_logps_and_entropies( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - batch_size, - num_images=num_images, - **forward_kwargs, - ) - # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: mask = completion_mask if not self.tools else completion_mask * tool_mask @@ -2016,9 +2003,7 @@ def _generate_and_score_completions( if images is not None: self._logs["images"].extend(gather_object(images)) - # track sampling logp diff even when IS off for debugging - # could remove this - if self.use_vllm and old_per_token_logps is not None and sampling_per_token_logps is not None: + if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool() delta = delta[mask] @@ -2030,9 +2015,6 @@ def _generate_and_score_completions( self._metrics[mode]["sampling/sampling_logp_difference/max"].append( self.accelerator.gather(max_delta).max().item() ) - - # track IS ratio only when IS correction is enabled - if self.use_vllm and self.vllm_importance_sampling_correction: if sequence_level_is: flat_is_ratio = vllm_importance_sampling_ratio.flatten() else: From a5f91668d342e90eaa7a7131541802e7f66bcf5a Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:23:49 -0800 Subject: [PATCH 22/51] restore Signed-off-by: cmunley1 --- trl/trainer/grpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0141dac8c07..aa0d7022f31 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1873,6 +1873,7 @@ def _generate_and_score_completions( per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) logps_diff = per_sequence_logps_diff From 17b72c8e39290c4d3dab900e5774feb1fbd47128 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:29:18 -0800 Subject: [PATCH 23/51] readme Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/README.md | 2 +- trl/trainer/grpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index 736f60e2e4c..8f48f54acf9 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -1,6 +1,6 @@ # Post-training with NeMo Gym and TRL -This integration supports training language models in NeMo-Gym environments using TRL GRPO. +This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode. ## Interactive single node diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index aa0d7022f31..9a0521b7e9a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1873,7 +1873,6 @@ def _generate_and_score_completions( per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] - if sequence_level_is: per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) logps_diff = per_sequence_logps_diff @@ -2016,6 +2015,7 @@ def _generate_and_score_completions( self._metrics[mode]["sampling/sampling_logp_difference/max"].append( self.accelerator.gather(max_delta).max().item() ) + if sequence_level_is: flat_is_ratio = vllm_importance_sampling_ratio.flatten() else: From 18ffaa8b95bd6129effde9b8b62e60591e2d87c5 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:31:20 -0800 Subject: [PATCH 24/51] restore pyproject Signed-off-by: cmunley1 --- pyproject.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99d7f5cbadf..25993a5b109 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,12 +30,7 @@ requires-python = ">=3.10" dependencies = [ "accelerate>=1.4.0", "datasets>=3.0.0", - "fastapi>=0.124.4", - "omegaconf>=2.3.0", "transformers>=4.56.1", - "uvicorn>=0.38.0", - "vllm>=0.11.2", - "wandb>=0.23.1", ] dynamic = ["version"] From cc503cb6eec8c45e0b38ff11c637da5c63ca132a Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 20:33:47 -0800 Subject: [PATCH 25/51] readme Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index 8f48f54acf9..40027328adc 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -1,6 +1,6 @@ # Post-training with NeMo Gym and TRL -This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode. +This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported, including multi-environment training. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode. ## Interactive single node @@ -27,3 +27,7 @@ CUDA_VISIBLE_DEVICES=4 python train.py --config config.yaml ## Multinode with slurm See submit.sh for a multinode example! + +## Multi environment training + +Docs coming soon! \ No newline at end of file From 843938f2d1e958c4ad6eac52fd2d7a2fad6e6973 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 21:00:23 -0800 Subject: [PATCH 26/51] move submit Signed-off-by: cmunley1 --- submit.sh => examples/scripts/nemo_gym/submit.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename submit.sh => examples/scripts/nemo_gym/submit.sh (96%) diff --git a/submit.sh b/examples/scripts/nemo_gym/submit.sh similarity index 96% rename from submit.sh rename to examples/scripts/nemo_gym/submit.sh index 3317bc5b228..3f5aed73ee8 100644 --- a/submit.sh +++ b/examples/scripts/nemo_gym/submit.sh @@ -52,7 +52,7 @@ srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ source .venv/bin/activate && \ uv sync && \ ray stop --force && \ - ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0 +head_server.port=11000) > \${LOG_DIR}/ng_run.log 2>&1 & + ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0) > \${LOG_DIR}/ng_run.log 2>&1 & sleep 10 From 209b12ed70d3117ca291eb84b3379840a582f9ab Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 16 Jan 2026 21:05:02 -0800 Subject: [PATCH 27/51] config Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 5770c389e3a..3b36f27cb55 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -14,9 +14,9 @@ learning_rate: 1.0e-5 max_steps: 1000 num_generations: 8 per_device_train_batch_size: 1 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 max_seq_length: 16384 -warmup_steps: 10 +warmup_steps: 5 lr_scheduler_type: "linear" optim: "adamw_torch_fused" weight_decay: 0.0 @@ -25,8 +25,8 @@ vllm_importance_sampling_correction: true temperature: 1.0 top_p: 0.999 -save_steps: 100 +save_steps: 10 eval_strategy: "steps" -eval_steps: 50 +eval_steps: 10 From a8f7b36c89291ef9b1bab7cdb4f8661153987e63 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 20 Jan 2026 18:15:15 -0800 Subject: [PATCH 28/51] draft docs Signed-off-by: cmunley1 --- docs/source/example_overview.md | 1 + examples/scripts/nemo_gym/config.yaml | 6 +++--- examples/scripts/nemo_gym/{train.py => train_multi_env.py} | 0 3 files changed, 4 insertions(+), 3 deletions(-) rename examples/scripts/nemo_gym/{train.py => train_multi_env.py} (100%) diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index daad32f29d1..a9e99bd755f 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -58,6 +58,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`experimental.kto.KTOTrainer`] to fine-tune a model. | | [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. | | [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. | +| [`examples/scripts/nemo_gym/train_multi_env.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_env.py) | This script shows how to use the [`GRPOTrainer`] to train language models in NVIDIA NeMo-Gym environments. Supports multi-turn and tool calling environments, and multi-environment training. See the [NeMo-Gym Integration](nemo_gym_integration) guide for setup and usage. | | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | | [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM for VLMs | diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 3b36f27cb55..1998e9f66fc 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,10 +1,10 @@ model_name: "Qwen/Qwen3-4B-Instruct-2507" -dataset_path: "data/train.jsonl" -eval_dataset_path: "data/val.jsonl" +dataset_path: "/path/to/data/train.jsonl" +eval_dataset_path: "/path/to/data/val.jsonl" output_dir: "outputs/nemo_gym" -run_name_prefix: "nemo_gym" +task: "workplace" # just used in wandb run name report_to: "wandb" project_name: "trl-nemo-gym" log_completions: true diff --git a/examples/scripts/nemo_gym/train.py b/examples/scripts/nemo_gym/train_multi_env.py similarity index 100% rename from examples/scripts/nemo_gym/train.py rename to examples/scripts/nemo_gym/train_multi_env.py From e883dcd941513e223ed6bc65efc12cb0f50965c5 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 20 Jan 2026 18:18:33 -0800 Subject: [PATCH 29/51] draft docs Signed-off-by: cmunley1 --- docs/source/nemo_gym_integration.md | 321 ++++++++++++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 docs/source/nemo_gym_integration.md diff --git a/docs/source/nemo_gym_integration.md b/docs/source/nemo_gym_integration.md new file mode 100644 index 00000000000..9e6affa8507 --- /dev/null +++ b/docs/source/nemo_gym_integration.md @@ -0,0 +1,321 @@ +# NeMo-Gym Integration + +NVIDIA NeMo-Gym is a library for building reinforcement learning environments for large language models. This integration enables training models in NeMo-Gym environments using TRL's [`GRPOTrainer`]. + +NeMo-Gym orchestrates multi-step and multi-turn rollouts, providing token IDs and log probabilities to TRL through a custom rollout function. This integration currently requires TRL's vLLM server mode. + +## Overview + +The integration supports: + +- **NeMo-Gym RL environments**: Any NeMo-Gym environment should work through this integration, though not all have been tested. We thorougly tested the following environments in the development of this integration: workplace assistant, reasoning gym, mcqa, and math with judge. +- **Multi-turn tasks**: Multi-step environments involve the agent performing multiple tool calls or other steps sequentially. Multi-turn environments involve follow-up user messages, in addition to potentially multiple tool calls or other steps in the environment. +- **Multi-environment training**: Train on multiple tasks or environments simultaneously and efficiently at scale. + +## Why NeMo Gym + +NeMo-Gym was designed to support large-scale, production-grade reinforcement learning training: + +- **Scale and Coverage**: NeMo-Gym supports diverse environments running in parallel, with many examples across domains (math, coding, tool use, knowledge, reasoning, search, ...). +- **Production-Ready**: Tested for frontier model training at large scale. The infrastructure is designed for the scale and reliability required for production LLM training. +- **Multi-Verifier RL Training**: Built for training with multiple verification methods simultaneously. Supports algorithmic verification (code execution, math verification), LLM-as-a-judge, and custom verification logic across different environments in a single training run. +- **Decoupled Architecture**: Enables building agents and environments independently from the training loop. Environments can be developed, tested, and deployed without requiring expertise in the RL training framework. +- **OpenAI-Compatible API**: All environments are compatble with standardized OpenAI Responses API, allowing seamless integration with any inference server (vLLM, SGLang, etc.) and enabling environment reuse across different training frameworks. +- **Container-Ready**: Designed for containerized deployment with REST APIs, supporting complex multi-agent systems and environments like SWE-Bench that require isolated Docker containers. + +## Installation + +Install TRL with vLLM support: + +```bash +pip install trl[vllm] +``` + +Install NeMo-Gym: + +```bash +git clone https://github.com/NVIDIA-NeMo/Gym.git +cd Gym +uv venv --python 3.12 +source .venv/bin/activate +uv sync --extra dev +``` + +## Available Environments + +NeMo-Gym provides training-ready environments across various domains, including but not limited to: + +| Environment | Domain | Description | +|-------------|--------|-------------| +| Workplace Assistant | Agent | Multi-step tool calling in common office scenarios (calendar, email, etc.) | +| Math with Judge | Math | Math problems with algorithmic or judge-based verification | +| Code Gen | Coding | Competitive programming problems with code execution | +| MCQA | Knowledge | Multiple-choice question answering | +| Instruction Following | Instruction Following | IFEval/IFBench style tasks | +| Reasoning Gym | Multiple | Single-step procedurally generated verifiable tasks across various domains | + +See a complete list of available training environments in the [NeMo-Gym repository](https://github.com/NVIDIA-NeMo/Gym#-available-resource-servers). + +## Preparing a Dataset + +For creating a new environment, check out the [official guide](https://docs.nvidia.com/nemo/gym/latest/contribute/environments/new-environment.html). + +Many NeMo-Gym datasets used in training Nemotron models are available on Hugging Face, corresponding to existing RL environments. + +### Download and Prepare Workplace Assistant Data + +Use `ng_prepare_data` to download and prepare the dataset. This command: +- Downloads the dataset from Hugging Face +- Validates the data format +- Adds an `agent_ref` field to each example that tells NeMo-Gym which agent server should handle that example + +Note that `train_multi_env.py` adds `agent_ref` field when loading datasets in case that datasets are created some other way. + +First, set `env.yaml` in `Gym/` to contain your Hugging Face token: +``` +hf_token: +``` + +Example dataset preparation for the workplace assistant environment: + +```bash +cd Gym +source .venv/bin/activate + +config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ +resources_servers/workplace_assistant/configs/workplace_assistant.yaml" + +ng_prepare_data "+config_paths=[${config_paths}]" \ + +output_dirpath=data/workplace_assistant \ + +mode=train_preparation \ + +should_download=true \ + +data_source=huggingface +``` + +This creates `train.jsonl` and `validation.jsonl` files in `data/workplace_assistant/`. + +### Dataset Format + +In NeMo Gym, datasets are stored as JSONL. Each line contains a task with input messages, potential tool definitions, metadata such as ground truth for verification, and an agent server reference. The workplace dataset is structured like shown below. The metadata fields can differ between datasets, as long as the corresponding resources server leverages the fields appropriately. + +```json +{ + "responses_create_params": { + "input": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "Move any of jinsoo's tasks that are in review to completed"} + ], + "tools": [...], // Full tool definitions + "parallel_tool_calls": false, + "temperature": 1 + }, + "ground_truth": [ + {"name": "project_management_update_task", "arguments": "{...}"}, + ... + ], + "category": "workbench_project_management", + "environment_name": "workbench", + "agent_ref": { + "type": "responses_api_agents", + "name": "workplace_assistant_simple_agent" + } +} +``` + +## Training Configuration + +Create a `config_workplace.yaml` file with your training parameters: + +```yaml +model_name: "Qwen/Qwen2.5-1.5B-Instruct" + +dataset_path: "data/workplace_assistant/train.jsonl" +eval_dataset_path: "data/workplace_assistant/validation.jsonl" + +task: 'workplace' # used in wandb run name +output_dir: "outputs/nemo_gym" +report_to: "wandb" # set to none if you don't have wandb set up. +project_name: "trl-nemo-gym" + +learning_rate: 1.0e-5 +max_steps: 1000 +num_generations: 8 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 4 +max_seq_length: 16384 + +temperature: 1.0 +top_p: 0.999 + +save_steps: 10 +eval_strategy: "steps" +eval_steps: 10 +``` + +## Interactive Training + +For development and testing on a single node: + +### Step 1: Update environment config + +Update `env.yaml` to include model information: + +``` +policy_base_url: http://127.0.0.1:8000/v1 +policy_api_key: EMPTY +policy_model_name: Qwen/Qwen3-30B-A3B-Instruct-2507 +hf_token: ... +``` + +### Step 2: Start NeMo-Gym Servers + +First, start the NeMo-Gym environment servers: + +```bash +cd Gym +source .venv/bin/activate + +config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ +responses_api_models/vllm_model/configs/vllm_model_for_training.yaml" + +ng_run "+config_paths=[${config_paths}]" +``` + +This starts: +- **Head server**: Manages servers used in training +- **Agent server**: Orchestrates rollouts by leveraging resource servers and model servers +- **Resources server**: Supports environment logic such as state-based feedback, tool implementations, and task verification +- **Model server**: Adapts vLLM server requests to support NeMo Gym agents and ensures OpenAI API compatibility + +### Step 2: Start TRL vLLM Server + +In a second terminal, start the TRL vLLM server on GPU 0: + +```bash +cd trl + +CUDA_VISIBLE_DEVICES=0 trl vllm-serve \ + --model Qwen/Qwen2.5-1.5B-Instruct \ + --max-model-len 16384 \ + --host 0.0.0.0 \ + --port 8000 +``` + + +### Step 3: Run Training + +In a third terminal, launch the training script on GPU 1: + +```bash +cd trl/ +source .venv/bin/activate + +cd examples/scripts/nemo_gym + +# if using wandb +export WANDB_API_KEY=... +uv pip install wandb # TODO: double check its missing from trl + +CUDA_VISIBLE_DEVICES=1 python train_multi_env.py --config config_workplace.yaml +``` + +Note that these separate terminals can also be tmux sessions or processes ran in the background. + +## Multi-Node Training with Slurm + +An example 5-node training script is provided in `submit.sh`. To use this, update your slurm account and partition, path to Gym repository, and your training config. + +Nodes 1-4 are used for training backend, while node 5 is used for vLLM inference. For more details on TRL's vLLM integration, visit vllm integration page. + +Submit the job: +```bash +sbatch submit.sh +``` + +Monitor training logs: +```bash +tail -f logs//* +``` + +Set up wandb logging for detailed training metrics! + +## Multi-Environment Training + +Train on multiple NeMo-Gym environments simultaneously. This allows learning diverse capabilities (e.g., tool calling + math reasoning) in a single training run. + +### Step 1: Prepare Individual Datasets + +First, prepare datasets for each environment you want to use. Above, we prepared the workplace dataset. Now, create a reasoning gym dataset: + +```bash +cd Gym +source .venv/bin/activate +uv add reasoning-gym +cd resources_servers/reasoning_gym +python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 2000 \ + --seed 42 \ + --output data/reasoning_gym/train_mini_sudoku.jsonl + +python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 50 \ + --seed 24 \ + --output data/reasoning_gym/val_mini_sudoku.jsonl +``` + +### Step 2: Create Blended Dataset + +Create a single dataset with tasks from both environments mixed together. This can be done with a simple bash command, such as the following: +```bash +cat data/workplace_assistant/train_workplace.jsonl data/reasoning_gym/train_mini_sudoku.jsonl | shuf > train_multi_env.jsonl +``` + +Note you may want to ensure that the datasets are the same size before shuffling to get an even blend of tasks. Do the same for the validation dataset. + +### Step 3: Update Training Config + +Create `config_multi_env.yaml` pointing to the blended dataset: + +```yaml +model_name: "Qwen/Qwen3-4B-Instruct-2507" + +dataset_path: "/path/to/data/train_multi_env.jsonl" +eval_dataset_path: "/path/to/data/val_multi_env.jsonl" + +task: "workplace-sudoku" # used in wandb run name +output_dir: "outputs/nemo_gym_multi_env" + +# ... rest of config same +``` + +### Step 4: Launch Resources Servers + +Start NeMo-Gym with both resources servers in the config: + +```bash +cd Gym +source .venv/bin/activate + +config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ +resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ +resources_servers/reasoning_gym/configs/reasoning_gym.yaml" + +ng_run "+config_paths=[${config_paths}]" +head_server.host=0.0.0.0 +``` + +This starts servers for both environments. The training script will automatically route each example to the correct agent server based on its `agent_ref` field. + +### Step 5: Run Training + +Just update the slurm submission script to use the new config, then submit the job as before! + +The training script reads `agent_ref` from each example's metadata, routes requests to the correct NeMo-Gym agent server, and handles different environments in the same batch + +## Resources + +- [NeMo-Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) +- [NeMo-Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) +- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_env.py) +- [TRL GRPO Trainer](grpo_trainer) \ No newline at end of file From 2c7de07ebfed7177dba9d0bd5cc9ce51cc48cb11 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Tue, 20 Jan 2026 18:30:54 -0800 Subject: [PATCH 30/51] docs update Signed-off-by: cmunley1 --- docs/source/nemo_gym_integration.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/nemo_gym_integration.md b/docs/source/nemo_gym_integration.md index 9e6affa8507..38008d621ab 100644 --- a/docs/source/nemo_gym_integration.md +++ b/docs/source/nemo_gym_integration.md @@ -20,8 +20,8 @@ NeMo-Gym was designed to support large-scale, production-grade reinforcement lea - **Production-Ready**: Tested for frontier model training at large scale. The infrastructure is designed for the scale and reliability required for production LLM training. - **Multi-Verifier RL Training**: Built for training with multiple verification methods simultaneously. Supports algorithmic verification (code execution, math verification), LLM-as-a-judge, and custom verification logic across different environments in a single training run. - **Decoupled Architecture**: Enables building agents and environments independently from the training loop. Environments can be developed, tested, and deployed without requiring expertise in the RL training framework. -- **OpenAI-Compatible API**: All environments are compatble with standardized OpenAI Responses API, allowing seamless integration with any inference server (vLLM, SGLang, etc.) and enabling environment reuse across different training frameworks. -- **Container-Ready**: Designed for containerized deployment with REST APIs, supporting complex multi-agent systems and environments like SWE-Bench that require isolated Docker containers. +- **OpenAI-Compatible API**: All environments are compatble with standardized OpenAI Responses API, allowing seamless integration with compatible inference endpoint (local vLLM, OpenAI models, etc.) and enabling environment reuse across different training frameworks. +- **Flexible Environment Isolation**: Environments can manage packages via lightweight python environments or using containers. ## Installation @@ -163,7 +163,7 @@ Update `env.yaml` to include model information: ``` policy_base_url: http://127.0.0.1:8000/v1 policy_api_key: EMPTY -policy_model_name: Qwen/Qwen3-30B-A3B-Instruct-2507 +policy_model_name: Qwen/Qwen2.5-1.5B-Instruct hf_token: ... ``` @@ -265,7 +265,7 @@ python scripts/create_dataset.py \ --output data/reasoning_gym/val_mini_sudoku.jsonl ``` -### Step 2: Create Blended Dataset +### Step 2: Create Combined Dataset Create a single dataset with tasks from both environments mixed together. This can be done with a simple bash command, such as the following: ```bash @@ -276,7 +276,7 @@ Note you may want to ensure that the datasets are the same size before shuffling ### Step 3: Update Training Config -Create `config_multi_env.yaml` pointing to the blended dataset: +Create `config_multi_env.yaml` pointing to the combined dataset: ```yaml model_name: "Qwen/Qwen3-4B-Instruct-2507" @@ -290,9 +290,9 @@ output_dir: "outputs/nemo_gym_multi_env" # ... rest of config same ``` -### Step 4: Launch Resources Servers +### Step 4: Update ng_run -Start NeMo-Gym with both resources servers in the config: +Whether training interactively or via slurm, update the ng_run command to use the config file from each resources server used in training. ```bash cd Gym @@ -309,9 +309,9 @@ This starts servers for both environments. The training script will automaticall ### Step 5: Run Training -Just update the slurm submission script to use the new config, then submit the job as before! +Just update the slurm submission script to use the new train config and both ng_run resources server configs, then submit the job as before! -The training script reads `agent_ref` from each example's metadata, routes requests to the correct NeMo-Gym agent server, and handles different environments in the same batch +The training script reads `agent_ref` from each example's metadata, routes requests to the correct NeMo-Gym agent server, and handles different agents and environments in the same batch ## Resources From aad21ee2c1cff8f142bf380ed1eb5b744f5340d2 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Thu, 22 Jan 2026 14:25:42 -0800 Subject: [PATCH 31/51] ds cfg, submit update Signed-off-by: cmunley1 --- .../scripts/nemo_gym/deepspeed_zero3.yaml | 22 +++++++++++++++++++ examples/scripts/nemo_gym/submit.sh | 7 +++--- examples/scripts/nemo_gym/train_multi_env.py | 2 +- 3 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 examples/scripts/nemo_gym/deepspeed_zero3.yaml diff --git a/examples/scripts/nemo_gym/deepspeed_zero3.yaml b/examples/scripts/nemo_gym/deepspeed_zero3.yaml new file mode 100644 index 00000000000..ac6ad51adb0 --- /dev/null +++ b/examples/scripts/nemo_gym/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 4 +num_processes: 32 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/scripts/nemo_gym/submit.sh b/examples/scripts/nemo_gym/submit.sh index 3f5aed73ee8..69bdc6568d2 100644 --- a/examples/scripts/nemo_gym/submit.sh +++ b/examples/scripts/nemo_gym/submit.sh @@ -52,7 +52,7 @@ srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ source .venv/bin/activate && \ uv sync && \ ray stop --force && \ - ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0) > \${LOG_DIR}/ng_run.log 2>&1 & + ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0 +head_server.port=11000) > \${LOG_DIR}/ng_run.log 2>&1 & sleep 10 @@ -61,7 +61,7 @@ srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ export HOME=/path/to/user && \ export HF_HOME=/path/to/user/hf_home && \ cd /path/to/user/trl && \ - source .venv/bin/activate && \ + rm -rf .venv && uv venv && source .venv/bin/activate && uv sync && uv pip install -e .[vllm] && uv pip install fastapi uvicorn && \ python -m trl.scripts.vllm_serve \ --model Qwen/Qwen3-4B-Instruct-2507 \ --host 0.0.0.0 \ @@ -90,8 +90,9 @@ srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ export HOME=/path/to/user && \ export HF_HOME=/path/to/user/hf_home && \ cd /path/to/user/trl && \ - source .venv/bin/activate && \ + source .venv/bin/activate && uv pip install accelerate deepseed wandb omegaconf && \ cd examples/scripts/nemo_gym && \ + export WANDB_API_KEY= && \ accelerate launch \ --config_file deepspeed_zero3.yaml \ --num_processes 32 \ diff --git a/examples/scripts/nemo_gym/train_multi_env.py b/examples/scripts/nemo_gym/train_multi_env.py index 4827c6bd576..a4167f672b7 100644 --- a/examples/scripts/nemo_gym/train_multi_env.py +++ b/examples/scripts/nemo_gym/train_multi_env.py @@ -394,7 +394,7 @@ def main(): log_completions=config.log_completions, num_completions_to_print=config.num_completions_to_print, - max_prompt_length=config.max_prompt_length, + # max_prompt_length=config.max_prompt_length, max_completion_length=config.max_seq_length - config.max_prompt_length if config.max_prompt_length else config.max_seq_length, shuffle_dataset=False, From 06ab2a2c29a46fbdd01a50723306ced975b9b9fa Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Thu, 22 Jan 2026 14:31:13 -0800 Subject: [PATCH 32/51] readme Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/README.md | 30 +---------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index 40027328adc..db4fff18b52 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -2,32 +2,4 @@ This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported, including multi-environment training. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode. -## Interactive single node - -1. Launch vLLM server: -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve \ - --model Qwen/Qwen3-4B-Instruct-2507 \ - --tensor-parallel-size 4 \ - --max-model-len 8192 \ - --trust-remote-code -``` - -2. Start NeMo Gym servers -``` -ng_run "+config_paths=[resources_servers/workplace_assistant/configs/workplace_assistant.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]" -``` - - -3. Run training: -```bash -CUDA_VISIBLE_DEVICES=4 python train.py --config config.yaml -``` - -## Multinode with slurm - -See submit.sh for a multinode example! - -## Multi environment training - -Docs coming soon! \ No newline at end of file +Check out docs/source/nemo_gym_integration.md for a full integration guide! \ No newline at end of file From cf9f177ce04e68ea530a75f21b3bb09d24d53efb Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Thu, 22 Jan 2026 14:34:29 -0800 Subject: [PATCH 33/51] rename train, update docs Signed-off-by: cmunley1 --- docs/source/nemo_gym_integration.md | 16 ++++++++-------- .../{train_multi_env.py => run_grpo_nemo_gym.py} | 0 examples/scripts/nemo_gym/submit.sh | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) rename examples/scripts/nemo_gym/{train_multi_env.py => run_grpo_nemo_gym.py} (100%) diff --git a/docs/source/nemo_gym_integration.md b/docs/source/nemo_gym_integration.md index 38008d621ab..f6567c0cb0b 100644 --- a/docs/source/nemo_gym_integration.md +++ b/docs/source/nemo_gym_integration.md @@ -1,6 +1,6 @@ # NeMo-Gym Integration -NVIDIA NeMo-Gym is a library for building reinforcement learning environments for large language models. This integration enables training models in NeMo-Gym environments using TRL's [`GRPOTrainer`]. +NVIDIA NeMo-Gym is a library for building reinforcement learning environments for large language models. This integration enables training models in NeMo-Gym environments using TRL's GRPOTrainer. NeMo-Gym orchestrates multi-step and multi-turn rollouts, providing token IDs and log probabilities to TRL through a custom rollout function. This integration currently requires TRL's vLLM server mode. @@ -14,9 +14,9 @@ The integration supports: ## Why NeMo Gym -NeMo-Gym was designed to support large-scale, production-grade reinforcement learning training: +NeMo-Gym was designed to support large-scale agentic RL: -- **Scale and Coverage**: NeMo-Gym supports diverse environments running in parallel, with many examples across domains (math, coding, tool use, knowledge, reasoning, search, ...). +- **Scale and Coverage**: NeMo-Gym supports diverse environments running in parallel, with examples across various domains (math, coding, tool use, knowledge, reasoning, search, ...). - **Production-Ready**: Tested for frontier model training at large scale. The infrastructure is designed for the scale and reliability required for production LLM training. - **Multi-Verifier RL Training**: Built for training with multiple verification methods simultaneously. Supports algorithmic verification (code execution, math verification), LLM-as-a-judge, and custom verification logic across different environments in a single training run. - **Decoupled Architecture**: Enables building agents and environments independently from the training loop. Environments can be developed, tested, and deployed without requiring expertise in the RL training framework. @@ -69,7 +69,7 @@ Use `ng_prepare_data` to download and prepare the dataset. This command: - Validates the data format - Adds an `agent_ref` field to each example that tells NeMo-Gym which agent server should handle that example -Note that `train_multi_env.py` adds `agent_ref` field when loading datasets in case that datasets are created some other way. +Note that `run_grpo_nemo_gym.py` adds `agent_ref` field when loading datasets in case that datasets are created some other way. First, set `env.yaml` in `Gym/` to contain your Hugging Face token: ``` @@ -105,7 +105,7 @@ In NeMo Gym, datasets are stored as JSONL. Each line contains a task with input {"role": "system", "content": "..."}, {"role": "user", "content": "Move any of jinsoo's tasks that are in review to completed"} ], - "tools": [...], // Full tool definitions + "tools": [...], "parallel_tool_calls": false, "temperature": 1 }, @@ -214,9 +214,9 @@ cd examples/scripts/nemo_gym # if using wandb export WANDB_API_KEY=... -uv pip install wandb # TODO: double check its missing from trl +uv pip install wandb -CUDA_VISIBLE_DEVICES=1 python train_multi_env.py --config config_workplace.yaml +CUDA_VISIBLE_DEVICES=1 python run_grpo_nemo_gym.py --config config_workplace.yaml ``` Note that these separate terminals can also be tmux sessions or processes ran in the background. @@ -317,5 +317,5 @@ The training script reads `agent_ref` from each example's metadata, routes reque - [NeMo-Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) - [NeMo-Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) -- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_env.py) +- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/run_grpo_nemo_gym.py) - [TRL GRPO Trainer](grpo_trainer) \ No newline at end of file diff --git a/examples/scripts/nemo_gym/train_multi_env.py b/examples/scripts/nemo_gym/run_grpo_nemo_gym.py similarity index 100% rename from examples/scripts/nemo_gym/train_multi_env.py rename to examples/scripts/nemo_gym/run_grpo_nemo_gym.py diff --git a/examples/scripts/nemo_gym/submit.sh b/examples/scripts/nemo_gym/submit.sh index 69bdc6568d2..83327aa1375 100644 --- a/examples/scripts/nemo_gym/submit.sh +++ b/examples/scripts/nemo_gym/submit.sh @@ -101,7 +101,7 @@ srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ --main_process_ip ${TRAIN_NODE_0} \ --main_process_port 29500 \ --rdzv_backend c10d \ - train.py \ + run_grpo_nemo_gym.py \ --config config.yaml \ --vllm_server_host ${VLLM_NODE} \ --head_server_host ${VLLM_NODE}" & From 7669c006aa39199dda58d7bb9d166d2c411d7cac Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Thu, 22 Jan 2026 14:35:19 -0800 Subject: [PATCH 34/51] comment Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/submit.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/scripts/nemo_gym/submit.sh b/examples/scripts/nemo_gym/submit.sh index 83327aa1375..5f046d55bfc 100644 --- a/examples/scripts/nemo_gym/submit.sh +++ b/examples/scripts/nemo_gym/submit.sh @@ -31,6 +31,8 @@ mkdir -p ${LOG_DIR} echo "Starting ng_run and vLLM on ${VLLM_NODE}..." echo "Logs will be saved to: ${LOG_DIR}" +# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below! + srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ --container-image="${CONTAINER_IMAGE}" \ --container-mounts="${MOUNTS}" \ From 3a455a90f48d6ec7aea869ea06193f6c5729df20 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Fri, 23 Jan 2026 17:47:37 +0100 Subject: [PATCH 35/51] Update trl/trainer/grpo_trainer.py Co-authored-by: Kashif Rasul --- trl/trainer/grpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a95d78c5987..e9099494289 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2145,7 +2145,6 @@ def _generate_and_score_completions( self._metrics[mode]["sampling/sampling_logp_difference/max"].append( self.accelerator.gather(max_delta).max().item() ) - if sequence_level_is: flat_is_ratio = vllm_importance_sampling_ratio.flatten() else: From f69a70a46d74ed1cfa964c7ef873ca2d32bccd80 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Mon, 26 Jan 2026 09:55:23 -0800 Subject: [PATCH 36/51] Update trl/scripts/vllm_serve.py Co-authored-by: Kashif Rasul --- trl/scripts/vllm_serve.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index ac7183c3d6b..826b857b741 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -1244,6 +1244,7 @@ async def tokenize(request: TokenizeRequest): "model": request.model or script_args.model } + # Start the server uvicorn.run( app, host=script_args.host, From 56535f2f5099e935679b9231050e74a2df5c06ad Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Mon, 26 Jan 2026 10:01:17 -0800 Subject: [PATCH 37/51] rename docs file Signed-off-by: cmunley1 --- docs/source/{nemo_gym_integration.md => nemo_gym.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/source/{nemo_gym_integration.md => nemo_gym.md} (100%) diff --git a/docs/source/nemo_gym_integration.md b/docs/source/nemo_gym.md similarity index 100% rename from docs/source/nemo_gym_integration.md rename to docs/source/nemo_gym.md From 7b1fe8ae5ff915c8bef7ae5166195c7eaadabc14 Mon Sep 17 00:00:00 2001 From: Lawrence Lane Date: Wed, 28 Jan 2026 11:29:35 -0500 Subject: [PATCH 38/51] nemo gym trl edits Signed-off-by: Lawrence Lane --- docs/source/nemo_gym.md | 350 ++++++++++++++++++++-------------------- 1 file changed, 175 insertions(+), 175 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index f6567c0cb0b..3a41e0f8bad 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -1,102 +1,94 @@ -# NeMo-Gym Integration +# NeMo Gym Integration -NVIDIA NeMo-Gym is a library for building reinforcement learning environments for large language models. This integration enables training models in NeMo-Gym environments using TRL's GRPOTrainer. +NVIDIA NeMo Gym is a library for building RL environments for large language models. This integration enables training models in NeMo Gym environments using TRL's GRPOTrainer with vLLM server mode. -NeMo-Gym orchestrates multi-step and multi-turn rollouts, providing token IDs and log probabilities to TRL through a custom rollout function. This integration currently requires TRL's vLLM server mode. - -## Overview - -The integration supports: - -- **NeMo-Gym RL environments**: Any NeMo-Gym environment should work through this integration, though not all have been tested. We thorougly tested the following environments in the development of this integration: workplace assistant, reasoning gym, mcqa, and math with judge. -- **Multi-turn tasks**: Multi-step environments involve the agent performing multiple tool calls or other steps sequentially. Multi-turn environments involve follow-up user messages, in addition to potentially multiple tool calls or other steps in the environment. -- **Multi-environment training**: Train on multiple tasks or environments simultaneously and efficiently at scale. +The integration supports multi-step and multi-turn rollouts, multi-environment training, and any NeMo Gym environment (thoroughly tested: workplace assistant, reasoning gym, MCQA, and math with judge). ## Why NeMo Gym -NeMo-Gym was designed to support large-scale agentic RL: - -- **Scale and Coverage**: NeMo-Gym supports diverse environments running in parallel, with examples across various domains (math, coding, tool use, knowledge, reasoning, search, ...). -- **Production-Ready**: Tested for frontier model training at large scale. The infrastructure is designed for the scale and reliability required for production LLM training. -- **Multi-Verifier RL Training**: Built for training with multiple verification methods simultaneously. Supports algorithmic verification (code execution, math verification), LLM-as-a-judge, and custom verification logic across different environments in a single training run. -- **Decoupled Architecture**: Enables building agents and environments independently from the training loop. Environments can be developed, tested, and deployed without requiring expertise in the RL training framework. -- **OpenAI-Compatible API**: All environments are compatble with standardized OpenAI Responses API, allowing seamless integration with compatible inference endpoint (local vLLM, OpenAI models, etc.) and enabling environment reuse across different training frameworks. -- **Flexible Environment Isolation**: Environments can manage packages via lightweight python environments or using containers. - -## Installation - -Install TRL with vLLM support: - -```bash -pip install trl[vllm] -``` - -Install NeMo-Gym: - -```bash -git clone https://github.com/NVIDIA-NeMo/Gym.git -cd Gym -uv venv --python 3.12 -source .venv/bin/activate -uv sync --extra dev -``` +- **Production-Ready Scale**: Tested for frontier model training with diverse environments running in parallel across math, coding, tool use, reasoning, and more. +- **Multi-Verifier Training**: Supports algorithmic verification, LLM-as-a-judge, and custom verification logic in a single training run. +- **Decoupled Architecture**: Build agents and environments independently from the training loop—no RL framework expertise required. +- **OpenAI-Compatible API**: All environments use the standardized OpenAI Responses API for seamless integration with vLLM, OpenAI models, and other endpoints. ## Available Environments -NeMo-Gym provides training-ready environments across various domains, including but not limited to: +NeMo Gym provides training-ready environments across multiple domains: | Environment | Domain | Description | |-------------|--------|-------------| -| Workplace Assistant | Agent | Multi-step tool calling in common office scenarios (calendar, email, etc.) | +| Workplace Assistant | Agent | Multi-step tool calling in common office scenarios (calendar, email, and more) | | Math with Judge | Math | Math problems with algorithmic or judge-based verification | | Code Gen | Coding | Competitive programming problems with code execution | | MCQA | Knowledge | Multiple-choice question answering | | Instruction Following | Instruction Following | IFEval/IFBench style tasks | -| Reasoning Gym | Multiple | Single-step procedurally generated verifiable tasks across various domains | +| Reasoning Gym | Multiple | Single-step procedurally generated verifiable tasks across domains | + +For a complete list of available training environments, refer to the [NeMo Gym repository](https://github.com/NVIDIA-NeMo/Gym#-available-resource-servers). + +## Before You Start + +Complete these one-time setup steps before running training. + +### Install TRL and NeMo Gym -See a complete list of available training environments in the [NeMo-Gym repository](https://github.com/NVIDIA-NeMo/Gym#-available-resource-servers). +1. **Install TRL with vLLM support** -## Preparing a Dataset + ```bash + pip install trl[vllm] + ``` -For creating a new environment, check out the [official guide](https://docs.nvidia.com/nemo/gym/latest/contribute/environments/new-environment.html). +1. **Install NeMo Gym** -Many NeMo-Gym datasets used in training Nemotron models are available on Hugging Face, corresponding to existing RL environments. + ```bash + git clone https://github.com/NVIDIA-NeMo/Gym.git + cd Gym + uv venv --python 3.12 + source .venv/bin/activate + uv sync --extra dev + ``` -### Download and Prepare Workplace Assistant Data +### Prepare a Dataset + +Many NeMo Gym datasets used to train Nemotron models are available on Hugging Face. Use `ng_prepare_data` to download and prepare datasets. This command: -Use `ng_prepare_data` to download and prepare the dataset. This command: - Downloads the dataset from Hugging Face - Validates the data format -- Adds an `agent_ref` field to each example that tells NeMo-Gym which agent server should handle that example +- Adds an `agent_ref` field to each example that tells NeMo Gym which agent server should handle that example -Note that `run_grpo_nemo_gym.py` adds `agent_ref` field when loading datasets in case that datasets are created some other way. +> **Note**: `run_grpo_nemo_gym.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way. -First, set `env.yaml` in `Gym/` to contain your Hugging Face token: -``` -hf_token: -``` +1. **Set Hugging Face Token** -Example dataset preparation for the workplace assistant environment: + Create `env.yaml` in `Gym/` with your token: -```bash -cd Gym -source .venv/bin/activate + ```yaml + hf_token: + ``` -config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ -resources_servers/workplace_assistant/configs/workplace_assistant.yaml" +1. **Prepare Dataset** -ng_prepare_data "+config_paths=[${config_paths}]" \ - +output_dirpath=data/workplace_assistant \ - +mode=train_preparation \ - +should_download=true \ - +data_source=huggingface -``` + ```bash + cd Gym + source .venv/bin/activate -This creates `train.jsonl` and `validation.jsonl` files in `data/workplace_assistant/`. + config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ + resources_servers/workplace_assistant/configs/workplace_assistant.yaml" -### Dataset Format + ng_prepare_data "+config_paths=[${config_paths}]" \ + +output_dirpath=data/workplace_assistant \ + +mode=train_preparation \ + +should_download=true \ + +data_source=huggingface + ``` -In NeMo Gym, datasets are stored as JSONL. Each line contains a task with input messages, potential tool definitions, metadata such as ground truth for verification, and an agent server reference. The workplace dataset is structured like shown below. The metadata fields can differ between datasets, as long as the corresponding resources server leverages the fields appropriately. + This creates `train.jsonl` and `validation.jsonl` files in `data/workplace_assistant/`. + +To create a new environment, refer to the [environment creation guide](https://docs.nvidia.com/nemo/gym/latest/contribute/environments/new-environment.html). + +#### Dataset Format + +NeMo Gym datasets are stored as JSONL. Each line contains a task with input messages, tool definitions, metadata such as ground truth for verification, and an agent server reference. The following example shows the workplace dataset structure. Metadata fields can differ between datasets, as long as the corresponding resources server uses the fields appropriately. ```json { @@ -122,7 +114,7 @@ In NeMo Gym, datasets are stored as JSONL. Each line contains a task with input } ``` -## Training Configuration +### Create Training Config Create a `config_workplace.yaml` file with your training parameters: @@ -154,168 +146,176 @@ eval_steps: 10 ## Interactive Training -For development and testing on a single node: +For development and testing on a single node. The following steps run in three separate terminals concurrently. -### Step 1: Update environment config +### Set Up -Update `env.yaml` to include model information: +1. **Verify Prerequisites** -``` -policy_base_url: http://127.0.0.1:8000/v1 -policy_api_key: EMPTY -policy_model_name: Qwen/Qwen2.5-1.5B-Instruct -hf_token: ... -``` + Confirm you have completed the [Before You Start](#before-you-start) section: + - Dataset prepared in `data/workplace_assistant/` + - Training config created (`config_workplace.yaml`) -### Step 2: Start NeMo-Gym Servers +1. **Update Environment Config** -First, start the NeMo-Gym environment servers: + Update `env.yaml` to include model information: -```bash -cd Gym -source .venv/bin/activate + ```yaml + policy_base_url: http://127.0.0.1:8000/v1 + policy_api_key: EMPTY + policy_model_name: Qwen/Qwen2.5-1.5B-Instruct + hf_token: ... + ``` -config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ -responses_api_models/vllm_model/configs/vllm_model_for_training.yaml" +### Start Servers and Run -ng_run "+config_paths=[${config_paths}]" -``` +1. **Start NeMo Gym Servers** (Terminal 1) -This starts: -- **Head server**: Manages servers used in training -- **Agent server**: Orchestrates rollouts by leveraging resource servers and model servers -- **Resources server**: Supports environment logic such as state-based feedback, tool implementations, and task verification -- **Model server**: Adapts vLLM server requests to support NeMo Gym agents and ensures OpenAI API compatibility + ```bash + cd Gym + source .venv/bin/activate -### Step 2: Start TRL vLLM Server + config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + responses_api_models/vllm_model/configs/vllm_model_for_training.yaml" -In a second terminal, start the TRL vLLM server on GPU 0: + ng_run "+config_paths=[${config_paths}]" + ``` -```bash -cd trl + This starts: + - **Head server**: Manages servers used in training + - **Agent server**: Orchestrates rollouts using resource servers and model servers + - **Resources server**: Supports environment logic such as state-based feedback, tool implementations, and task verification + - **Model server**: Adapts vLLM server requests to support NeMo Gym agents and ensures OpenAI API compatibility -CUDA_VISIBLE_DEVICES=0 trl vllm-serve \ - --model Qwen/Qwen2.5-1.5B-Instruct \ - --max-model-len 16384 \ - --host 0.0.0.0 \ - --port 8000 -``` +1. **Start TRL vLLM Server** (Terminal 2) + ```bash + cd trl -### Step 3: Run Training + CUDA_VISIBLE_DEVICES=0 trl vllm-serve \ + --model Qwen/Qwen2.5-1.5B-Instruct \ + --max-model-len 16384 \ + --host 0.0.0.0 \ + --port 8000 + ``` -In a third terminal, launch the training script on GPU 1: +1. **Run Training** (Terminal 3) -```bash -cd trl/ -source .venv/bin/activate + ```bash + cd trl/ + source .venv/bin/activate -cd examples/scripts/nemo_gym + cd examples/scripts/nemo_gym -# if using wandb -export WANDB_API_KEY=... -uv pip install wandb + # if using wandb + export WANDB_API_KEY=... + uv pip install wandb -CUDA_VISIBLE_DEVICES=1 python run_grpo_nemo_gym.py --config config_workplace.yaml -``` + CUDA_VISIBLE_DEVICES=1 python run_grpo_nemo_gym.py --config config_workplace.yaml + ``` -Note that these separate terminals can also be tmux sessions or processes ran in the background. +> **Note**: These separate terminals can also be tmux sessions or background processes. ## Multi-Node Training with Slurm -An example 5-node training script is provided in `submit.sh`. To use this, update your slurm account and partition, path to Gym repository, and your training config. +An example five-node training script is provided in `submit.sh`. Nodes one through four run the training backend, while node five runs vLLM inference. -Nodes 1-4 are used for training backend, while node 5 is used for vLLM inference. For more details on TRL's vLLM integration, visit vllm integration page. +1. **Configure the Script** -Submit the job: -```bash -sbatch submit.sh -``` + Update `submit.sh` with your Slurm account, partition, path to Gym repository, and training config. -Monitor training logs: -```bash -tail -f logs//* -``` +1. **Submit the Job** -Set up wandb logging for detailed training metrics! + ```bash + sbatch submit.sh + ``` + +1. **Monitor Training** + + ```bash + tail -f logs//* + ``` + +> **Tip**: Set up wandb logging for detailed training metrics. For more details on TRL's vLLM integration, refer to the vLLM integration page. ## Multi-Environment Training -Train on multiple NeMo-Gym environments simultaneously. This allows learning diverse capabilities (e.g., tool calling + math reasoning) in a single training run. +Train on multiple NeMo Gym environments simultaneously. This allows learning diverse capabilities (such as tool calling and math reasoning) in a single training run. -### Step 1: Prepare Individual Datasets +1. **Prepare Individual Datasets** -First, prepare datasets for each environment you want to use. Above, we prepared the workplace dataset. Now, create a reasoning gym dataset: + Prepare datasets for each environment. The workplace dataset was prepared above. Now, create a reasoning gym dataset: -```bash -cd Gym -source .venv/bin/activate -uv add reasoning-gym -cd resources_servers/reasoning_gym -python scripts/create_dataset.py \ - --task mini_sudoku \ - --size 2000 \ - --seed 42 \ - --output data/reasoning_gym/train_mini_sudoku.jsonl + ```bash + cd Gym + source .venv/bin/activate + uv add reasoning-gym + cd resources_servers/reasoning_gym + python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 2000 \ + --seed 42 \ + --output data/reasoning_gym/train_mini_sudoku.jsonl -python scripts/create_dataset.py \ - --task mini_sudoku \ - --size 50 \ - --seed 24 \ - --output data/reasoning_gym/val_mini_sudoku.jsonl -``` + python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 50 \ + --seed 24 \ + --output data/reasoning_gym/val_mini_sudoku.jsonl + ``` -### Step 2: Create Combined Dataset +1. **Create Combined Dataset** -Create a single dataset with tasks from both environments mixed together. This can be done with a simple bash command, such as the following: -```bash -cat data/workplace_assistant/train_workplace.jsonl data/reasoning_gym/train_mini_sudoku.jsonl | shuf > train_multi_env.jsonl -``` + Combine datasets into a single file with tasks from both environments: -Note you may want to ensure that the datasets are the same size before shuffling to get an even blend of tasks. Do the same for the validation dataset. + ```bash + cat data/workplace_assistant/train_workplace.jsonl data/reasoning_gym/train_mini_sudoku.jsonl | shuf > train_multi_env.jsonl + ``` -### Step 3: Update Training Config + > **Tip**: Ensure datasets are the same size before shuffling for an even blend of tasks. Repeat for the validation dataset. -Create `config_multi_env.yaml` pointing to the combined dataset: +1. **Update Training Config** -```yaml -model_name: "Qwen/Qwen3-4B-Instruct-2507" + Create `config_multi_env.yaml` pointing to the combined dataset: -dataset_path: "/path/to/data/train_multi_env.jsonl" -eval_dataset_path: "/path/to/data/val_multi_env.jsonl" + ```yaml + model_name: "Qwen/Qwen3-4B-Instruct-2507" -task: "workplace-sudoku" # used in wandb run name -output_dir: "outputs/nemo_gym_multi_env" + dataset_path: "/path/to/data/train_multi_env.jsonl" + eval_dataset_path: "/path/to/data/val_multi_env.jsonl" -# ... rest of config same -``` + task: "workplace-sudoku" # used in wandb run name + output_dir: "outputs/nemo_gym_multi_env" -### Step 4: Update ng_run + # ... rest of config same + ``` -Whether training interactively or via slurm, update the ng_run command to use the config file from each resources server used in training. +1. **Update ng_run** -```bash -cd Gym -source .venv/bin/activate + Whether training interactively or via Slurm, update the `ng_run` command to include config files from each resources server: -config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ -resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ -resources_servers/reasoning_gym/configs/reasoning_gym.yaml" + ```bash + cd Gym + source .venv/bin/activate -ng_run "+config_paths=[${config_paths}]" +head_server.host=0.0.0.0 -``` + config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ + resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + resources_servers/reasoning_gym/configs/reasoning_gym.yaml" + + ng_run "+config_paths=[${config_paths}]" +head_server.host=0.0.0.0 + ``` -This starts servers for both environments. The training script will automatically route each example to the correct agent server based on its `agent_ref` field. + This starts servers for both environments. The training script automatically routes each example to the correct agent server based on its `agent_ref` field. -### Step 5: Run Training +1. **Run Training** -Just update the slurm submission script to use the new train config and both ng_run resources server configs, then submit the job as before! + Update the Slurm submission script to use the new training config and both `ng_run` resources server configs, then submit the job as before. -The training script reads `agent_ref` from each example's metadata, routes requests to the correct NeMo-Gym agent server, and handles different agents and environments in the same batch + The training script reads `agent_ref` from each example's metadata, routes requests to the correct NeMo Gym agent server, and handles different agents and environments in the same batch. ## Resources -- [NeMo-Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) -- [NeMo-Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) +- [NeMo Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) +- [NeMo Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) - [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/run_grpo_nemo_gym.py) - [TRL GRPO Trainer](grpo_trainer) \ No newline at end of file From 4d6012ef961d8121a310d6ea1dcc58cbc714f956 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Fri, 30 Jan 2026 00:16:35 -0800 Subject: [PATCH 39/51] lint Signed-off-by: Christian Munley --- .../scripts/nemo_gym/run_grpo_nemo_gym.py | 187 ++++++++------- trl/scripts/vllm_serve.py | 221 ++++++++++-------- trl/trainer/grpo_trainer.py | 4 +- 3 files changed, 228 insertions(+), 184 deletions(-) diff --git a/examples/scripts/nemo_gym/run_grpo_nemo_gym.py b/examples/scripts/nemo_gym/run_grpo_nemo_gym.py index a4167f672b7..0c48e747fff 100644 --- a/examples/scripts/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/scripts/nemo_gym/run_grpo_nemo_gym.py @@ -1,24 +1,41 @@ -import os +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import asyncio -import aiohttp import json -import yaml -import requests -from omegaconf import OmegaConf -from typing import Any, Dict, List, Optional +import os from dataclasses import dataclass +from typing import Any + +import aiohttp +import requests +import wandb +import yaml from datasets import Dataset, load_dataset -from trl import GRPOConfig, GRPOTrainer +from omegaconf import OmegaConf from transformers import AutoTokenizer -import wandb + +from trl import GRPOConfig, GRPOTrainer + @dataclass class TrainingConfig: model_name: str dataset_path: str - task: Optional[str] = None + task: str | None = None learning_rate: float = 5e-6 max_steps: int = 100 @@ -44,32 +61,30 @@ class TrainingConfig: log_completions: bool = False num_completions_to_print: int = None - eval_dataset_path: Optional[str] = None + eval_dataset_path: str | None = None eval_strategy: str = "no" eval_steps: int = 50 vllm_importance_sampling_correction: bool = False + def get_agent_servers( head_server_host: str = "127.0.0.1", head_server_port: int = 11000, -) -> Dict[str, str]: +) -> dict[str, str]: try: - response = requests.get( - f"http://{head_server_host}:{head_server_port}/global_config_dict_yaml", - timeout=10 - ) + response = requests.get(f"http://{head_server_host}:{head_server_port}/global_config_dict_yaml", timeout=10) response.raise_for_status() global_config_yaml = response.text global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) agent_servers = {} for project_name, project_config in global_config_dict.items(): - if hasattr(project_config, 'responses_api_agents'): + if hasattr(project_config, "responses_api_agents"): agents = project_config.responses_api_agents for agent_key in agents.keys(): agent_config = getattr(agents, agent_key) - if hasattr(agent_config, 'host') and hasattr(agent_config, 'port'): + if hasattr(agent_config, "host") and hasattr(agent_config, "port"): agent_host = agent_config.host if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"): agent_host = head_server_host @@ -81,25 +96,27 @@ def get_agent_servers( return agent_servers except requests.exceptions.RequestException as e: - raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") + raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") from e + -def reward_fn(completions: List[str], **kwargs) -> List[float]: +def reward_fn(completions: list[str], **kwargs) -> list[float]: env_rewards = kwargs.get("env_reward") assert env_rewards is not None, "env_reward not found in kwargs" return [float(r) for r in env_rewards] + async def call_nemo_gym_agents( - prompts: List[str], - dataset_items: List[Dict[str, Any]], - agent_servers: Dict[str, str], + prompts: list[str], + dataset_items: list[dict[str, Any]], + agent_servers: dict[str, str], timeout: float, max_completion_length: int = 4096, temperature: float = 1.0, top_p: float = 0.999, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session: tasks = [] - for prompt, item in zip(prompts, dataset_items): + for prompt, item in zip(prompts, dataset_items, strict=False): request_body = item.copy() if "responses_create_params" not in request_body: @@ -115,7 +132,9 @@ async def call_nemo_gym_agents( agent_ref = item.get("agent_ref", {}) agent_name = agent_ref.get("name") if isinstance(agent_ref, dict) else None if not agent_name or agent_name not in agent_servers: - raise ValueError(f"Missing or invalid agent_ref. Got: {agent_ref}. Available: {list(agent_servers.keys())}") + raise ValueError( + f"Missing or invalid agent_ref. Got: {agent_ref}. Available: {list(agent_servers.keys())}" + ) agent_url = agent_servers[agent_name] task = session.post( @@ -143,9 +162,13 @@ async def call_nemo_gym_agents( return results -def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, List]: +def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: is_eval = not trainer.model.training - num_generations = trainer.args.num_generations_eval if is_eval and trainer.args.num_generations_eval else trainer.args.num_generations + num_generations = ( + trainer.args.num_generations_eval + if is_eval and trainer.args.num_generations_eval + else trainer.args.num_generations + ) dataset = trainer.eval_dataset if is_eval and trainer.eval_dataset is not None else trainer.train_dataset expanded_prompts = [] @@ -177,14 +200,14 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, loop.close() tokenizer = trainer.processing_class - - prompt_ids: List[List[int]] = [] - completion_ids: List[List[int]] = [] # list of rollouts - completion_mask: List[List[int]] = [] # only train on assistant turns - - logprobs: List[List[float]] = [] - env_rewards: List[float] = [] - num_turns_list: List[int] = [] + + prompt_ids: list[list[int]] = [] + completion_ids: list[list[int]] = [] # list of rollouts + completion_mask: list[list[int]] = [] # only train on assistant turns + + logprobs: list[list[float]] = [] + env_rewards: list[float] = [] + num_turns_list: list[int] = [] for i, response in enumerate(responses): eos_token_id = tokenizer.eos_token_id or 0 @@ -194,10 +217,12 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, else: output_items = response.get("response", {}).get("output", []) has_content = output_items and any( - item.get("type") == "function_call" or ( - item.get("type") == "message" and - any(c.get("type") == "output_text" and c.get("text", "").strip() - for c in item.get("content", [])) + item.get("type") == "function_call" + or ( + item.get("type") == "message" + and any( + c.get("type") == "output_text" and c.get("text", "").strip() for c in item.get("content", []) + ) ) for item in output_items ) @@ -215,15 +240,15 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, episode_reward = response.get("reward", 0.0) output_items = response.get("response", {}).get("output", []) - rollout_ids: List[int] = [] - rollout_mask: List[int] = [] - rollout_logprobs: List[float] = [] - - seen_token_ids: List[int] = [] + rollout_ids: list[int] = [] + rollout_mask: list[int] = [] + rollout_logprobs: list[float] = [] + + seen_token_ids: list[int] = [] first_prompt = None num_turns = 0 - for idx, item in enumerate(output_items): + for _idx, item in enumerate(output_items): if "prompt_token_ids" not in item or "generation_token_ids" not in item: continue @@ -238,12 +263,12 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, seen_token_ids = list(item_prompt_ids) else: if len(item_prompt_ids) > len(seen_token_ids): - if item_prompt_ids[:len(seen_token_ids)] != seen_token_ids: + if item_prompt_ids[: len(seen_token_ids)] != seen_token_ids: raise ValueError( f"[Turn {num_turns}] Non-contiguous messages (tokenization issue). " f"Expected prefix len {len(seen_token_ids)}, got prompt len {len(item_prompt_ids)}" ) - tool_result_tokens = item_prompt_ids[len(seen_token_ids):] + tool_result_tokens = item_prompt_ids[len(seen_token_ids) :] if tool_result_tokens: rollout_ids.extend(tool_result_tokens) @@ -252,7 +277,9 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, rollout_ids.extend(item_gen_ids) rollout_mask.extend([1] * len(item_gen_ids)) - assert len(item_logprobs) == len(item_gen_ids), f"Logprobs len {len(item_logprobs)} != gen len {len(item_gen_ids)}" + assert len(item_logprobs) == len(item_gen_ids), ( + f"Logprobs len {len(item_logprobs)} != gen len {len(item_gen_ids)}" + ) rollout_logprobs.extend(item_logprobs) seen_token_ids = list(item_prompt_ids) + list(item_gen_ids) @@ -260,25 +287,24 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, if not rollout_ids or first_prompt is None: raise ValueError(f"Rollout {i} has no valid turns") - prompt_ids.append(first_prompt) # list of prompts - completion_ids.append(rollout_ids) # list of rollouts - completion_mask.append(rollout_mask) + prompt_ids.append(first_prompt) # list of prompts + completion_ids.append(rollout_ids) # list of rollouts + completion_mask.append(rollout_mask) logprobs.append(rollout_logprobs) env_rewards.append(episode_reward) num_turns_list.append(num_turns) if not prompt_ids: - raise RuntimeError( - "No valid rollouts. Check Nemo Gym and vLLM logs." - ) - + raise RuntimeError("No valid rollouts. Check Nemo Gym and vLLM logs.") if num_turns_list: - wandb.log({ - "train/num_turns_mean": sum(num_turns_list) / len(num_turns_list), - "train/num_turns_min": min(num_turns_list), - "train/num_turns_max": max(num_turns_list), - }) + wandb.log( + { + "train/num_turns_mean": sum(num_turns_list) / len(num_turns_list), + "train/num_turns_min": min(num_turns_list), + "train/num_turns_max": max(num_turns_list), + } + ) unique_prompt_ids = prompt_ids[::num_generations] @@ -294,25 +320,27 @@ def nemo_gym_rollout_func(prompts: List[str], trainer: GRPOTrainer) -> Dict[str, def load_dataset_from_jsonl(path: str) -> Dataset: data = [] - with open(path, 'r') as f: + with open(path) as f: for idx, line in enumerate(f): if line.strip(): item = json.loads(line) - data.append({ - "prompt": str(idx), # use index for lookup as not all nemo gym datasets have the same metadata fields. maybe not the most elegant - "metadata": json.dumps(item), - }) + data.append( + { + "prompt": str( + idx + ), # use index for lookup as not all nemo gym datasets have the same metadata fields. maybe not the most elegant + "metadata": json.dumps(item), + } + ) return Dataset.from_list(data) + def main(): parser = argparse.ArgumentParser(description="") parser.add_argument("--config", required=True, help="Path to config YAML file") - parser.add_argument("--vllm_server_host", type=str, default="127.0.0.1", - help="vLLM server hostname/IP") - parser.add_argument("--head_server_host", type=str, default="127.0.0.1", - help="Head server hostname/IP for ng_run") - parser.add_argument("--resume_from_checkpoint", type=str, default=None, - help="Path to checkpoint to resume from") + parser.add_argument("--vllm_server_host", type=str, default="127.0.0.1", help="vLLM server hostname/IP") + parser.add_argument("--head_server_host", type=str, default="127.0.0.1", help="Head server hostname/IP for ng_run") + parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint to resume from") args = parser.parse_args() with open(args.config) as f: @@ -332,7 +360,7 @@ def main(): os.environ["WANDB_PROJECT"] = config.project_name if config.run_name is None: - task = config.task or os.path.basename(config.dataset_path).replace('.jsonl', '').replace('.json', '') + task = config.task or os.path.basename(config.dataset_path).replace(".jsonl", "").replace(".json", "") model_short = config.model_name.split("/")[-1] config.run_name = ( f"{task}_{model_short}" @@ -345,7 +373,7 @@ def main(): f"_topp{config.top_p}" ) - if config.dataset_path.endswith(('.jsonl', '.json')): + if config.dataset_path.endswith((".jsonl", ".json")): dataset = load_dataset_from_jsonl(config.dataset_path) else: dataset = load_dataset(config.dataset_path, split="train") @@ -360,9 +388,7 @@ def main(): vllm_mode="server", vllm_server_host=args.vllm_server_host, vllm_server_port=8000, - gradient_checkpointing=True, - temperature=config.temperature, learning_rate=config.learning_rate, weight_decay=config.weight_decay, @@ -370,22 +396,18 @@ def main(): warmup_steps=config.warmup_steps, lr_scheduler_type=config.lr_scheduler_type, optim=config.optim, - per_device_train_batch_size=config.per_device_train_batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, num_generations=config.num_generations, num_generations_eval=1, - max_steps=config.max_steps, save_steps=config.save_steps, logging_steps=1, report_to=config.report_to, output_dir=config.output_dir, run_name=config.run_name, - eval_strategy=config.eval_strategy, eval_steps=config.eval_steps, - vllm_importance_sampling_correction=config.vllm_importance_sampling_correction, epsilon=0.2, epsilon_high=0.28, @@ -393,11 +415,11 @@ def main(): mask_truncated_completions=True, log_completions=config.log_completions, num_completions_to_print=config.num_completions_to_print, - # max_prompt_length=config.max_prompt_length, - max_completion_length=config.max_seq_length - config.max_prompt_length if config.max_prompt_length else config.max_seq_length, + max_completion_length=config.max_seq_length - config.max_prompt_length + if config.max_prompt_length + else config.max_seq_length, shuffle_dataset=False, - model_init_kwargs={ "torch_dtype": "auto", }, @@ -420,5 +442,6 @@ def main(): trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) + if __name__ == "__main__": main() diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 826b857b741..f5bf4333e81 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -31,8 +31,7 @@ import torch import torch.distributed.distributed_c10d as c10d from packaging.version import Version -from transformers import is_torch_xpu_available, is_vision_available -from transformers import AutoTokenizer +from transformers import AutoTokenizer, is_torch_xpu_available, is_vision_available from trl import TrlParser from trl.import_utils import ( @@ -450,26 +449,21 @@ def _replace_prefix_tokens( template_token_ids: list[int], ) -> list[int]: """ - This function is for fixing up the chat template-tokenized messages history - to match the model output tokenization up to the last assistant turn, - in order to preserve the monotonic tokens property for optimized multi-turn + This function is for fixing up the chat template-tokenized messages history to match the model output tokenization + up to the last assistant turn, in order to preserve the monotonic tokens property for optimized multi-turn training. - RL training frameworks train models on token IDs, but the OpenAI compatible - server communicates in what is basically de-tokenized text. When multiple - model calls are made to the OpenAI compatible server in a single trajectory, - model generations in previous model calls may be re-tokenized to something - that is different than what was generated. This is not too big of an issue - (that we know of) at inference time, but the log probs the model produces - are different enough for the differently re-tokenized generation result that - it causes the training to be off policy. Off policy isn't necessarily a bad - thing in isolation, but this source of off-policyness may cause unexpected - issues if not properly accounted for. It also mis-aligns the token ID - sequences across model calls, which is strange during training. - - There are real cases where the model output string _does not match_ the chat - template tokenization of the parsed model output. A concrete example is - inconsistent whitespace tokens around tool call special tokens. + RL training frameworks train models on token IDs, but the OpenAI compatible server communicates in what is + basically de-tokenized text. When multiple model calls are made to the OpenAI compatible server in a single + trajectory, model generations in previous model calls may be re-tokenized to something that is different than what + was generated. This is not too big of an issue (that we know of) at inference time, but the log probs the model + produces are different enough for the differently re-tokenized generation result that it causes the training to be + off policy. Off policy isn't necessarily a bad thing in isolation, but this source of off-policyness may cause + unexpected issues if not properly accounted for. It also mis-aligns the token ID sequences across model calls, + which is strange during training. + + There are real cases where the model output string _does not match_ the chat template tokenization of the parsed + model output. A concrete example is inconsistent whitespace tokens around tool call special tokens. Based on NeMo RL's _replace_prefix_tokens: https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 @@ -498,13 +492,11 @@ def _replace_prefix_tokens( logger.warning("No EOS token found in template prefix, cannot apply _replace_prefix_tokens") return template_token_ids - result = ( - model_prefix_token_ids[:model_cut_end] + - template_token_ids[template_cut_start:] - ) + result = model_prefix_token_ids[:model_cut_end] + template_token_ids[template_cut_start:] return result + def main(script_args: ScriptArguments): if not is_fastapi_available(): raise ImportError( @@ -538,7 +530,9 @@ def main(script_args: ScriptArguments): @asynccontextmanager async def lifespan(app: FastAPI): logger.info(f"Loading tokenizer for {script_args.model}...") - app.state.tokenizer = AutoTokenizer.from_pretrained(script_args.model, trust_remote_code=script_args.trust_remote_code) + app.state.tokenizer = AutoTokenizer.from_pretrained( + script_args.model, trust_remote_code=script_args.trust_remote_code + ) # Wait for all workers to send "ready" ready_connections = set() @@ -851,7 +845,7 @@ async def chat(request: ChatRequest): "chat_template_kwargs": request.chat_template_kwargs, "tools": request.tools if request.tools else None, } - + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) # Receive results @@ -984,7 +978,7 @@ async def chat_completions(request: ChatCompletionRequest): messages.append(msg) max_tokens = request.max_completion_tokens or request.max_tokens or 512 - + sampling_kwargs = { "n": request.n, "temperature": request.temperature, @@ -1004,10 +998,7 @@ async def chat_completions(request: ChatCompletionRequest): if request.tool_choice and request.tool_choice != "auto": chat_template_kwargs["tool_choice"] = request.tool_choice - has_prefix_token_ids = any( - msg.get("role") == "assistant" and "prompt_token_ids" in msg - for msg in messages - ) + has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) if has_prefix_token_ids: # do on policy token id correction and call generate instead of chat @@ -1016,9 +1007,18 @@ async def chat_completions(request: ChatCompletionRequest): tokenizer = app.state.tokenizer # preprocess full conversation - connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": { - "messages": [messages], "chat_template_kwargs": chat_template_kwargs, - "tools": request.tools, "add_generation_prompt": True}}) + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages], + "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, + "add_generation_prompt": True, + }, + } + ) template_prompts = connections[0].recv() template_prompt = template_prompts[0] @@ -1029,23 +1029,30 @@ async def chat_completions(request: ChatCompletionRequest): if messages[i].get("role") == "assistant": last_assistant_idx = i if "prompt_token_ids" in messages[i]: - model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get("generation_token_ids", []) + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( + "generation_token_ids", [] + ) break if model_prefix_tokens and last_assistant_idx is not None: - messages_to_last_assistant = messages[:last_assistant_idx + 1] - connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": { - "messages": [messages_to_last_assistant], "chat_template_kwargs": chat_template_kwargs, - "tools": request.tools, "add_generation_prompt": False}}) + messages_to_last_assistant = messages[: last_assistant_idx + 1] + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages_to_last_assistant], + "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, + "add_generation_prompt": False, + }, + } + ) template_prefix_prompts = connections[0].recv() template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] - corrected_token_ids = _replace_prefix_tokens( - tokenizer, - model_prefix_tokens, - template_prefix_token_ids, - template_prompt["prompt_token_ids"] + tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] ) else: @@ -1057,8 +1064,13 @@ async def chat_completions(request: ChatCompletionRequest): for connection, prompts in zip(connections, chunked_prompts, strict=True): if not prompts: prompts = [{"prompt_token_ids": [tokenizer.eos_token_id]}] - connection.send({"type": "call", "method": "generate", "kwargs": { - "prompts": prompts, "sampling_params": sampling_params}}) + connection.send( + { + "type": "call", + "method": "generate", + "kwargs": {"prompts": prompts, "sampling_params": sampling_params}, + } + ) else: # no prefix token IDs, use chat() chunked_messages = chunk_list([messages], script_args.data_parallel_size) @@ -1070,7 +1082,7 @@ async def chat_completions(request: ChatCompletionRequest): "messages": message_chunk, "sampling_params": sampling_params, "tools": request.tools, - "chat_template_kwargs": chat_template_kwargs + "chat_template_kwargs": chat_template_kwargs, } connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) @@ -1078,7 +1090,9 @@ async def chat_completions(request: ChatCompletionRequest): if has_prefix_token_ids: all_outputs = [o for o in all_outputs if o] else: - all_outputs = [output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk] + all_outputs = [ + output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk + ] all_outputs = list(chain.from_iterable(all_outputs)) if not all_outputs: @@ -1087,13 +1101,15 @@ async def chat_completions(request: ChatCompletionRequest): "object": "chat.completion", "created": created_at, "model": request.model or script_args.model, - "choices": [{ - "index": 0, - "message": {"role": "assistant", "content": ""}, - "finish_reason": "length", - "logprobs": None - }], - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": ""}, + "finish_reason": "length", + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, } choices = [] @@ -1112,21 +1128,23 @@ async def chat_completions(request: ChatCompletionRequest): # Manual XML-json tool call parsing if request.tools and text: - pattern = r'(.*?)' + pattern = r"(.*?)" matches = re.findall(pattern, text, re.DOTALL) if matches: tool_calls = [] for match in matches: try: data = json.loads(match.strip()) - tool_calls.append({ - "id": f"call_{uuid.uuid4().hex[:24]}", - "type": "function", - "function": { - "name": data.get("name", ""), - "arguments": json.dumps(data.get("arguments", {})) + tool_calls.append( + { + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": data.get("name", ""), + "arguments": json.dumps(data.get("arguments", {})), + }, } - }) + ) except json.JSONDecodeError: continue if tool_calls: @@ -1144,22 +1162,24 @@ async def chat_completions(request: ChatCompletionRequest): "token": str(token_id), "logprob": float(list(logprob_dict.values())[0].logprob) if logprob_dict else 0.0, "bytes": None, - "top_logprobs": [] + "top_logprobs": [], } - for token_id, logprob_dict in zip(gen_output.token_ids, gen_output.logprobs) + for token_id, logprob_dict in zip(gen_output.token_ids, gen_output.logprobs, strict=False) ] } - choices.append({ - "index": idx, - "message": { - "role": "assistant", - "content": text if not tool_calls else None, - "tool_calls": tool_calls - }, - "logprobs": logprobs_data, - "finish_reason": finish_reason - }) + choices.append( + { + "index": idx, + "message": { + "role": "assistant", + "content": text if not tool_calls else None, + "tool_calls": tool_calls, + }, + "logprobs": logprobs_data, + "finish_reason": finish_reason, + } + ) return { "id": completion_id, @@ -1170,8 +1190,8 @@ async def chat_completions(request: ChatCompletionRequest): "usage": { "prompt_tokens": total_input_tokens, "completion_tokens": total_output_tokens, - "total_tokens": total_input_tokens + total_output_tokens - } + "total_tokens": total_input_tokens + total_output_tokens, + }, } class TokenizeRequest(BaseModel): @@ -1183,23 +1203,22 @@ class TokenizeRequest(BaseModel): async def tokenize(request: TokenizeRequest): messages = request.messages - has_prefix_token_ids = any( - msg.get("role") == "assistant" and "prompt_token_ids" in msg - for msg in messages - ) + has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) kwargs = { "messages": [messages], "tools": request.tools, "add_generation_prompt": True, - "chat_template_kwargs": {} + "chat_template_kwargs": {}, } connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": kwargs}) preprocessed_prompts = connections[0].recv() if preprocessed_prompts and len(preprocessed_prompts) > 1: - logger.warning("More than one tokenized message returned from preprocess_chat inside tokenize, double check results!") + logger.warning( + "More than one tokenized message returned from preprocess_chat inside tokenize, double check results!" + ) if not preprocessed_prompts or len(preprocessed_prompts) == 0: return {"tokens": [], "model": request.model or script_args.model} @@ -1217,32 +1236,34 @@ async def tokenize(request: TokenizeRequest): if messages[i].get("role") == "assistant": last_assistant_idx = i if "prompt_token_ids" in messages[i]: - model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get("generation_token_ids", []) + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( + "generation_token_ids", [] + ) break if model_prefix_tokens and last_assistant_idx is not None: # Preprocess up to last assistant - messages_to_last_assistant = messages[:last_assistant_idx + 1] - connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": { - "messages": [messages_to_last_assistant], - "tools": request.tools, - "add_generation_prompt": False, - "chat_template_kwargs": {} - }}) + messages_to_last_assistant = messages[: last_assistant_idx + 1] + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages_to_last_assistant], + "tools": request.tools, + "add_generation_prompt": False, + "chat_template_kwargs": {}, + }, + } + ) template_prefix_prompts = connections[0].recv() template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] result_tokens = _replace_prefix_tokens( - tokenizer, - model_prefix_tokens, - template_prefix_token_ids, - template_prompt["prompt_token_ids"] + tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] ) - return { - "tokens": result_tokens, - "model": request.model or script_args.model - } + return {"tokens": result_tokens, "model": request.model or script_args.model} # Start the server uvicorn.run( @@ -1252,7 +1273,7 @@ async def tokenize(request: TokenizeRequest): log_level=script_args.log_level, limit_concurrency=256, backlog=4096, - timeout_keep_alive=600 + timeout_keep_alive=600, ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d86c42c3d09..b27f9a35a5b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1562,7 +1562,7 @@ def _generate_and_score_completions( prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - + # Allow custom completion_mask from rollout_func for multi-turn training if "completion_mask" in extra_fields: completion_mask_list = extra_fields.pop("completion_mask") @@ -1588,7 +1588,7 @@ def _generate_and_score_completions( # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) - + # attend to all non-padding tokens, but mask out user/tool result tokens in loss completion_attention_mask = (completion_ids != self.pad_token_id).long() attention_mask = torch.cat([prompt_mask, completion_attention_mask], dim=1) # (B, P+C) From d5443eb5b21199851b32cdbba952e11079a41453 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Fri, 30 Jan 2026 00:35:39 -0800 Subject: [PATCH 40/51] docs Signed-off-by: Christian Munley --- docs/source/nemo_gym.md | 8 ++++---- examples/scripts/nemo_gym/README.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index 3a41e0f8bad..328957d7a85 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -13,7 +13,7 @@ The integration supports multi-step and multi-turn rollouts, multi-environment t ## Available Environments -NeMo Gym provides training-ready environments across multiple domains: +NeMo Gym provides training-ready environments across multiple domains, including but not limited to: | Environment | Domain | Description | |-------------|--------|-------------| @@ -116,7 +116,7 @@ NeMo Gym datasets are stored as JSONL. Each line contains a task with input mess ### Create Training Config -Create a `config_workplace.yaml` file with your training parameters: +Create a config file, `config_workplace.yaml`: ```yaml model_name: "Qwen/Qwen2.5-1.5B-Instruct" @@ -184,8 +184,8 @@ For development and testing on a single node. The following steps run in three s This starts: - **Head server**: Manages servers used in training - **Agent server**: Orchestrates rollouts using resource servers and model servers - - **Resources server**: Supports environment logic such as state-based feedback, tool implementations, and task verification - - **Model server**: Adapts vLLM server requests to support NeMo Gym agents and ensures OpenAI API compatibility + - **Resources server**: Supports environment logic such as state-management, tool implementations, and task verification + - **Model server**: Adapts vLLM server requests to support NeMo Gym agents and on-policy RL training while ensuring OpenAI API compatibility 1. **Start TRL vLLM Server** (Terminal 2) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md index db4fff18b52..23784c594cd 100644 --- a/examples/scripts/nemo_gym/README.md +++ b/examples/scripts/nemo_gym/README.md @@ -2,4 +2,4 @@ This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported, including multi-environment training. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode. -Check out docs/source/nemo_gym_integration.md for a full integration guide! \ No newline at end of file +Check out the docs page `docs/source/nemo_gym.md` for a guide. \ No newline at end of file From 2837bdad272abdea68314c403673ba684d961506 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Fri, 30 Jan 2026 01:15:01 -0800 Subject: [PATCH 41/51] improve docs, rename train script Signed-off-by: cmunley1 --- docs/source/_toctree.yml | 2 + docs/source/example_overview.md | 2 +- docs/source/nemo_gym.md | 38 +++++++++++-------- examples/scripts/nemo_gym/submit.sh | 2 +- ...nemo_gym.py => train_multi_environment.py} | 0 5 files changed, 27 insertions(+), 17 deletions(-) rename examples/scripts/nemo_gym/{run_grpo_nemo_gym.py => train_multi_environment.py} (100%) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 16a420ff0d4..d726c45f284 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -117,6 +117,8 @@ title: MiniLLM - local: nash_md_trainer title: Nash-MD + - local: nemo_gym + title: NeMo Gym - local: online_dpo_trainer title: Online DPO - local: orpo_trainer diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 8a3a6742554..b78db67020e 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -61,7 +61,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`experimental.kto.KTOTrainer`] to fine-tune a model. | | [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. | | [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. | -| [`examples/scripts/nemo_gym/train_multi_env.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_env.py) | This script shows how to use the [`GRPOTrainer`] to train language models in NVIDIA NeMo-Gym environments. Supports multi-turn and tool calling environments, and multi-environment training. See the [NeMo-Gym Integration](nemo_gym_integration) guide for setup and usage. | +| [`examples/scripts/nemo_gym/train_multi_environment.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) | This script shows how to use the [`GRPOTrainer`] to train language models in NVIDIA NeMo-Gym environments. Supports multi-turn and tool calling environments, and multi-environment training. See the [NeMo-Gym Integration](nemo_gym) guide for setup and usage. | | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | | [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM for VLMs | diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index 328957d7a85..47592dcd2c6 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -32,20 +32,25 @@ Complete these one-time setup steps before running training. ### Install TRL and NeMo Gym -1. **Install TRL with vLLM support** +1. **Install TRL with vLLM extras** ```bash - pip install trl[vllm] + cd trl/ + uv venv + source .venv/bin/activate + uv sync --extra vllm ``` 1. **Install NeMo Gym** ```bash + # deactivate trl venv + deactivate git clone https://github.com/NVIDIA-NeMo/Gym.git cd Gym uv venv --python 3.12 source .venv/bin/activate - uv sync --extra dev + uv sync ``` ### Prepare a Dataset @@ -60,7 +65,7 @@ Many NeMo Gym datasets used to train Nemotron models are available on Hugging Fa 1. **Set Hugging Face Token** - Create `env.yaml` in `Gym/` with your token: + Create `env.yaml` in `Gym/` with your HF token: ```yaml hf_token: @@ -69,12 +74,15 @@ Many NeMo Gym datasets used to train Nemotron models are available on Hugging Fa 1. **Prepare Dataset** ```bash + # Enter Gym and activate the venv cd Gym source .venv/bin/activate + # Set config paths config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ resources_servers/workplace_assistant/configs/workplace_assistant.yaml" + # Download data and prep for training ng_prepare_data "+config_paths=[${config_paths}]" \ +output_dirpath=data/workplace_assistant \ +mode=train_preparation \ @@ -84,7 +92,7 @@ Many NeMo Gym datasets used to train Nemotron models are available on Hugging Fa This creates `train.jsonl` and `validation.jsonl` files in `data/workplace_assistant/`. -To create a new environment, refer to the [environment creation guide](https://docs.nvidia.com/nemo/gym/latest/contribute/environments/new-environment.html). +To create a new environment, refer to the [environment creation guide](https://docs.nvidia.com/nemo/gym/latest/contribute/environments/new-environment.html). We suggest running an existing one first! #### Dataset Format @@ -146,7 +154,7 @@ eval_steps: 10 ## Interactive Training -For development and testing on a single node. The following steps run in three separate terminals concurrently. +For development and testing on a single node. ### Set Up @@ -158,7 +166,7 @@ For development and testing on a single node. The following steps run in three s 1. **Update Environment Config** - Update `env.yaml` to include model information: + Update `env.yaml` in `Gym/` to include model information: ```yaml policy_base_url: http://127.0.0.1:8000/v1 @@ -167,7 +175,9 @@ For development and testing on a single node. The following steps run in three s hf_token: ... ``` -### Start Servers and Run +### Run Training + +The following steps run in 3 terminals. It can also be ran with processes in the background, or using tmux. 1. **Start NeMo Gym Servers** (Terminal 1) @@ -182,12 +192,12 @@ For development and testing on a single node. The following steps run in three s ``` This starts: - - **Head server**: Manages servers used in training - **Agent server**: Orchestrates rollouts using resource servers and model servers - **Resources server**: Supports environment logic such as state-management, tool implementations, and task verification - **Model server**: Adapts vLLM server requests to support NeMo Gym agents and on-policy RL training while ensuring OpenAI API compatibility + - **Head server**: Manages servers used in training enabling their discovery -1. **Start TRL vLLM Server** (Terminal 2) +1. **Start TRL vLLM Server on GPU 0** (Terminal 2) ```bash cd trl @@ -199,7 +209,7 @@ For development and testing on a single node. The following steps run in three s --port 8000 ``` -1. **Run Training** (Terminal 3) +1. **Run Training on GPU 1** (Terminal 3) ```bash cd trl/ @@ -214,11 +224,9 @@ For development and testing on a single node. The following steps run in three s CUDA_VISIBLE_DEVICES=1 python run_grpo_nemo_gym.py --config config_workplace.yaml ``` -> **Note**: These separate terminals can also be tmux sessions or background processes. - ## Multi-Node Training with Slurm -An example five-node training script is provided in `submit.sh`. Nodes one through four run the training backend, while node five runs vLLM inference. +An example five-node training script is provided in `submit.sh`. Nodes one through four run the training algorithm, while node five runs vLLM inference for NeMo Gym agent rollouts. 1. **Configure the Script** @@ -244,7 +252,7 @@ Train on multiple NeMo Gym environments simultaneously. This allows learning div 1. **Prepare Individual Datasets** - Prepare datasets for each environment. The workplace dataset was prepared above. Now, create a reasoning gym dataset: + Prepare datasets for each environment. The workplace assistant dataset was prepared above. Now lets create a dataset for the mini sudoku environment implemented by the reasoning gym resources server in NeMo Gym: ```bash cd Gym diff --git a/examples/scripts/nemo_gym/submit.sh b/examples/scripts/nemo_gym/submit.sh index 5f046d55bfc..49fa9ed8bd1 100644 --- a/examples/scripts/nemo_gym/submit.sh +++ b/examples/scripts/nemo_gym/submit.sh @@ -103,7 +103,7 @@ srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ --main_process_ip ${TRAIN_NODE_0} \ --main_process_port 29500 \ --rdzv_backend c10d \ - run_grpo_nemo_gym.py \ + train_multi_environment.py \ --config config.yaml \ --vllm_server_host ${VLLM_NODE} \ --head_server_host ${VLLM_NODE}" & diff --git a/examples/scripts/nemo_gym/run_grpo_nemo_gym.py b/examples/scripts/nemo_gym/train_multi_environment.py similarity index 100% rename from examples/scripts/nemo_gym/run_grpo_nemo_gym.py rename to examples/scripts/nemo_gym/train_multi_environment.py From 93d97a7c03d35c9305122d7d4f338d868830be8b Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Fri, 30 Jan 2026 08:50:10 -0800 Subject: [PATCH 42/51] fixes based on review Signed-off-by: Christian Munley --- docs/source/nemo_gym.md | 6 +-- .../nemo_gym/train_multi_environment.py | 37 ++++++++++++------- tests/test_vllm_client_server.py | 3 +- trl/scripts/vllm_serve.py | 10 +++-- 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index 47592dcd2c6..ec7e8afdc7e 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -61,7 +61,7 @@ Many NeMo Gym datasets used to train Nemotron models are available on Hugging Fa - Validates the data format - Adds an `agent_ref` field to each example that tells NeMo Gym which agent server should handle that example -> **Note**: `run_grpo_nemo_gym.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way. +> **Note**: `train_multi_environment.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way. 1. **Set Hugging Face Token** @@ -221,7 +221,7 @@ The following steps run in 3 terminals. It can also be ran with processes in the export WANDB_API_KEY=... uv pip install wandb - CUDA_VISIBLE_DEVICES=1 python run_grpo_nemo_gym.py --config config_workplace.yaml + CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config_workplace.yaml ``` ## Multi-Node Training with Slurm @@ -325,5 +325,5 @@ Train on multiple NeMo Gym environments simultaneously. This allows learning div - [NeMo Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) - [NeMo Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) -- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/run_grpo_nemo_gym.py) +- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) - [TRL GRPO Trainer](grpo_trainer) \ No newline at end of file diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py index 0c48e747fff..29c9f86e804 100644 --- a/examples/scripts/nemo_gym/train_multi_environment.py +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# /// script +# dependencies = [ +# "trl[vllm]", +# "nemo_gym @ git+https://github.com/NVIDIA-NeMo/Gym", +# ] +# /// + import argparse import asyncio import json @@ -21,7 +28,6 @@ import aiohttp import requests -import wandb import yaml from datasets import Dataset, load_dataset from omegaconf import OmegaConf @@ -30,6 +36,14 @@ from trl import GRPOConfig, GRPOTrainer +@dataclass +class NeMoGymGRPOConfig(GRPOConfig): + """GRPOConfig subclass with NeMo Gym specific fields.""" + + agent_servers: dict[str, str] | None = None + request_timeout: float = 10800 + + @dataclass class TrainingConfig: model_name: str @@ -43,7 +57,6 @@ class TrainingConfig: per_device_train_batch_size: int = 2 gradient_accumulation_steps: int = 16 max_seq_length: int = 1024 - max_prompt_length: int = None temperature: float = 1.0 top_p: float = 0.999 @@ -298,11 +311,11 @@ def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, raise RuntimeError("No valid rollouts. Check Nemo Gym and vLLM logs.") if num_turns_list: - wandb.log( + trainer.log( { - "train/num_turns_mean": sum(num_turns_list) / len(num_turns_list), - "train/num_turns_min": min(num_turns_list), - "train/num_turns_max": max(num_turns_list), + "num_turns_mean": sum(num_turns_list) / len(num_turns_list), + "num_turns_min": min(num_turns_list), + "num_turns_max": max(num_turns_list), } ) @@ -383,7 +396,7 @@ def main(): eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path) print(f"Eval dataset has {len(eval_dataset)} examples\n") - training_args = GRPOConfig( + training_args = NeMoGymGRPOConfig( use_vllm=True, vllm_mode="server", vllm_server_host=args.vllm_server_host, @@ -415,19 +428,15 @@ def main(): mask_truncated_completions=True, log_completions=config.log_completions, num_completions_to_print=config.num_completions_to_print, - # max_prompt_length=config.max_prompt_length, - max_completion_length=config.max_seq_length - config.max_prompt_length - if config.max_prompt_length - else config.max_seq_length, + max_completion_length=config.max_seq_length, shuffle_dataset=False, model_init_kwargs={ "torch_dtype": "auto", }, + agent_servers=agent_servers, + request_timeout=10800, ) - training_args.agent_servers = agent_servers - training_args.request_timeout = 10800 - tokenizer = AutoTokenizer.from_pretrained(config.model_name, truncation_side="left", padding_side="left") trainer = GRPOTrainer( diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index bddbca4cf7b..e3464dae6d6 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -197,7 +197,8 @@ def test_chat_completions_with_params(self): assert len(data["choices"]) == 2 - for choice in data["choices"]: + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i, f"Expected choice at position {i} to have index {i}, got {choice['index']}" assert "message" in choice assert choice["message"]["role"] == "assistant" diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index f5bf4333e81..ce94ad1f539 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -466,7 +466,7 @@ def _replace_prefix_tokens( model output. A concrete example is inconsistent whitespace tokens around tool call special tokens. Based on NeMo RL's _replace_prefix_tokens: - https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + https://github.com/NVIDIA-NeMo/RL/blob/748b9caff4e6d672b8a98a10b6e612d028cfc96b/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 """ if not model_prefix_token_ids: return template_token_ids @@ -1088,7 +1088,9 @@ async def chat_completions(request: ChatCompletionRequest): all_outputs = [connection.recv() for connection in connections] if has_prefix_token_ids: - all_outputs = [o for o in all_outputs if o] + all_outputs = [ + output for output, prompt_chunk in zip(all_outputs, chunked_prompts, strict=True) if prompt_chunk + ] else: all_outputs = [ output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk @@ -1116,7 +1118,8 @@ async def chat_completions(request: ChatCompletionRequest): total_input_tokens = 0 total_output_tokens = 0 - for idx, output in enumerate(all_outputs): + idx = 0 + for output in all_outputs: total_input_tokens += len(output.prompt_token_ids) for gen_output in output.outputs: @@ -1180,6 +1183,7 @@ async def chat_completions(request: ChatCompletionRequest): "finish_reason": finish_reason, } ) + idx += 1 return { "id": completion_id, From b15ab6305d6675ff740d45952e8140a88e940b65 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Fri, 30 Jan 2026 11:19:48 -0800 Subject: [PATCH 43/51] subclass Signed-off-by: Christian Munley --- docs/source/nemo_gym.md | 2 +- examples/scripts/nemo_gym/config.yaml | 2 +- .../nemo_gym/train_multi_environment.py | 125 +++++------------- 3 files changed, 38 insertions(+), 91 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index ec7e8afdc7e..a1eea762203 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -142,7 +142,7 @@ max_steps: 1000 num_generations: 8 per_device_train_batch_size: 1 gradient_accumulation_steps: 4 -max_seq_length: 16384 +max_completion_length: 16384 temperature: 1.0 top_p: 0.999 diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 1998e9f66fc..448cd07da9f 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -15,7 +15,7 @@ max_steps: 1000 num_generations: 8 per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -max_seq_length: 16384 +max_completion_length: 16384 warmup_steps: 5 lr_scheduler_type: "linear" optim: "adamw_torch_fused" diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py index 29c9f86e804..3dbe58b8a37 100644 --- a/examples/scripts/nemo_gym/train_multi_environment.py +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -44,43 +44,6 @@ class NeMoGymGRPOConfig(GRPOConfig): request_timeout: float = 10800 -@dataclass -class TrainingConfig: - model_name: str - dataset_path: str - - task: str | None = None - - learning_rate: float = 5e-6 - max_steps: int = 100 - num_generations: int = 2 - per_device_train_batch_size: int = 2 - gradient_accumulation_steps: int = 16 - max_seq_length: int = 1024 - - temperature: float = 1.0 - top_p: float = 0.999 - weight_decay: float = 0.01 - warmup_ratio: float = 0.0 - warmup_steps: int = 0 - lr_scheduler_type: str = "linear" - optim: str = "adamw_8bit" - - output_dir: str = "outputs/trl_nemo_gym" - save_steps: int = 100 - report_to: str = "none" - run_name: str = None # Wandb - project_name: str = None # Wandb - log_completions: bool = False - num_completions_to_print: int = None - - eval_dataset_path: str | None = None - eval_strategy: str = "no" - eval_steps: int = 50 - - vllm_importance_sampling_correction: bool = False - - def get_agent_servers( head_server_host: str = "127.0.0.1", head_server_port: int = 11000, @@ -357,43 +320,35 @@ def main(): args = parser.parse_args() with open(args.config) as f: - config = TrainingConfig(**yaml.safe_load(f)) + config = yaml.safe_load(f) + + model_name = config.pop("model_name") + dataset_path = config.pop("dataset_path") + eval_dataset_path = config.pop("eval_dataset_path", None) + task = config.pop("task", None) + project_name = config.pop("project_name", None) - if isinstance(config.learning_rate, str): - config.learning_rate = float(config.learning_rate) - if isinstance(config.weight_decay, str): - config.weight_decay = float(config.weight_decay) + if "learning_rate" in config and isinstance(config["learning_rate"], str): + config["learning_rate"] = float(config["learning_rate"]) + if "weight_decay" in config and isinstance(config["weight_decay"], str): + config["weight_decay"] = float(config["weight_decay"]) agent_servers = get_agent_servers( head_server_host=args.head_server_host, head_server_port=11000, ) - if config.project_name: - os.environ["WANDB_PROJECT"] = config.project_name - - if config.run_name is None: - task = config.task or os.path.basename(config.dataset_path).replace(".jsonl", "").replace(".json", "") - model_short = config.model_name.split("/")[-1] - config.run_name = ( - f"{task}_{model_short}" - f"_rpp{config.num_generations}" - f"_dbs{config.per_device_train_batch_size}" - f"_ga{config.gradient_accumulation_steps}" - f"_maxlen{config.max_seq_length}" - f"_lr{config.learning_rate}" - f"_temp{config.temperature}" - f"_topp{config.top_p}" - ) + if project_name: + os.environ["WANDB_PROJECT"] = project_name - if config.dataset_path.endswith((".jsonl", ".json")): - dataset = load_dataset_from_jsonl(config.dataset_path) + if dataset_path.endswith((".jsonl", ".json")): + dataset = load_dataset_from_jsonl(dataset_path) else: - dataset = load_dataset(config.dataset_path, split="train") + dataset = load_dataset(dataset_path, split="train") eval_dataset = None - if config.eval_dataset_path: - eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path) + if eval_dataset_path: + eval_dataset = load_dataset_from_jsonl(eval_dataset_path) print(f"Eval dataset has {len(eval_dataset)} examples\n") training_args = NeMoGymGRPOConfig( @@ -402,45 +357,37 @@ def main(): vllm_server_host=args.vllm_server_host, vllm_server_port=8000, gradient_checkpointing=True, - temperature=config.temperature, - learning_rate=config.learning_rate, - weight_decay=config.weight_decay, - warmup_ratio=config.warmup_ratio, - warmup_steps=config.warmup_steps, - lr_scheduler_type=config.lr_scheduler_type, - optim=config.optim, - per_device_train_batch_size=config.per_device_train_batch_size, - gradient_accumulation_steps=config.gradient_accumulation_steps, - num_generations=config.num_generations, num_generations_eval=1, - max_steps=config.max_steps, - save_steps=config.save_steps, logging_steps=1, - report_to=config.report_to, - output_dir=config.output_dir, - run_name=config.run_name, - eval_strategy=config.eval_strategy, - eval_steps=config.eval_steps, - vllm_importance_sampling_correction=config.vllm_importance_sampling_correction, epsilon=0.2, epsilon_high=0.28, loss_type="grpo", mask_truncated_completions=True, - log_completions=config.log_completions, - num_completions_to_print=config.num_completions_to_print, - max_completion_length=config.max_seq_length, shuffle_dataset=False, - model_init_kwargs={ - "torch_dtype": "auto", - }, + model_init_kwargs={"torch_dtype": "auto"}, agent_servers=agent_servers, request_timeout=10800, + **config, ) - tokenizer = AutoTokenizer.from_pretrained(config.model_name, truncation_side="left", padding_side="left") + if training_args.run_name is None: + task_name = task or os.path.basename(dataset_path).replace(".jsonl", "").replace(".json", "") + model_short = model_name.split("/")[-1] + training_args.run_name = ( + f"{task_name}_{model_short}" + f"_rpp{training_args.num_generations}" + f"_dbs{training_args.per_device_train_batch_size}" + f"_ga{training_args.gradient_accumulation_steps}" + f"_maxlen{training_args.max_completion_length}" + f"_lr{training_args.learning_rate}" + f"_temp{training_args.temperature}" + f"_topp{training_args.top_p}" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side="left", padding_side="left") trainer = GRPOTrainer( - model=config.model_name, + model=model_name, processing_class=tokenizer, reward_funcs=reward_fn, train_dataset=dataset, From 6d7e8d0da9bfde14b4097b73063449e1fee3fad0 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 02:46:51 -0800 Subject: [PATCH 44/51] config update Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/config.yaml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 448cd07da9f..2efa5b30ae0 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,8 +1,11 @@ -model_name: "Qwen/Qwen3-4B-Instruct-2507" +# Model +model_name: "Qwen/Qwen2.5-1.5B-Instruct" -dataset_path: "/path/to/data/train.jsonl" -eval_dataset_path: "/path/to/data/val.jsonl" +# Data +dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/train.jsonl" +eval_dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/validation.jsonl" +# Logging output_dir: "outputs/nemo_gym" task: "workplace" # just used in wandb run name report_to: "wandb" @@ -10,6 +13,7 @@ project_name: "trl-nemo-gym" log_completions: true num_completions_to_print: 2 +# Training hyperparameters learning_rate: 1.0e-5 max_steps: 1000 num_generations: 8 @@ -22,11 +26,12 @@ optim: "adamw_torch_fused" weight_decay: 0.0 vllm_importance_sampling_correction: true +# Inference sampling parameters temperature: 1.0 top_p: 0.999 +# Checkpointing and Eval save_steps: 10 - eval_strategy: "steps" eval_steps: 10 From 13c378ceed933a7fb8a2472c1f4319d457c25556 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 02:49:23 -0800 Subject: [PATCH 45/51] docs Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train_multi_environment.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py index 3dbe58b8a37..b28dd4158f2 100644 --- a/examples/scripts/nemo_gym/train_multi_environment.py +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -38,8 +38,6 @@ @dataclass class NeMoGymGRPOConfig(GRPOConfig): - """GRPOConfig subclass with NeMo Gym specific fields.""" - agent_servers: dict[str, str] | None = None request_timeout: float = 10800 From c5dcb5d89cacf9dad450fb79b9dce90a15fbf206 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 02:53:01 -0800 Subject: [PATCH 46/51] typo in submit Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/submit.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/nemo_gym/submit.sh b/examples/scripts/nemo_gym/submit.sh index 49fa9ed8bd1..c819c0fa45d 100644 --- a/examples/scripts/nemo_gym/submit.sh +++ b/examples/scripts/nemo_gym/submit.sh @@ -31,7 +31,7 @@ mkdir -p ${LOG_DIR} echo "Starting ng_run and vLLM on ${VLLM_NODE}..." echo "Logs will be saved to: ${LOG_DIR}" -# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below! +# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below! srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ --container-image="${CONTAINER_IMAGE}" \ @@ -92,7 +92,7 @@ srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ export HOME=/path/to/user && \ export HF_HOME=/path/to/user/hf_home && \ cd /path/to/user/trl && \ - source .venv/bin/activate && uv pip install accelerate deepseed wandb omegaconf && \ + source .venv/bin/activate && uv pip install accelerate deepspeed wandb omegaconf && \ cd examples/scripts/nemo_gym && \ export WANDB_API_KEY= && \ accelerate launch \ From b4678fb81234ac641c71aa3d2f6036bdf74a4707 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 03:45:44 -0800 Subject: [PATCH 47/51] improve nemo gym docs Signed-off-by: cmunley1 --- docs/source/nemo_gym.md | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index a1eea762203..342721e31b4 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -182,7 +182,7 @@ The following steps run in 3 terminals. It can also be ran with processes in the 1. **Start NeMo Gym Servers** (Terminal 1) ```bash - cd Gym + cd Gym/ source .venv/bin/activate config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ @@ -200,8 +200,8 @@ The following steps run in 3 terminals. It can also be ran with processes in the 1. **Start TRL vLLM Server on GPU 0** (Terminal 2) ```bash - cd trl - + cd trl/ + source .venv/bin/activate CUDA_VISIBLE_DEVICES=0 trl vllm-serve \ --model Qwen/Qwen2.5-1.5B-Instruct \ --max-model-len 16384 \ @@ -212,14 +212,10 @@ The following steps run in 3 terminals. It can also be ran with processes in the 1. **Run Training on GPU 1** (Terminal 3) ```bash - cd trl/ - source .venv/bin/activate - - cd examples/scripts/nemo_gym - - # if using wandb - export WANDB_API_KEY=... - uv pip install wandb + source trl/.venv/bin/activate + cd trl/examples/scripts/nemo_gym + export WANDB_API_KEY=... + uv add omegaconf CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config_workplace.yaml ``` @@ -326,4 +322,4 @@ Train on multiple NeMo Gym environments simultaneously. This allows learning div - [NeMo Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) - [NeMo Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) - [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) -- [TRL GRPO Trainer](grpo_trainer) \ No newline at end of file +- [TRL GRPO Trainer](grpo_trainer) From a476ac5ac43e18c5b09be637102f1c488b3674f8 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 03:54:46 -0800 Subject: [PATCH 48/51] update docs Signed-off-by: cmunley1 --- docs/source/nemo_gym.md | 42 +++++------------------------------------ 1 file changed, 5 insertions(+), 37 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index 342721e31b4..5fda978d3f7 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -122,48 +122,12 @@ NeMo Gym datasets are stored as JSONL. Each line contains a task with input mess } ``` -### Create Training Config - -Create a config file, `config_workplace.yaml`: - -```yaml -model_name: "Qwen/Qwen2.5-1.5B-Instruct" - -dataset_path: "data/workplace_assistant/train.jsonl" -eval_dataset_path: "data/workplace_assistant/validation.jsonl" - -task: 'workplace' # used in wandb run name -output_dir: "outputs/nemo_gym" -report_to: "wandb" # set to none if you don't have wandb set up. -project_name: "trl-nemo-gym" - -learning_rate: 1.0e-5 -max_steps: 1000 -num_generations: 8 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 4 -max_completion_length: 16384 - -temperature: 1.0 -top_p: 0.999 - -save_steps: 10 -eval_strategy: "steps" -eval_steps: 10 -``` - ## Interactive Training For development and testing on a single node. ### Set Up -1. **Verify Prerequisites** - - Confirm you have completed the [Before You Start](#before-you-start) section: - - Dataset prepared in `data/workplace_assistant/` - - Training config created (`config_workplace.yaml`) - 1. **Update Environment Config** Update `env.yaml` in `Gym/` to include model information: @@ -175,6 +139,10 @@ For development and testing on a single node. hf_token: ... ``` +2. **Update Training Config** + + Update `examples/scripts/nemo_gym/config.yaml` to point to the dataset generated above, and any other optional modifications. + ### Run Training The following steps run in 3 terminals. It can also be ran with processes in the background, or using tmux. @@ -226,7 +194,7 @@ An example five-node training script is provided in `submit.sh`. Nodes one throu 1. **Configure the Script** - Update `submit.sh` with your Slurm account, partition, path to Gym repository, and training config. + Update `submit.sh` with your Slurm account, partition, paths to your project directory, and updated training configs. 1. **Submit the Job** From 5e70a33a22c7899d884a0c53a6c45b76ecf86196 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Mon, 2 Feb 2026 13:52:05 -0800 Subject: [PATCH 49/51] rename project to server Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train_multi_environment.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py index b28dd4158f2..3ec95bac980 100644 --- a/examples/scripts/nemo_gym/train_multi_environment.py +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -53,16 +53,16 @@ def get_agent_servers( global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) agent_servers = {} - for project_name, project_config in global_config_dict.items(): - if hasattr(project_config, "responses_api_agents"): - agents = project_config.responses_api_agents + for server_name, server_config in global_config_dict.items(): + if hasattr(server_config, "responses_api_agents"): + agents = server_config.responses_api_agents for agent_key in agents.keys(): agent_config = getattr(agents, agent_key) if hasattr(agent_config, "host") and hasattr(agent_config, "port"): agent_host = agent_config.host if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"): agent_host = head_server_host - agent_servers[project_name] = f"http://{agent_host}:{agent_config.port}" + agent_servers[server_name] = f"http://{agent_host}:{agent_config.port}" if not agent_servers: raise ValueError("No agents found in global config") From a3f241e20f31018b870d239798a93fc5ff996508 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Mon, 2 Feb 2026 13:52:30 -0800 Subject: [PATCH 50/51] vllm finish reason Signed-off-by: cmunley1 --- trl/scripts/vllm_serve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index ce94ad1f539..13120274d00 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -1127,7 +1127,7 @@ async def chat_completions(request: ChatCompletionRequest): text = gen_output.text if hasattr(gen_output, "text") else "" tool_calls = None - finish_reason = "stop" + finish_reason = gen_output.finish_reason if hasattr(gen_output, "finish_reason") else "stop" # Manual XML-json tool call parsing if request.tools and text: From e123a88f7e70ece6558051ea2d5b86b3cd3559f6 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Wed, 4 Feb 2026 17:17:54 +0100 Subject: [PATCH 51/51] Update docs/source/nemo_gym.md --- docs/source/nemo_gym.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index 5fda978d3f7..62c4e49966d 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -185,7 +185,7 @@ The following steps run in 3 terminals. It can also be ran with processes in the export WANDB_API_KEY=... uv add omegaconf - CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config_workplace.yaml + CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config.yaml ``` ## Multi-Node Training with Slurm