Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/source/nemo_gym.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Many NeMo Gym datasets used to train Nemotron models are available on Hugging Fa
- Validates the data format
- Adds an `agent_ref` field to each example that tells NeMo Gym which agent server should handle that example

> **Note**: `run_grpo_nemo_gym.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way.
> **Note**: `train_multi_environment.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way.

1. **Set Hugging Face Token**

Expand Down Expand Up @@ -142,7 +142,7 @@ max_steps: 1000
num_generations: 8
per_device_train_batch_size: 1
gradient_accumulation_steps: 4
max_seq_length: 16384
max_completion_length: 16384

temperature: 1.0
top_p: 0.999
Expand Down Expand Up @@ -221,7 +221,7 @@ The following steps run in 3 terminals. It can also be ran with processes in the
export WANDB_API_KEY=...
uv pip install wandb

CUDA_VISIBLE_DEVICES=1 python run_grpo_nemo_gym.py --config config_workplace.yaml
CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config_workplace.yaml
```

## Multi-Node Training with Slurm
Expand Down Expand Up @@ -325,5 +325,5 @@ Train on multiple NeMo Gym environments simultaneously. This allows learning div

- [NeMo Gym GitHub](https://github.com/NVIDIA-NeMo/Gym)
- [NeMo Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/)
- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/run_grpo_nemo_gym.py)
- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py)
- [TRL GRPO Trainer](grpo_trainer)
15 changes: 10 additions & 5 deletions examples/scripts/nemo_gym/config.yaml
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
model_name: "Qwen/Qwen3-4B-Instruct-2507"
# Model
model_name: "Qwen/Qwen2.5-1.5B-Instruct"

dataset_path: "/path/to/data/train.jsonl"
eval_dataset_path: "/path/to/data/val.jsonl"
# Data
dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/train.jsonl"
eval_dataset_path: "/home/ubuntu/Gym/resources_servers/workplace_assistant/data/validation.jsonl"

# Logging
output_dir: "outputs/nemo_gym"
task: "workplace" # just used in wandb run name
report_to: "wandb"
project_name: "trl-nemo-gym"
log_completions: true
num_completions_to_print: 2

# Training hyperparameters
learning_rate: 1.0e-5
max_steps: 1000
num_generations: 8
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
max_seq_length: 16384
max_completion_length: 16384
warmup_steps: 5
lr_scheduler_type: "linear"
optim: "adamw_torch_fused"
weight_decay: 0.0
vllm_importance_sampling_correction: true

# Inference sampling parameters
temperature: 1.0
top_p: 0.999

# Checkpointing and Eval
save_steps: 10

eval_strategy: "steps"
eval_steps: 10

4 changes: 2 additions & 2 deletions examples/scripts/nemo_gym/submit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mkdir -p ${LOG_DIR}
echo "Starting ng_run and vLLM on ${VLLM_NODE}..."
echo "Logs will be saved to: ${LOG_DIR}"

# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below!
# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below!

srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \
--container-image="${CONTAINER_IMAGE}" \
Expand Down Expand Up @@ -92,7 +92,7 @@ srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \
export HOME=/path/to/user && \
export HF_HOME=/path/to/user/hf_home && \
cd /path/to/user/trl && \
source .venv/bin/activate && uv pip install accelerate deepseed wandb omegaconf && \
source .venv/bin/activate && uv pip install accelerate deepspeed wandb omegaconf && \
cd examples/scripts/nemo_gym && \
export WANDB_API_KEY=<your wandb api key> && \
accelerate launch \
Expand Down
150 changes: 52 additions & 98 deletions examples/scripts/nemo_gym/train_multi_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# /// script
# dependencies = [
# "trl[vllm]",
# "nemo_gym @ git+https://github.com/NVIDIA-NeMo/Gym",
# ]
# ///

import argparse
import asyncio
import json
Expand All @@ -21,7 +28,6 @@

import aiohttp
import requests
import wandb
import yaml
from datasets import Dataset, load_dataset
from omegaconf import OmegaConf
Expand All @@ -31,41 +37,9 @@


@dataclass
class TrainingConfig:
model_name: str
dataset_path: str

task: str | None = None

learning_rate: float = 5e-6
max_steps: int = 100
num_generations: int = 2
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 16
max_seq_length: int = 1024
max_prompt_length: int = None

temperature: float = 1.0
top_p: float = 0.999
weight_decay: float = 0.01
warmup_ratio: float = 0.0
warmup_steps: int = 0
lr_scheduler_type: str = "linear"
optim: str = "adamw_8bit"

output_dir: str = "outputs/trl_nemo_gym"
save_steps: int = 100
report_to: str = "none"
run_name: str = None # Wandb
project_name: str = None # Wandb
log_completions: bool = False
num_completions_to_print: int = None

eval_dataset_path: str | None = None
eval_strategy: str = "no"
eval_steps: int = 50

vllm_importance_sampling_correction: bool = False
class NeMoGymGRPOConfig(GRPOConfig):
agent_servers: dict[str, str] | None = None
request_timeout: float = 10800


def get_agent_servers(
Expand Down Expand Up @@ -298,11 +272,11 @@ def nemo_gym_rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str,
raise RuntimeError("No valid rollouts. Check Nemo Gym and vLLM logs.")

if num_turns_list:
wandb.log(
trainer.log(
{
"train/num_turns_mean": sum(num_turns_list) / len(num_turns_list),
"train/num_turns_min": min(num_turns_list),
"train/num_turns_max": max(num_turns_list),
"num_turns_mean": sum(num_turns_list) / len(num_turns_list),
"num_turns_min": min(num_turns_list),
"num_turns_max": max(num_turns_list),
}
)

Expand Down Expand Up @@ -344,94 +318,74 @@ def main():
args = parser.parse_args()

with open(args.config) as f:
config = TrainingConfig(**yaml.safe_load(f))
config = yaml.safe_load(f)

model_name = config.pop("model_name")
dataset_path = config.pop("dataset_path")
eval_dataset_path = config.pop("eval_dataset_path", None)
task = config.pop("task", None)
project_name = config.pop("project_name", None)

if isinstance(config.learning_rate, str):
config.learning_rate = float(config.learning_rate)
if isinstance(config.weight_decay, str):
config.weight_decay = float(config.weight_decay)
if "learning_rate" in config and isinstance(config["learning_rate"], str):
config["learning_rate"] = float(config["learning_rate"])
if "weight_decay" in config and isinstance(config["weight_decay"], str):
config["weight_decay"] = float(config["weight_decay"])

agent_servers = get_agent_servers(
head_server_host=args.head_server_host,
head_server_port=11000,
)

if config.project_name:
os.environ["WANDB_PROJECT"] = config.project_name

if config.run_name is None:
task = config.task or os.path.basename(config.dataset_path).replace(".jsonl", "").replace(".json", "")
model_short = config.model_name.split("/")[-1]
config.run_name = (
f"{task}_{model_short}"
f"_rpp{config.num_generations}"
f"_dbs{config.per_device_train_batch_size}"
f"_ga{config.gradient_accumulation_steps}"
f"_maxlen{config.max_seq_length}"
f"_lr{config.learning_rate}"
f"_temp{config.temperature}"
f"_topp{config.top_p}"
)
if project_name:
os.environ["WANDB_PROJECT"] = project_name

if config.dataset_path.endswith((".jsonl", ".json")):
dataset = load_dataset_from_jsonl(config.dataset_path)
if dataset_path.endswith((".jsonl", ".json")):
dataset = load_dataset_from_jsonl(dataset_path)
else:
dataset = load_dataset(config.dataset_path, split="train")
dataset = load_dataset(dataset_path, split="train")

eval_dataset = None
if config.eval_dataset_path:
eval_dataset = load_dataset_from_jsonl(config.eval_dataset_path)
if eval_dataset_path:
eval_dataset = load_dataset_from_jsonl(eval_dataset_path)
print(f"Eval dataset has {len(eval_dataset)} examples\n")

training_args = GRPOConfig(
training_args = NeMoGymGRPOConfig(
use_vllm=True,
vllm_mode="server",
vllm_server_host=args.vllm_server_host,
vllm_server_port=8000,
gradient_checkpointing=True,
temperature=config.temperature,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_ratio=config.warmup_ratio,
warmup_steps=config.warmup_steps,
lr_scheduler_type=config.lr_scheduler_type,
optim=config.optim,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
num_generations=config.num_generations,
num_generations_eval=1,
max_steps=config.max_steps,
save_steps=config.save_steps,
logging_steps=1,
report_to=config.report_to,
output_dir=config.output_dir,
run_name=config.run_name,
eval_strategy=config.eval_strategy,
eval_steps=config.eval_steps,
vllm_importance_sampling_correction=config.vllm_importance_sampling_correction,
epsilon=0.2,
epsilon_high=0.28,
loss_type="grpo",
mask_truncated_completions=True,
log_completions=config.log_completions,
num_completions_to_print=config.num_completions_to_print,
# max_prompt_length=config.max_prompt_length,
max_completion_length=config.max_seq_length - config.max_prompt_length
if config.max_prompt_length
else config.max_seq_length,
shuffle_dataset=False,
model_init_kwargs={
"torch_dtype": "auto",
},
model_init_kwargs={"torch_dtype": "auto"},
agent_servers=agent_servers,
request_timeout=10800,
**config,
)

training_args.agent_servers = agent_servers
training_args.request_timeout = 10800
if training_args.run_name is None:
task_name = task or os.path.basename(dataset_path).replace(".jsonl", "").replace(".json", "")
model_short = model_name.split("/")[-1]
training_args.run_name = (
f"{task_name}_{model_short}"
f"_rpp{training_args.num_generations}"
f"_dbs{training_args.per_device_train_batch_size}"
f"_ga{training_args.gradient_accumulation_steps}"
f"_maxlen{training_args.max_completion_length}"
f"_lr{training_args.learning_rate}"
f"_temp{training_args.temperature}"
f"_topp{training_args.top_p}"
)

tokenizer = AutoTokenizer.from_pretrained(config.model_name, truncation_side="left", padding_side="left")
tokenizer = AutoTokenizer.from_pretrained(model_name, truncation_side="left", padding_side="left")

trainer = GRPOTrainer(
model=config.model_name,
model=model_name,
processing_class=tokenizer,
reward_funcs=reward_fn,
train_dataset=dataset,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def test_chat_completions_with_params(self):

assert len(data["choices"]) == 2

for choice in data["choices"]:
for i, choice in enumerate(data["choices"]):
assert choice["index"] == i, f"Expected choice at position {i} to have index {i}, got {choice['index']}"
assert "message" in choice
assert choice["message"]["role"] == "assistant"

Expand Down
10 changes: 7 additions & 3 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _replace_prefix_tokens(
model output. A concrete example is inconsistent whitespace tokens around tool call special tokens.

Based on NeMo RL's _replace_prefix_tokens:
https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40
https://github.com/NVIDIA-NeMo/RL/blob/748b9caff4e6d672b8a98a10b6e612d028cfc96b/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40
"""
if not model_prefix_token_ids:
return template_token_ids
Expand Down Expand Up @@ -1088,7 +1088,9 @@ async def chat_completions(request: ChatCompletionRequest):

all_outputs = [connection.recv() for connection in connections]
if has_prefix_token_ids:
all_outputs = [o for o in all_outputs if o]
all_outputs = [
output for output, prompt_chunk in zip(all_outputs, chunked_prompts, strict=True) if prompt_chunk
]
else:
all_outputs = [
output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk
Expand Down Expand Up @@ -1116,7 +1118,8 @@ async def chat_completions(request: ChatCompletionRequest):
total_input_tokens = 0
total_output_tokens = 0

for idx, output in enumerate(all_outputs):
idx = 0
for output in all_outputs:
total_input_tokens += len(output.prompt_token_ids)

for gen_output in output.outputs:
Expand Down Expand Up @@ -1180,6 +1183,7 @@ async def chat_completions(request: ChatCompletionRequest):
"finish_reason": finish_reason,
}
)
idx += 1

return {
"id": completion_id,
Expand Down