diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 16a420ff0d4..d726c45f284 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -117,6 +117,8 @@ title: MiniLLM - local: nash_md_trainer title: Nash-MD + - local: nemo_gym + title: NeMo Gym - local: online_dpo_trainer title: Online DPO - local: orpo_trainer diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 67367a6844b..b78db67020e 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -61,6 +61,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`experimental.kto.KTOTrainer`] to fine-tune a model. | | [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. | | [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. | +| [`examples/scripts/nemo_gym/train_multi_environment.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) | This script shows how to use the [`GRPOTrainer`] to train language models in NVIDIA NeMo-Gym environments. Supports multi-turn and tool calling environments, and multi-environment training. See the [NeMo-Gym Integration](nemo_gym) guide for setup and usage. | | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | | [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM for VLMs | diff --git a/docs/source/nemo_gym.md b/docs/source/nemo_gym.md new file mode 100644 index 00000000000..62c4e49966d --- /dev/null +++ b/docs/source/nemo_gym.md @@ -0,0 +1,293 @@ +# NeMo Gym Integration + +NVIDIA NeMo Gym is a library for building RL environments for large language models. This integration enables training models in NeMo Gym environments using TRL's GRPOTrainer with vLLM server mode. + +The integration supports multi-step and multi-turn rollouts, multi-environment training, and any NeMo Gym environment (thoroughly tested: workplace assistant, reasoning gym, MCQA, and math with judge). + +## Why NeMo Gym + +- **Production-Ready Scale**: Tested for frontier model training with diverse environments running in parallel across math, coding, tool use, reasoning, and more. +- **Multi-Verifier Training**: Supports algorithmic verification, LLM-as-a-judge, and custom verification logic in a single training run. +- **Decoupled Architecture**: Build agents and environments independently from the training loop—no RL framework expertise required. +- **OpenAI-Compatible API**: All environments use the standardized OpenAI Responses API for seamless integration with vLLM, OpenAI models, and other endpoints. + +## Available Environments + +NeMo Gym provides training-ready environments across multiple domains, including but not limited to: + +| Environment | Domain | Description | +|-------------|--------|-------------| +| Workplace Assistant | Agent | Multi-step tool calling in common office scenarios (calendar, email, and more) | +| Math with Judge | Math | Math problems with algorithmic or judge-based verification | +| Code Gen | Coding | Competitive programming problems with code execution | +| MCQA | Knowledge | Multiple-choice question answering | +| Instruction Following | Instruction Following | IFEval/IFBench style tasks | +| Reasoning Gym | Multiple | Single-step procedurally generated verifiable tasks across domains | + +For a complete list of available training environments, refer to the [NeMo Gym repository](https://github.com/NVIDIA-NeMo/Gym#-available-resource-servers). + +## Before You Start + +Complete these one-time setup steps before running training. + +### Install TRL and NeMo Gym + +1. **Install TRL with vLLM extras** + + ```bash + cd trl/ + uv venv + source .venv/bin/activate + uv sync --extra vllm + ``` + +1. **Install NeMo Gym** + + ```bash + # deactivate trl venv + deactivate + git clone https://github.com/NVIDIA-NeMo/Gym.git + cd Gym + uv venv --python 3.12 + source .venv/bin/activate + uv sync + ``` + +### Prepare a Dataset + +Many NeMo Gym datasets used to train Nemotron models are available on Hugging Face. Use `ng_prepare_data` to download and prepare datasets. This command: + +- Downloads the dataset from Hugging Face +- Validates the data format +- Adds an `agent_ref` field to each example that tells NeMo Gym which agent server should handle that example + +> **Note**: `train_multi_environment.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way. + +1. **Set Hugging Face Token** + + Create `env.yaml` in `Gym/` with your HF token: + + ```yaml + hf_token: + ``` + +1. **Prepare Dataset** + + ```bash + # Enter Gym and activate the venv + cd Gym + source .venv/bin/activate + + # Set config paths + config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ + resources_servers/workplace_assistant/configs/workplace_assistant.yaml" + + # Download data and prep for training + ng_prepare_data "+config_paths=[${config_paths}]" \ + +output_dirpath=data/workplace_assistant \ + +mode=train_preparation \ + +should_download=true \ + +data_source=huggingface + ``` + + This creates `train.jsonl` and `validation.jsonl` files in `data/workplace_assistant/`. + +To create a new environment, refer to the [environment creation guide](https://docs.nvidia.com/nemo/gym/latest/contribute/environments/new-environment.html). We suggest running an existing one first! + +#### Dataset Format + +NeMo Gym datasets are stored as JSONL. Each line contains a task with input messages, tool definitions, metadata such as ground truth for verification, and an agent server reference. The following example shows the workplace dataset structure. Metadata fields can differ between datasets, as long as the corresponding resources server uses the fields appropriately. + +```json +{ + "responses_create_params": { + "input": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "Move any of jinsoo's tasks that are in review to completed"} + ], + "tools": [...], + "parallel_tool_calls": false, + "temperature": 1 + }, + "ground_truth": [ + {"name": "project_management_update_task", "arguments": "{...}"}, + ... + ], + "category": "workbench_project_management", + "environment_name": "workbench", + "agent_ref": { + "type": "responses_api_agents", + "name": "workplace_assistant_simple_agent" + } +} +``` + +## Interactive Training + +For development and testing on a single node. + +### Set Up + +1. **Update Environment Config** + + Update `env.yaml` in `Gym/` to include model information: + + ```yaml + policy_base_url: http://127.0.0.1:8000/v1 + policy_api_key: EMPTY + policy_model_name: Qwen/Qwen2.5-1.5B-Instruct + hf_token: ... + ``` + +2. **Update Training Config** + + Update `examples/scripts/nemo_gym/config.yaml` to point to the dataset generated above, and any other optional modifications. + +### Run Training + +The following steps run in 3 terminals. It can also be ran with processes in the background, or using tmux. + +1. **Start NeMo Gym Servers** (Terminal 1) + + ```bash + cd Gym/ + source .venv/bin/activate + + config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + responses_api_models/vllm_model/configs/vllm_model_for_training.yaml" + + ng_run "+config_paths=[${config_paths}]" + ``` + + This starts: + - **Agent server**: Orchestrates rollouts using resource servers and model servers + - **Resources server**: Supports environment logic such as state-management, tool implementations, and task verification + - **Model server**: Adapts vLLM server requests to support NeMo Gym agents and on-policy RL training while ensuring OpenAI API compatibility + - **Head server**: Manages servers used in training enabling their discovery + +1. **Start TRL vLLM Server on GPU 0** (Terminal 2) + + ```bash + cd trl/ + source .venv/bin/activate + CUDA_VISIBLE_DEVICES=0 trl vllm-serve \ + --model Qwen/Qwen2.5-1.5B-Instruct \ + --max-model-len 16384 \ + --host 0.0.0.0 \ + --port 8000 + ``` + +1. **Run Training on GPU 1** (Terminal 3) + + ```bash + source trl/.venv/bin/activate + cd trl/examples/scripts/nemo_gym + export WANDB_API_KEY=... + uv add omegaconf + + CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config.yaml + ``` + +## Multi-Node Training with Slurm + +An example five-node training script is provided in `submit.sh`. Nodes one through four run the training algorithm, while node five runs vLLM inference for NeMo Gym agent rollouts. + +1. **Configure the Script** + + Update `submit.sh` with your Slurm account, partition, paths to your project directory, and updated training configs. + +1. **Submit the Job** + + ```bash + sbatch submit.sh + ``` + +1. **Monitor Training** + + ```bash + tail -f logs//* + ``` + +> **Tip**: Set up wandb logging for detailed training metrics. For more details on TRL's vLLM integration, refer to the vLLM integration page. + +## Multi-Environment Training + +Train on multiple NeMo Gym environments simultaneously. This allows learning diverse capabilities (such as tool calling and math reasoning) in a single training run. + +1. **Prepare Individual Datasets** + + Prepare datasets for each environment. The workplace assistant dataset was prepared above. Now lets create a dataset for the mini sudoku environment implemented by the reasoning gym resources server in NeMo Gym: + + ```bash + cd Gym + source .venv/bin/activate + uv add reasoning-gym + cd resources_servers/reasoning_gym + python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 2000 \ + --seed 42 \ + --output data/reasoning_gym/train_mini_sudoku.jsonl + + python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 50 \ + --seed 24 \ + --output data/reasoning_gym/val_mini_sudoku.jsonl + ``` + +1. **Create Combined Dataset** + + Combine datasets into a single file with tasks from both environments: + + ```bash + cat data/workplace_assistant/train_workplace.jsonl data/reasoning_gym/train_mini_sudoku.jsonl | shuf > train_multi_env.jsonl + ``` + + > **Tip**: Ensure datasets are the same size before shuffling for an even blend of tasks. Repeat for the validation dataset. + +1. **Update Training Config** + + Create `config_multi_env.yaml` pointing to the combined dataset: + + ```yaml + model_name: "Qwen/Qwen3-4B-Instruct-2507" + + dataset_path: "/path/to/data/train_multi_env.jsonl" + eval_dataset_path: "/path/to/data/val_multi_env.jsonl" + + task: "workplace-sudoku" # used in wandb run name + output_dir: "outputs/nemo_gym_multi_env" + + # ... rest of config same + ``` + +1. **Update ng_run** + + Whether training interactively or via Slurm, update the `ng_run` command to include config files from each resources server: + + ```bash + cd Gym + source .venv/bin/activate + + config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ + resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + resources_servers/reasoning_gym/configs/reasoning_gym.yaml" + + ng_run "+config_paths=[${config_paths}]" +head_server.host=0.0.0.0 + ``` + + This starts servers for both environments. The training script automatically routes each example to the correct agent server based on its `agent_ref` field. + +1. **Run Training** + + Update the Slurm submission script to use the new training config and both `ng_run` resources server configs, then submit the job as before. + + The training script reads `agent_ref` from each example's metadata, routes requests to the correct NeMo Gym agent server, and handles different agents and environments in the same batch. + +## Resources + +- [NeMo Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) +- [NeMo Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) +- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) +- [TRL GRPO Trainer](grpo_trainer) diff --git a/examples/scripts/nemo_gym/README.md b/examples/scripts/nemo_gym/README.md new file mode 100644 index 00000000000..23784c594cd --- /dev/null +++ b/examples/scripts/nemo_gym/README.md @@ -0,0 +1,5 @@ +# Post-training with NeMo Gym and TRL + +This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported, including multi-environment training. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode. + +Check out the docs page `docs/source/nemo_gym.md` for a guide. \ No newline at end of file diff --git a/examples/scripts/nemo_gym/config.yaml b/examples/scripts/nemo_gym/config.yaml new file mode 100644 index 00000000000..2efa5b30ae0 --- /dev/null +++ b/examples/scripts/nemo_gym/config.yaml @@ -0,0 +1,37 @@ +# Model +model_name: "Qwen/Qwen2.5-1.5B-Instruct" + +# Data +dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/train.jsonl" +eval_dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/validation.jsonl" + +# Logging +output_dir: "outputs/nemo_gym" +task: "workplace" # just used in wandb run name +report_to: "wandb" +project_name: "trl-nemo-gym" +log_completions: true +num_completions_to_print: 2 + +# Training hyperparameters +learning_rate: 1.0e-5 +max_steps: 1000 +num_generations: 8 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +max_completion_length: 16384 +warmup_steps: 5 +lr_scheduler_type: "linear" +optim: "adamw_torch_fused" +weight_decay: 0.0 +vllm_importance_sampling_correction: true + +# Inference sampling parameters +temperature: 1.0 +top_p: 0.999 + +# Checkpointing and Eval +save_steps: 10 +eval_strategy: "steps" +eval_steps: 10 + diff --git a/examples/scripts/nemo_gym/deepspeed_zero3.yaml b/examples/scripts/nemo_gym/deepspeed_zero3.yaml new file mode 100644 index 00000000000..ac6ad51adb0 --- /dev/null +++ b/examples/scripts/nemo_gym/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 4 +num_processes: 32 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/scripts/nemo_gym/submit.sh b/examples/scripts/nemo_gym/submit.sh new file mode 100644 index 00000000000..c819c0fa45d --- /dev/null +++ b/examples/scripts/nemo_gym/submit.sh @@ -0,0 +1,112 @@ +#!/bin/bash +#SBATCH -A account +#SBATCH -p partition +#SBATCH -N 5 +#SBATCH --gres gpu:8 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --time=4:00:00 +#SBATCH --job-name=trl_nemo_gym +#SBATCH --output=logs/%j/slurm.out +#SBATCH --error=logs/%j/slurm.err + +CONTAINER_IMAGE="nvcr.io/nvidia/pytorch:25.12-py3" +MOUNTS="/path/to/mounts:/path/to/mounts" + +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) + +TRAIN_NODE_0="${NODELIST[0]}" +TRAIN_NODE_1="${NODELIST[1]}" +TRAIN_NODE_2="${NODELIST[2]}" +TRAIN_NODE_3="${NODELIST[3]}" +VLLM_NODE="${NODELIST[4]}" + +echo "Training Nodes: $TRAIN_NODE_0, $TRAIN_NODE_1, $TRAIN_NODE_2, $TRAIN_NODE_3" +echo "vLLM Node: $VLLM_NODE" +echo "Main process IP: $TRAIN_NODE_0" + +LOG_DIR="logs/${SLURM_JOB_ID}" +mkdir -p ${LOG_DIR} + +echo "Starting ng_run and vLLM on ${VLLM_NODE}..." +echo "Logs will be saved to: ${LOG_DIR}" + +# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below! + +srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${MOUNTS}" \ + --container-mount-home \ + bash -c " + LOG_DIR=/path/to/logs + mkdir -p \${LOG_DIR} + + # Install uv if not already installed + curl -LsSf https://astral.sh/uv/install.sh | sh + source \$HOME/.local/bin/env + + # Start nemo gym servers + (set -x && \ + export HOME=/path/to/user && \ + export PATH=\$HOME/.local/bin:\$PATH && \ + cd /path/to/user/Gym && \ + uv venv --python 3.12 && \ + source .venv/bin/activate && \ + uv sync && \ + ray stop --force && \ + ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0 +head_server.port=11000) > \${LOG_DIR}/ng_run.log 2>&1 & + + sleep 10 + + # Start trl vllm server + (set -x && \ + export HOME=/path/to/user && \ + export HF_HOME=/path/to/user/hf_home && \ + cd /path/to/user/trl && \ + rm -rf .venv && uv venv && source .venv/bin/activate && uv sync && uv pip install -e .[vllm] && uv pip install fastapi uvicorn && \ + python -m trl.scripts.vllm_serve \ + --model Qwen/Qwen3-4B-Instruct-2507 \ + --host 0.0.0.0 \ + --tensor-parallel-size 8 \ + --data-parallel-size 1 \ + --max-model-len 16384 \ + --gpu-memory-utilization 0.7 \ + --port 8000) > \${LOG_DIR}/vllm_serve.log 2>&1 & + + wait +" & + +echo "Waiting for nemo gym and vllm to start..." +sleep 120 + +echo "Launching training on 4 nodes..." + +TRAIN_NODES_LIST="${TRAIN_NODE_0},${TRAIN_NODE_1},${TRAIN_NODE_2},${TRAIN_NODE_3}" + +srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${MOUNTS}" \ + --container-mount-home \ + bash -c " + set -x && \ + export HOME=/path/to/user && \ + export HF_HOME=/path/to/user/hf_home && \ + cd /path/to/user/trl && \ + source .venv/bin/activate && uv pip install accelerate deepspeed wandb omegaconf && \ + cd examples/scripts/nemo_gym && \ + export WANDB_API_KEY= && \ + accelerate launch \ + --config_file deepspeed_zero3.yaml \ + --num_processes 32 \ + --num_machines 4 \ + --machine_rank \$SLURM_PROCID \ + --main_process_ip ${TRAIN_NODE_0} \ + --main_process_port 29500 \ + --rdzv_backend c10d \ + train_multi_environment.py \ + --config config.yaml \ + --vllm_server_host ${VLLM_NODE} \ + --head_server_host ${VLLM_NODE}" & + +wait + diff --git a/examples/scripts/nemo_gym/train_multi_environment.py b/examples/scripts/nemo_gym/train_multi_environment.py new file mode 100644 index 00000000000..3ec95bac980 --- /dev/null +++ b/examples/scripts/nemo_gym/train_multi_environment.py @@ -0,0 +1,401 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl[vllm]", +# "nemo_gym @ git+https://github.com/NVIDIA-NeMo/Gym", +# ] +# /// + +import argparse +import asyncio +import json +import os +from dataclasses import dataclass +from typing import Any + +import aiohttp +import requests +import yaml +from datasets import Dataset, load_dataset +from omegaconf import OmegaConf +from transformers import AutoTokenizer + +from trl import GRPOConfig, GRPOTrainer + + +@dataclass +class NeMoGymGRPOConfig(GRPOConfig): + agent_servers: dict[str, str] | None = None + request_timeout: float = 10800 + + +def get_agent_servers( + head_server_host: str = "127.0.0.1", + head_server_port: int = 11000, +) -> dict[str, str]: + try: + response = requests.get(f"http://{head_server_host}:{head_server_port}/global_config_dict_yaml", timeout=10) + response.raise_for_status() + global_config_yaml = response.text + global_config_dict = OmegaConf.create(yaml.safe_load(global_config_yaml)) + + agent_servers = {} + for server_name, server_config in global_config_dict.items(): + if hasattr(server_config, "responses_api_agents"): + agents = server_config.responses_api_agents + for agent_key in agents.keys(): + agent_config = getattr(agents, agent_key) + if hasattr(agent_config, "host") and hasattr(agent_config, "port"): + agent_host = agent_config.host + if agent_host in ("127.0.0.1", "0.0.0.0", "localhost"): + agent_host = head_server_host + agent_servers[server_name] = f"http://{agent_host}:{agent_config.port}" + + if not agent_servers: + raise ValueError("No agents found in global config") + + return agent_servers + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to connect to head server at {head_server_host}:{head_server_port}: {e}") from e + + +def reward_fn(completions: list[str], **kwargs) -> list[float]: + env_rewards = kwargs.get("env_reward") + assert env_rewards is not None, "env_reward not found in kwargs" + return [float(r) for r in env_rewards] + + +async def call_nemo_gym_agents( + prompts: list[str], + dataset_items: list[dict[str, Any]], + agent_servers: dict[str, str], + timeout: float, + max_completion_length: int = 4096, + temperature: float = 1.0, + top_p: float = 0.999, +) -> list[dict[str, Any]]: + async with aiohttp.ClientSession(cookie_jar=aiohttp.CookieJar()) as session: + tasks = [] + for prompt, item in zip(prompts, dataset_items, strict=False): + request_body = item.copy() + + if "responses_create_params" not in request_body: + request_body["responses_create_params"] = { + "input": [{"role": "user", "content": prompt}], + } + + params = request_body["responses_create_params"] + params.setdefault("max_output_tokens", max_completion_length) + params["temperature"] = temperature + params["top_p"] = top_p + + agent_ref = item.get("agent_ref", {}) + agent_name = agent_ref.get("name") if isinstance(agent_ref, dict) else None + if not agent_name or agent_name not in agent_servers: + raise ValueError( + f"Missing or invalid agent_ref. Got: {agent_ref}. Available: {list(agent_servers.keys())}" + ) + agent_url = agent_servers[agent_name] + + task = session.post( + f"{agent_url}/run", + json=request_body, + timeout=aiohttp.ClientTimeout(total=timeout), + ) + tasks.append(task) + + responses = await asyncio.gather(*tasks, return_exceptions=True) + + results = [] + for i, response in enumerate(responses): + try: + if isinstance(response, Exception): + raise response + json_data = await response.json() + if not isinstance(json_data, dict): + raise ValueError(f"Expected dict, got {type(json_data)}") + results.append(json_data) + except Exception as e: + print(f"WARNING: Request {i} failed: {e}") + results.append({"response": {"output": []}, "reward": 0.0, "error": str(e)}) + + return results + + +def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + is_eval = not trainer.model.training + num_generations = ( + trainer.args.num_generations_eval + if is_eval and trainer.args.num_generations_eval + else trainer.args.num_generations + ) + dataset = trainer.eval_dataset if is_eval and trainer.eval_dataset is not None else trainer.train_dataset + + expanded_prompts = [] + expanded_dataset_items = [] + + for idx_str in prompts: + idx = int(idx_str) + item = json.loads(dataset[idx]["metadata"]) + + for _ in range(num_generations): + expanded_prompts.append(idx_str) + expanded_dataset_items.append(dict(item)) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + responses = loop.run_until_complete( + call_nemo_gym_agents( + expanded_prompts, + expanded_dataset_items, + trainer.args.agent_servers, + trainer.args.request_timeout, + trainer.args.max_completion_length, + temperature=trainer.args.temperature, + top_p=trainer.args.top_p, + ) + ) + finally: + loop.close() + + tokenizer = trainer.processing_class + + prompt_ids: list[list[int]] = [] + completion_ids: list[list[int]] = [] # list of rollouts + completion_mask: list[list[int]] = [] # only train on assistant turns + + logprobs: list[list[float]] = [] + env_rewards: list[float] = [] + num_turns_list: list[int] = [] + + for i, response in enumerate(responses): + eos_token_id = tokenizer.eos_token_id or 0 + + if not isinstance(response, dict) or response.get("error"): + rollout_failed = True + else: + output_items = response.get("response", {}).get("output", []) + has_content = output_items and any( + item.get("type") == "function_call" + or ( + item.get("type") == "message" + and any( + c.get("type") == "output_text" and c.get("text", "").strip() for c in item.get("content", []) + ) + ) + for item in output_items + ) + rollout_failed = not has_content + + if rollout_failed: + prompt_ids.append([eos_token_id]) + completion_ids.append([eos_token_id]) + completion_mask.append([0]) + logprobs.append([0.0]) + env_rewards.append(0.0) + num_turns_list.append(0) + continue + + episode_reward = response.get("reward", 0.0) + output_items = response.get("response", {}).get("output", []) + + rollout_ids: list[int] = [] + rollout_mask: list[int] = [] + rollout_logprobs: list[float] = [] + + seen_token_ids: list[int] = [] + first_prompt = None + num_turns = 0 + + for _idx, item in enumerate(output_items): + if "prompt_token_ids" not in item or "generation_token_ids" not in item: + continue + + num_turns += 1 + item_prompt_ids = item["prompt_token_ids"] + item_gen_ids = item["generation_token_ids"] + item_logprobs = item.get("generation_log_probs", []) + tool_result_tokens = [] + + if first_prompt is None: + first_prompt = item_prompt_ids + seen_token_ids = list(item_prompt_ids) + else: + if len(item_prompt_ids) > len(seen_token_ids): + if item_prompt_ids[: len(seen_token_ids)] != seen_token_ids: + raise ValueError( + f"[Turn {num_turns}] Non-contiguous messages (tokenization issue). " + f"Expected prefix len {len(seen_token_ids)}, got prompt len {len(item_prompt_ids)}" + ) + tool_result_tokens = item_prompt_ids[len(seen_token_ids) :] + + if tool_result_tokens: + rollout_ids.extend(tool_result_tokens) + rollout_mask.extend([0] * len(tool_result_tokens)) + rollout_logprobs.extend([0.0] * len(tool_result_tokens)) + + rollout_ids.extend(item_gen_ids) + rollout_mask.extend([1] * len(item_gen_ids)) + assert len(item_logprobs) == len(item_gen_ids), ( + f"Logprobs len {len(item_logprobs)} != gen len {len(item_gen_ids)}" + ) + rollout_logprobs.extend(item_logprobs) + + seen_token_ids = list(item_prompt_ids) + list(item_gen_ids) + + if not rollout_ids or first_prompt is None: + raise ValueError(f"Rollout {i} has no valid turns") + + prompt_ids.append(first_prompt) # list of prompts + completion_ids.append(rollout_ids) # list of rollouts + completion_mask.append(rollout_mask) + logprobs.append(rollout_logprobs) + env_rewards.append(episode_reward) + num_turns_list.append(num_turns) + + if not prompt_ids: + raise RuntimeError("No valid rollouts. Check Nemo Gym and vLLM logs.") + + if num_turns_list: + trainer.log( + { + "num_turns_mean": sum(num_turns_list) / len(num_turns_list), + "num_turns_min": min(num_turns_list), + "num_turns_max": max(num_turns_list), + } + ) + + unique_prompt_ids = prompt_ids[::num_generations] + + return { + "prompt_ids": unique_prompt_ids, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "logprobs": logprobs, + "env_reward": env_rewards, + "num_turns": num_turns_list, + } + + +def load_dataset_from_jsonl(path: str) -> Dataset: + data = [] + with open(path) as f: + for idx, line in enumerate(f): + if line.strip(): + item = json.loads(line) + data.append( + { + "prompt": str( + idx + ), # use index for lookup as not all nemo gym datasets have the same metadata fields. maybe not the most elegant + "metadata": json.dumps(item), + } + ) + return Dataset.from_list(data) + + +def main(): + parser = argparse.ArgumentParser(description="") + parser.add_argument("--config", required=True, help="Path to config YAML file") + parser.add_argument("--vllm_server_host", type=str, default="127.0.0.1", help="vLLM server hostname/IP") + parser.add_argument("--head_server_host", type=str, default="127.0.0.1", help="Head server hostname/IP for ng_run") + parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to checkpoint to resume from") + args = parser.parse_args() + + with open(args.config) as f: + config = yaml.safe_load(f) + + model_name = config.pop("model_name") + dataset_path = config.pop("dataset_path") + eval_dataset_path = config.pop("eval_dataset_path", None) + task = config.pop("task", None) + project_name = config.pop("project_name", None) + + if "learning_rate" in config and isinstance(config["learning_rate"], str): + config["learning_rate"] = float(config["learning_rate"]) + if "weight_decay" in config and isinstance(config["weight_decay"], str): + config["weight_decay"] = float(config["weight_decay"]) + + agent_servers = get_agent_servers( + head_server_host=args.head_server_host, + head_server_port=11000, + ) + + if project_name: + os.environ["WANDB_PROJECT"] = project_name + + if dataset_path.endswith((".jsonl", ".json")): + dataset = load_dataset_from_jsonl(dataset_path) + else: + dataset = load_dataset(dataset_path, split="train") + + eval_dataset = None + if eval_dataset_path: + eval_dataset = load_dataset_from_jsonl(eval_dataset_path) + print(f"Eval dataset has {len(eval_dataset)} examples\n") + + training_args = NeMoGymGRPOConfig( + use_vllm=True, + vllm_mode="server", + vllm_server_host=args.vllm_server_host, + vllm_server_port=8000, + gradient_checkpointing=True, + num_generations_eval=1, + logging_steps=1, + epsilon=0.2, + epsilon_high=0.28, + loss_type="grpo", + mask_truncated_completions=True, + shuffle_dataset=False, + model_init_kwargs={"torch_dtype": "auto"}, + agent_servers=agent_servers, + request_timeout=10800, + **config, + ) + + if training_args.run_name is None: + task_name = task or os.path.basename(dataset_path).replace(".jsonl", "").replace(".json", "") + model_short = model_name.split("/")[-1] + training_args.run_name = ( + f"{task_name}_{model_short}" + f"_rpp{training_args.num_generations}" + f"_dbs{training_args.per_device_train_batch_size}" + f"_ga{training_args.gradient_accumulation_steps}" + f"_maxlen{training_args.max_completion_length}" + f"_lr{training_args.learning_rate}" + f"_temp{training_args.temperature}" + f"_topp{training_args.top_p}" + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side="left", padding_side="left") + + trainer = GRPOTrainer( + model=model_name, + processing_class=tokenizer, + reward_funcs=reward_fn, + train_dataset=dataset, + eval_dataset=eval_dataset, + rollout_func=nemo_gym_rollout_func, + args=training_args, + ) + + trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) + + +if __name__ == "__main__": + main() diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 611f9a2f781..e3464dae6d6 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -152,6 +152,65 @@ def test_reset_prefix_cache(self): # Test resetting the prefix cache self.client.reset_prefix_cache() + def test_chat_completions_endpoint(self): + data = self.client.chat_completions( + messages=[{"role": "user", "content": "Say hello"}], + max_tokens=32, + ) + + assert "id" in data + assert "choices" in data + assert "usage" in data + assert len(data["choices"]) > 0 + assert data["choices"][0]["message"]["role"] == "assistant" + assert data["choices"][0]["finish_reason"] in ["stop", "length", "tool_calls"] + + def test_chat_completions_with_tools(self): + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information for a location", + "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}, + }, + } + ] + data = self.client.chat_completions( + messages=[{"role": "user", "content": "What's the weather in San Francisco?"}], + tools=tools, + max_tokens=100, + ) + + assert "choices" in data + assert len(data["choices"]) > 0 + assert "message" in data["choices"][0] + + def test_chat_completions_with_params(self): + data = self.client.chat_completions( + messages=[{"role": "user", "content": "Tell me a joke"}], + n=2, + temperature=0.8, + top_p=0.9, + max_tokens=32, + ) + + assert len(data["choices"]) == 2 + + for i, choice in enumerate(data["choices"]): + assert choice["index"] == i, f"Expected choice at position {i} to have index {i}, got {choice['index']}" + assert "message" in choice + assert choice["message"]["role"] == "assistant" + + def test_tokenize_endpoint(self): + data = self.client.tokenize(messages=[{"role": "user", "content": "Hello, how are you?"}]) + + assert "tokens" in data + assert "model" in data + assert isinstance(data["tokens"], list) + assert len(data["tokens"]) > 0 + assert all(isinstance(tok, int) for tok in data["tokens"]) + @pytest.mark.xfail(reason="Importing `bitsandbytes` causes issues, see vllm-project/vllm#32793") def test_logprobs_match_with_non_default_sampling(self): prompts = ["Hello, AI!", "Tell me a joke"] diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 1f20dd39af9..1919345eaf7 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -514,6 +514,82 @@ def reset_prefix_cache(self): if response.status_code != 200: raise Exception(f"Request failed: {response.status_code}, {response.text}") + def chat_completions( + self, + messages: list[dict], + model: str | None = None, + temperature: float = 1.0, + top_p: float = 1.0, + max_tokens: int | None = None, + n: int = 1, + tools: list[dict] | None = None, + **kwargs, + ) -> dict: + """ + OpenAI-compatible chat completions endpoint. + + Args: + messages (`list[dict]`): + List of messages in OpenAI format with "role" and "content" keys. + model (`str`, *optional*): + Model name to use. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for sampling. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. + max_tokens (`int`, *optional*): + Maximum number of tokens to generate. + n (`int`, *optional*, defaults to `1`): + Number of completions to generate. + tools (`list[dict]`, *optional*): + List of tool definitions for tool calling. + **kwargs: + Additional parameters to pass to the endpoint. + + Returns: + `dict`: + OpenAI-compatible response with "choices", "usage", etc. + """ + url = f"{self.base_url}/v1/chat/completions" + response = self.session.post( + url, + json={ + "messages": messages, + "model": model, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "n": n, + "tools": tools, + **kwargs, + }, + ) + if response.status_code == 200: + return response.json() + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def tokenize(self, messages: list[dict], tools: list[dict] | None = None) -> dict: + """ + Tokenize messages to get token IDs. + + Args: + messages (`list[dict]`): + List of messages to tokenize. + tools (`list[dict]`, *optional*): + List of tool definitions. + + Returns: + `dict`: + Dictionary with "tokens" (list of token IDs) and "model" keys. + """ + url = f"{self.base_url}/tokenize" + response = self.session.post(url, json={"messages": messages, "tools": tools}) + if response.status_code == 200: + return response.json() + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + def close_communicator(self): """ Closes the weight update group and cleans up the communication group. diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 521aac7f4ba..13120274d00 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -14,8 +14,12 @@ import argparse import base64 +import json import logging import os +import re +import time +import uuid from collections.abc import Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -27,7 +31,7 @@ import torch import torch.distributed.distributed_c10d as c10d from packaging.version import Version -from transformers import is_torch_xpu_available, is_vision_available +from transformers import AutoTokenizer, is_torch_xpu_available, is_vision_available from trl import TrlParser from trl.import_utils import ( @@ -384,7 +388,23 @@ def llm_worker( method_name = command["method"] args, kwargs = command.get("args", ()), command.get("kwargs", {}) method = getattr(llm, method_name) - result = method(*args, **kwargs) + + try: + result = method(*args, **kwargs) + except ValueError as e: + error_msg = str(e) + if "longer than the maximum model length" in error_msg or "context length" in error_msg: + logger.error(f"[Worker] Context length exceeded: {error_msg}") + if method_name in ["generate", "chat"]: + result = [] + else: + raise + else: + raise + except Exception as e: + logger.error(f"[Worker] Unexpected error in {method_name}: {e}") + raise + if command["type"] == "call": connection.send(result) elif command["type"] == "shutdown": @@ -422,6 +442,61 @@ def sanitize_logprob(logprob): return value +def _replace_prefix_tokens( + tokenizer, + model_prefix_token_ids: list[int], + template_prefix_token_ids: list[int], + template_token_ids: list[int], +) -> list[int]: + """ + This function is for fixing up the chat template-tokenized messages history to match the model output tokenization + up to the last assistant turn, in order to preserve the monotonic tokens property for optimized multi-turn + training. + + RL training frameworks train models on token IDs, but the OpenAI compatible server communicates in what is + basically de-tokenized text. When multiple model calls are made to the OpenAI compatible server in a single + trajectory, model generations in previous model calls may be re-tokenized to something that is different than what + was generated. This is not too big of an issue (that we know of) at inference time, but the log probs the model + produces are different enough for the differently re-tokenized generation result that it causes the training to be + off policy. Off policy isn't necessarily a bad thing in isolation, but this source of off-policyness may cause + unexpected issues if not properly accounted for. It also mis-aligns the token ID sequences across model calls, + which is strange during training. + + There are real cases where the model output string _does not match_ the chat template tokenization of the parsed + model output. A concrete example is inconsistent whitespace tokens around tool call special tokens. + + Based on NeMo RL's _replace_prefix_tokens: + https://github.com/NVIDIA-NeMo/RL/blob/748b9caff4e6d672b8a98a10b6e612d028cfc96b/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + """ + if not model_prefix_token_ids: + return template_token_ids + + eos_token_id = tokenizer.eos_token_id + if eos_token_id is None: + logger.warning("Tokenizer has no EOS token ID, cannot apply _replace_prefix_tokens") + return template_token_ids + + model_cut_end = len(model_prefix_token_ids) + if model_prefix_token_ids and model_prefix_token_ids[-1] == eos_token_id: + model_cut_end -= 1 + + # We take everything starting with the EOS token ID. + template_cut_start = -1 + for pos in reversed(range(len(template_prefix_token_ids))): + if template_token_ids[pos] == eos_token_id: + template_cut_start = pos + break + + # This should never be the case, but + if template_cut_start < 0: + logger.warning("No EOS token found in template prefix, cannot apply _replace_prefix_tokens") + return template_token_ids + + result = model_prefix_token_ids[:model_cut_end] + template_token_ids[template_cut_start:] + + return result + + def main(script_args: ScriptArguments): if not is_fastapi_available(): raise ImportError( @@ -454,6 +529,11 @@ def main(script_args: ScriptArguments): @asynccontextmanager async def lifespan(app: FastAPI): + logger.info(f"Loading tokenizer for {script_args.model}...") + app.state.tokenizer = AutoTokenizer.from_pretrained( + script_args.model, trust_remote_code=script_args.trust_remote_code + ) + # Wait for all workers to send "ready" ready_connections = set() while len(ready_connections) < script_args.data_parallel_size: @@ -650,6 +730,7 @@ class ChatRequest(BaseModel): structured_outputs_regex: str | None = None generation_kwargs: dict = field(default_factory=dict) chat_template_kwargs: dict = field(default_factory=dict) + tools: list[dict] | None = None class ChatResponse(BaseModel): prompt_ids: list[list[int]] @@ -762,7 +843,9 @@ async def chat(request: ChatRequest): "messages": messages, "sampling_params": sampling_params, "chat_template_kwargs": request.chat_template_kwargs, + "tools": request.tools if request.tools else None, } + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) # Receive results @@ -865,8 +948,337 @@ async def close_communicator(): connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) return {"message": "Request received, closing communicator"} + class ChatCompletionRequest(BaseModel): + messages: list[dict] + model: str | None = None + temperature: float = 1.0 + top_p: float = 1.0 + max_completion_tokens: int | None = None + max_tokens: int | None = None + n: int = 1 + stop: str | list[str] | None = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + logprobs: bool = False + top_logprobs: int | None = None + tools: list[dict] | None = None + tool_choice: str | dict = "auto" + parallel_tool_calls: bool = True + + @app.post("/v1/chat/completions") + async def chat_completions(request: ChatCompletionRequest): + completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created_at = int(time.time()) + + messages = [] + for msg in request.messages: + role = msg.get("role", "") + if role not in ["system", "user", "assistant", "tool"]: + logger.warning(f"Unknown message role: {role}") + messages.append(msg) + + max_tokens = request.max_completion_tokens or request.max_tokens or 512 + + sampling_kwargs = { + "n": request.n, + "temperature": request.temperature, + "top_p": request.top_p, + "max_tokens": max_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "stop": request.stop, + } + + if request.logprobs or request.top_logprobs: + sampling_kwargs["logprobs"] = request.top_logprobs if request.top_logprobs else 1 + + sampling_params = SamplingParams(**sampling_kwargs) + + chat_template_kwargs = {} + if request.tool_choice and request.tool_choice != "auto": + chat_template_kwargs["tool_choice"] = request.tool_choice + + has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) + + if has_prefix_token_ids: + # do on policy token id correction and call generate instead of chat + # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html + # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + tokenizer = app.state.tokenizer + + # preprocess full conversation + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages], + "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, + "add_generation_prompt": True, + }, + } + ) + template_prompts = connections[0].recv() + template_prompt = template_prompts[0] + + # extract model prefix tokens from last assistant message + model_prefix_tokens = None + last_assistant_idx = None + for i in reversed(range(len(messages))): + if messages[i].get("role") == "assistant": + last_assistant_idx = i + if "prompt_token_ids" in messages[i]: + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( + "generation_token_ids", [] + ) + break + + if model_prefix_tokens and last_assistant_idx is not None: + messages_to_last_assistant = messages[: last_assistant_idx + 1] + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages_to_last_assistant], + "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, + "add_generation_prompt": False, + }, + } + ) + template_prefix_prompts = connections[0].recv() + template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] + + corrected_token_ids = _replace_prefix_tokens( + tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] + ) + + else: + corrected_token_ids = template_prompt["prompt_token_ids"] + + corrected_prompt = {"prompt_token_ids": corrected_token_ids} + chunked_prompts = chunk_list([corrected_prompt], script_args.data_parallel_size) + + for connection, prompts in zip(connections, chunked_prompts, strict=True): + if not prompts: + prompts = [{"prompt_token_ids": [tokenizer.eos_token_id]}] + connection.send( + { + "type": "call", + "method": "generate", + "kwargs": {"prompts": prompts, "sampling_params": sampling_params}, + } + ) + else: + # no prefix token IDs, use chat() + chunked_messages = chunk_list([messages], script_args.data_parallel_size) + + for connection, message_chunk in zip(connections, chunked_messages, strict=True): + if not message_chunk: + message_chunk = [[{"role": "user", "content": ""}]] + kwargs = { + "messages": message_chunk, + "sampling_params": sampling_params, + "tools": request.tools, + "chat_template_kwargs": chat_template_kwargs, + } + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) + + all_outputs = [connection.recv() for connection in connections] + if has_prefix_token_ids: + all_outputs = [ + output for output, prompt_chunk in zip(all_outputs, chunked_prompts, strict=True) if prompt_chunk + ] + else: + all_outputs = [ + output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk + ] + all_outputs = list(chain.from_iterable(all_outputs)) + + if not all_outputs: + return { + "id": completion_id, + "object": "chat.completion", + "created": created_at, + "model": request.model or script_args.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": ""}, + "finish_reason": "length", + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + choices = [] + total_input_tokens = 0 + total_output_tokens = 0 + + idx = 0 + for output in all_outputs: + total_input_tokens += len(output.prompt_token_ids) + + for gen_output in output.outputs: + total_output_tokens += len(gen_output.token_ids) + text = gen_output.text if hasattr(gen_output, "text") else "" + + tool_calls = None + finish_reason = gen_output.finish_reason if hasattr(gen_output, "finish_reason") else "stop" + + # Manual XML-json tool call parsing + if request.tools and text: + pattern = r"(.*?)" + matches = re.findall(pattern, text, re.DOTALL) + if matches: + tool_calls = [] + for match in matches: + try: + data = json.loads(match.strip()) + tool_calls.append( + { + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": data.get("name", ""), + "arguments": json.dumps(data.get("arguments", {})), + }, + } + ) + except json.JSONDecodeError: + continue + if tool_calls: + finish_reason = "tool_calls" + text = re.sub(pattern, "", text, flags=re.DOTALL).strip() + + if not request.parallel_tool_calls and tool_calls and len(tool_calls) > 1: + tool_calls = [tool_calls[0]] + + logprobs_data = None + if request.logprobs and hasattr(gen_output, "logprobs") and gen_output.logprobs: + logprobs_data = { + "content": [ + { + "token": str(token_id), + "logprob": float(list(logprob_dict.values())[0].logprob) if logprob_dict else 0.0, + "bytes": None, + "top_logprobs": [], + } + for token_id, logprob_dict in zip(gen_output.token_ids, gen_output.logprobs, strict=False) + ] + } + + choices.append( + { + "index": idx, + "message": { + "role": "assistant", + "content": text if not tool_calls else None, + "tool_calls": tool_calls, + }, + "logprobs": logprobs_data, + "finish_reason": finish_reason, + } + ) + idx += 1 + + return { + "id": completion_id, + "object": "chat.completion", + "created": created_at, + "model": request.model or script_args.model, + "choices": choices, + "usage": { + "prompt_tokens": total_input_tokens, + "completion_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens, + }, + } + + class TokenizeRequest(BaseModel): + model: str | None = None + messages: list[dict] + tools: list[dict] | None = None + + @app.post("/tokenize") + async def tokenize(request: TokenizeRequest): + messages = request.messages + + has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) + + kwargs = { + "messages": [messages], + "tools": request.tools, + "add_generation_prompt": True, + "chat_template_kwargs": {}, + } + + connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": kwargs}) + preprocessed_prompts = connections[0].recv() + + if preprocessed_prompts and len(preprocessed_prompts) > 1: + logger.warning( + "More than one tokenized message returned from preprocess_chat inside tokenize, double check results!" + ) + + if not preprocessed_prompts or len(preprocessed_prompts) == 0: + return {"tokens": [], "model": request.model or script_args.model} + + template_prompt = preprocessed_prompts[0] + result_tokens = template_prompt["prompt_token_ids"] + + if has_prefix_token_ids: + tokenizer = app.state.tokenizer + + # Extract model prefix tokens from last assistant message + model_prefix_tokens = None + last_assistant_idx = None + for i in reversed(range(len(messages))): + if messages[i].get("role") == "assistant": + last_assistant_idx = i + if "prompt_token_ids" in messages[i]: + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( + "generation_token_ids", [] + ) + break + + if model_prefix_tokens and last_assistant_idx is not None: + # Preprocess up to last assistant + messages_to_last_assistant = messages[: last_assistant_idx + 1] + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages_to_last_assistant], + "tools": request.tools, + "add_generation_prompt": False, + "chat_template_kwargs": {}, + }, + } + ) + template_prefix_prompts = connections[0].recv() + template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] + + result_tokens = _replace_prefix_tokens( + tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] + ) + + return {"tokens": result_tokens, "model": request.model or script_args.model} + # Start the server - uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level) + uvicorn.run( + app, + host=script_args.host, + port=script_args.port, + log_level=script_args.log_level, + limit_concurrency=256, + backlog=4096, + timeout_keep_alive=600, + ) def make_parser(subparsers: argparse._SubParsersAction | None = None): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1d8dcf10673..47a50c4ad00 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1556,7 +1556,13 @@ def _generate_and_score_completions( prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + + # Allow custom completion_mask from rollout_func for multi-turn training + if "completion_mask" in extra_fields: + completion_mask_list = extra_fields.pop("completion_mask") + completion_mask = [torch.tensor(m, device=device, dtype=torch.long) for m in completion_mask_list] + else: + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") completion_mask = pad(completion_mask, padding_value=0, padding_side="right") if sampling_per_token_logps_list is not None: @@ -1578,7 +1584,10 @@ def _generate_and_score_completions( # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + # attend to all non-padding tokens, but mask out user/tool result tokens in loss + completion_attention_mask = (completion_ids != self.pad_token_id).long() + attention_mask = torch.cat([prompt_mask, completion_attention_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size @@ -1799,7 +1808,6 @@ def _generate_and_score_completions( self._metrics[mode]["sampling/sampling_logp_difference/max"].append( self.accelerator.gather(max_delta).max().item() ) - if sequence_level_is: flat_is_ratio = vllm_importance_sampling_ratio.flatten() else: