From 93d97a7c03d35c9305122d7d4f338d868830be8b Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Fri, 30 Jan 2026 08:50:10 -0800 Subject: [PATCH 1/5] fixes based on review Signed-off-by: Christian Munley --- docs/source/nemo_gym.md | 6 +-- .../nemo_gym/train_multi_environment.py | 37 ++++++++++++------- tests/test_vllm_client_server.py | 3 +- trl/scripts/vllm_serve.py | 10 +++-- 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index 47592dcd2c6..ec7e8afdc7e 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -61,7 +61,7 @@ Many NeMo Gym datasets used to train Nemotron models are available on Hugging Fa - Validates the data format - Adds an `agent_ref` field to each example that tells NeMo Gym which agent server should handle that example -> **Note**: `run_grpo_nemo_gym.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way. +> **Note**: `train_multi_environment.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way. 1. **Set Hugging Face Token** @@ -221,7 +221,7 @@ The following steps run in 3 terminals. It can also be ran with processes in the export WANDB_API_KEY=... uv pip install wandb - CUDA_VISIBLE_DEVICES=1 python run_grpo_nemo_gym.py --config config_workplace.yaml + CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config_workplace.yaml ``` ## Multi-Node Training with Slurm @@ -325,5 +325,5 @@ Train on multiple NeMo Gym environments simultaneously. This allows learning div - [NeMo Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) - [NeMo Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) -- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/run_grpo_nemo_gym.py) +- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) - [TRL GRPO Trainer](grpo_trainer) \ No newline at end of file diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py index 0c48e747fff..29c9f86e804 100644 --- a/examples/scripts/nemo_gym/train_multi_environment.py +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# /// script +# dependencies = [ +# "trl[vllm]", +# "nemo_gym @ git+https://github.com/NVIDIA-NeMo/Gym", +# ] +# /// + import argparse import asyncio import json @@ -21,7 +28,6 @@ import aiohttp import requests -import wandb import yaml from datasets import Dataset, load_dataset from omegaconf import OmegaConf @@ -30,6 +36,14 @@ from trl import GRPOConfig, GRPOTrainer +@dataclass +class NeMoGymGRPOConfig(GRPOConfig): + """GRPOConfig subclass with NeMo Gym specific fields.""" + + agent_servers: dict[str, str] | None = None + request_timeout: float = 10800 + + @dataclass class TrainingConfig: model_name: str @@ -43,7 +57,6 @@ class TrainingConfig: per_device_train_batch_size: int = 2 gradient_accumulation_steps: int = 16 max_seq_length: int = 1024 - max_prompt_length: int = None temperature: float = 1.0 top_p: float = 0.999 @@ -298,11 +311,11 @@ def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, raise RuntimeError("No valid rollouts. Check Nemo Gym and vLLM logs.") if num_turns_list: - wandb.log( + trainer.log( { - "train/num_turns_mean": sum(num_turns_list) / len(num_turns_list), - "train/num_turns_min": min(num_turns_list), - "train/num_turns_max": max(num_turns_list), + "num_turns_mean": sum(num_turns_list) / len(num_turns_list), + "num_turns_min": min(num_turns_list), + "num_turns_max": max(num_turns_list), } ) @@ -383,7 +396,7 @@ def main(): eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path) print(f"Eval dataset has {len(eval_dataset)} examples\n") - training_args = GRPOConfig( + training_args = NeMoGymGRPOConfig( use_vllm=True, vllm_mode="server", vllm_server_host=args.vllm_server_host, @@ -415,19 +428,15 @@ def main(): mask_truncated_completions=True, log_completions=config.log_completions, num_completions_to_print=config.num_completions_to_print, - # max_prompt_length=config.max_prompt_length, - max_completion_length=config.max_seq_length - config.max_prompt_length - if config.max_prompt_length - else config.max_seq_length, + max_completion_length=config.max_seq_length, shuffle_dataset=False, model_init_kwargs={ "torch_dtype": "auto", }, + agent_servers=agent_servers, + request_timeout=10800, ) - training_args.agent_servers = agent_servers - training_args.request_timeout = 10800 - tokenizer = AutoTokenizer.from_pretrained(config.model_name, truncation_side="left", padding_side="left") trainer = GRPOTrainer( diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index bddbca4cf7b..e3464dae6d6 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -197,7 +197,8 @@ def test_chat_completions_with_params(self): assert len(data["choices"]) == 2 - for choice in data["choices"]: + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i, f"Expected choice at position {i} to have index {i}, got {choice['index']}" assert "message" in choice assert choice["message"]["role"] == "assistant" diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index f5bf4333e81..ce94ad1f539 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -466,7 +466,7 @@ def _replace_prefix_tokens( model output. A concrete example is inconsistent whitespace tokens around tool call special tokens. Based on NeMo RL's _replace_prefix_tokens: - https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + https://github.com/NVIDIA-NeMo/RL/blob/748b9caff4e6d672b8a98a10b6e612d028cfc96b/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 """ if not model_prefix_token_ids: return template_token_ids @@ -1088,7 +1088,9 @@ async def chat_completions(request: ChatCompletionRequest): all_outputs = [connection.recv() for connection in connections] if has_prefix_token_ids: - all_outputs = [o for o in all_outputs if o] + all_outputs = [ + output for output, prompt_chunk in zip(all_outputs, chunked_prompts, strict=True) if prompt_chunk + ] else: all_outputs = [ output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk @@ -1116,7 +1118,8 @@ async def chat_completions(request: ChatCompletionRequest): total_input_tokens = 0 total_output_tokens = 0 - for idx, output in enumerate(all_outputs): + idx = 0 + for output in all_outputs: total_input_tokens += len(output.prompt_token_ids) for gen_output in output.outputs: @@ -1180,6 +1183,7 @@ async def chat_completions(request: ChatCompletionRequest): "finish_reason": finish_reason, } ) + idx += 1 return { "id": completion_id, From b15ab6305d6675ff740d45952e8140a88e940b65 Mon Sep 17 00:00:00 2001 From: Christian Munley Date: Fri, 30 Jan 2026 11:19:48 -0800 Subject: [PATCH 2/5] subclass Signed-off-by: Christian Munley --- docs/source/nemo_gym.md | 2 +- examples/scripts/nemo_gym/config.yaml | 2 +- .../nemo_gym/train_multi_environment.py | 125 +++++------------- 3 files changed, 38 insertions(+), 91 deletions(-) diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md index ec7e8afdc7e..a1eea762203 100644 --- a/docs/source/nemo_gym.md +++ b/docs/source/nemo_gym.md @@ -142,7 +142,7 @@ max_steps: 1000 num_generations: 8 per_device_train_batch_size: 1 gradient_accumulation_steps: 4 -max_seq_length: 16384 +max_completion_length: 16384 temperature: 1.0 top_p: 0.999 diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 1998e9f66fc..448cd07da9f 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -15,7 +15,7 @@ max_steps: 1000 num_generations: 8 per_device_train_batch_size: 1 gradient_accumulation_steps: 8 -max_seq_length: 16384 +max_completion_length: 16384 warmup_steps: 5 lr_scheduler_type: "linear" optim: "adamw_torch_fused" diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py index 29c9f86e804..3dbe58b8a37 100644 --- a/examples/scripts/nemo_gym/train_multi_environment.py +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -44,43 +44,6 @@ class NeMoGymGRPOConfig(GRPOConfig): request_timeout: float = 10800 -@dataclass -class TrainingConfig: - model_name: str - dataset_path: str - - task: str | None = None - - learning_rate: float = 5e-6 - max_steps: int = 100 - num_generations: int = 2 - per_device_train_batch_size: int = 2 - gradient_accumulation_steps: int = 16 - max_seq_length: int = 1024 - - temperature: float = 1.0 - top_p: float = 0.999 - weight_decay: float = 0.01 - warmup_ratio: float = 0.0 - warmup_steps: int = 0 - lr_scheduler_type: str = "linear" - optim: str = "adamw_8bit" - - output_dir: str = "outputs/trl_nemo_gym" - save_steps: int = 100 - report_to: str = "none" - run_name: str = None # Wandb - project_name: str = None # Wandb - log_completions: bool = False - num_completions_to_print: int = None - - eval_dataset_path: str | None = None - eval_strategy: str = "no" - eval_steps: int = 50 - - vllm_importance_sampling_correction: bool = False - - def get_agent_servers( head_server_host: str = "127.0.0.1", head_server_port: int = 11000, @@ -357,43 +320,35 @@ def main(): args = parser.parse_args() with open(args.config) as f: - config = TrainingConfig(**yaml.safe_load(f)) + config = yaml.safe_load(f) + + model_name = config.pop("model_name") + dataset_path = config.pop("dataset_path") + eval_dataset_path = config.pop("eval_dataset_path", None) + task = config.pop("task", None) + project_name = config.pop("project_name", None) - if isinstance(config.learning_rate, str): - config.learning_rate = float(config.learning_rate) - if isinstance(config.weight_decay, str): - config.weight_decay = float(config.weight_decay) + if "learning_rate" in config and isinstance(config["learning_rate"], str): + config["learning_rate"] = float(config["learning_rate"]) + if "weight_decay" in config and isinstance(config["weight_decay"], str): + config["weight_decay"] = float(config["weight_decay"]) agent_servers = get_agent_servers( head_server_host=args.head_server_host, head_server_port=11000, ) - if config.project_name: - os.environ["WANDB_PROJECT"] = config.project_name - - if config.run_name is None: - task = config.task or os.path.basename(config.dataset_path).replace(".jsonl", "").replace(".json", "") - model_short = config.model_name.split("/")[-1] - config.run_name = ( - f"{task}_{model_short}" - f"_rpp{config.num_generations}" - f"_dbs{config.per_device_train_batch_size}" - f"_ga{config.gradient_accumulation_steps}" - f"_maxlen{config.max_seq_length}" - f"_lr{config.learning_rate}" - f"_temp{config.temperature}" - f"_topp{config.top_p}" - ) + if project_name: + os.environ["WANDB_PROJECT"] = project_name - if config.dataset_path.endswith((".jsonl", ".json")): - dataset = load_dataset_from_jsonl(config.dataset_path) + if dataset_path.endswith((".jsonl", ".json")): + dataset = load_dataset_from_jsonl(dataset_path) else: - dataset = load_dataset(config.dataset_path, split="train") + dataset = load_dataset(dataset_path, split="train") eval_dataset = None - if config.eval_dataset_path: - eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path) + if eval_dataset_path: + eval_dataset = load_dataset_from_jsonl(eval_dataset_path) print(f"Eval dataset has {len(eval_dataset)} examples\n") training_args = NeMoGymGRPOConfig( @@ -402,45 +357,37 @@ def main(): vllm_server_host=args.vllm_server_host, vllm_server_port=8000, gradient_checkpointing=True, - temperature=config.temperature, - learning_rate=config.learning_rate, - weight_decay=config.weight_decay, - warmup_ratio=config.warmup_ratio, - warmup_steps=config.warmup_steps, - lr_scheduler_type=config.lr_scheduler_type, - optim=config.optim, - per_device_train_batch_size=config.per_device_train_batch_size, - gradient_accumulation_steps=config.gradient_accumulation_steps, - num_generations=config.num_generations, num_generations_eval=1, - max_steps=config.max_steps, - save_steps=config.save_steps, logging_steps=1, - report_to=config.report_to, - output_dir=config.output_dir, - run_name=config.run_name, - eval_strategy=config.eval_strategy, - eval_steps=config.eval_steps, - vllm_importance_sampling_correction=config.vllm_importance_sampling_correction, epsilon=0.2, epsilon_high=0.28, loss_type="grpo", mask_truncated_completions=True, - log_completions=config.log_completions, - num_completions_to_print=config.num_completions_to_print, - max_completion_length=config.max_seq_length, shuffle_dataset=False, - model_init_kwargs={ - "torch_dtype": "auto", - }, + model_init_kwargs={"torch_dtype": "auto"}, agent_servers=agent_servers, request_timeout=10800, + **config, ) - tokenizer = AutoTokenizer.from_pretrained(config.model_name, truncation_side="left", padding_side="left") + if training_args.run_name is None: + task_name = task or os.path.basename(dataset_path).replace(".jsonl", "").replace(".json", "") + model_short = model_name.split("/")[-1] + training_args.run_name = ( + f"{task_name}_{model_short}" + f"_rpp{training_args.num_generations}" + f"_dbs{training_args.per_device_train_batch_size}" + f"_ga{training_args.gradient_accumulation_steps}" + f"_maxlen{training_args.max_completion_length}" + f"_lr{training_args.learning_rate}" + f"_temp{training_args.temperature}" + f"_topp{training_args.top_p}" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side="left", padding_side="left") trainer = GRPOTrainer( - model=config.model_name, + model=model_name, processing_class=tokenizer, reward_funcs=reward_fn, train_dataset=dataset, From 6d7e8d0da9bfde14b4097b73063449e1fee3fad0 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 02:46:51 -0800 Subject: [PATCH 3/5] config update Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/config.yaml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml index 448cd07da9f..2efa5b30ae0 100644 --- a/examples/scripts/nemo_gym/config.yaml +++ b/examples/scripts/nemo_gym/config.yaml @@ -1,8 +1,11 @@ -model_name: "Qwen/Qwen3-4B-Instruct-2507" +# Model +model_name: "Qwen/Qwen2.5-1.5B-Instruct" -dataset_path: "/path/to/data/train.jsonl" -eval_dataset_path: "/path/to/data/val.jsonl" +# Data +dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/train.jsonl" +eval_dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/validation.jsonl" +# Logging output_dir: "outputs/nemo_gym" task: "workplace" # just used in wandb run name report_to: "wandb" @@ -10,6 +13,7 @@ project_name: "trl-nemo-gym" log_completions: true num_completions_to_print: 2 +# Training hyperparameters learning_rate: 1.0e-5 max_steps: 1000 num_generations: 8 @@ -22,11 +26,12 @@ optim: "adamw_torch_fused" weight_decay: 0.0 vllm_importance_sampling_correction: true +# Inference sampling parameters temperature: 1.0 top_p: 0.999 +# Checkpointing and Eval save_steps: 10 - eval_strategy: "steps" eval_steps: 10 From 13c378ceed933a7fb8a2472c1f4319d457c25556 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 02:49:23 -0800 Subject: [PATCH 4/5] docs Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/train_multi_environment.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py index 3dbe58b8a37..b28dd4158f2 100644 --- a/examples/scripts/nemo_gym/train_multi_environment.py +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -38,8 +38,6 @@ @dataclass class NeMoGymGRPOConfig(GRPOConfig): - """GRPOConfig subclass with NeMo Gym specific fields.""" - agent_servers: dict[str, str] | None = None request_timeout: float = 10800 From c5dcb5d89cacf9dad450fb79b9dce90a15fbf206 Mon Sep 17 00:00:00 2001 From: cmunley1 Date: Sat, 31 Jan 2026 02:53:01 -0800 Subject: [PATCH 5/5] typo in submit Signed-off-by: cmunley1 --- examples/scripts/nemo_gym/submit.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/scripts/nemo_gym/submit.sh b/examples/scripts/nemo_gym/submit.sh index 49fa9ed8bd1..c819c0fa45d 100644 --- a/examples/scripts/nemo_gym/submit.sh +++ b/examples/scripts/nemo_gym/submit.sh @@ -31,7 +31,7 @@ mkdir -p ${LOG_DIR} echo "Starting ng_run and vLLM on ${VLLM_NODE}..." echo "Logs will be saved to: ${LOG_DIR}" -# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below! +# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below! srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ --container-image="${CONTAINER_IMAGE}" \ @@ -92,7 +92,7 @@ srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ export HOME=/path/to/user && \ export HF_HOME=/path/to/user/hf_home && \ cd /path/to/user/trl && \ - source .venv/bin/activate && uv pip install accelerate deepseed wandb omegaconf && \ + source .venv/bin/activate && uv pip install accelerate deepspeed wandb omegaconf && \ cd examples/scripts/nemo_gym && \ export WANDB_API_KEY= && \ accelerate launch \