Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
16640e8
nemo gym integration
cmunley1 Dec 17, 2025
6261758
couple updates
cmunley1 Dec 18, 2025
4105340
baseline without on policy correction
cmunley1 Dec 20, 2025
be5c156
readme
cmunley1 Dec 20, 2025
64b9ed4
wip
cmunley1 Dec 22, 2025
948869f
fixes
cmunley1 Jan 7, 2026
52a3140
readme
cmunley1 Jan 7, 2026
0e71cbb
cfg
cmunley1 Jan 7, 2026
3548099
small fix
cmunley1 Jan 7, 2026
8373899
docs
cmunley1 Jan 9, 2026
fe4bce6
fixes
cmunley1 Jan 15, 2026
facfb5a
remove flag
cmunley1 Jan 15, 2026
ac94e1b
multi env
cmunley1 Jan 16, 2026
32c5a6b
small fix
cmunley1 Jan 16, 2026
5619096
dataset index
cmunley1 Jan 16, 2026
04821b5
multinode example
cmunley1 Jan 17, 2026
52b2f5c
client and tests
cmunley1 Jan 17, 2026
0793c05
remove native tool parsing, use fastapi state
cmunley1 Jan 17, 2026
5f8ccc9
remove old code
cmunley1 Jan 17, 2026
743d5ea
enable IS
cmunley1 Jan 17, 2026
d98dd8a
remove logp diff tracking without is
cmunley1 Jan 17, 2026
a5f9166
restore
cmunley1 Jan 17, 2026
17b72c8
readme
cmunley1 Jan 17, 2026
18ffaa8
restore pyproject
cmunley1 Jan 17, 2026
cc503cb
readme
cmunley1 Jan 17, 2026
843938f
move submit
cmunley1 Jan 17, 2026
209b12e
config
cmunley1 Jan 17, 2026
2ec1a0f
Merge branch 'main' into cmunley1/nemo_gym_on_policy
sergiopaniego Jan 20, 2026
a8f7b36
draft docs
cmunley1 Jan 21, 2026
6893625
Merge branch 'cmunley1/nemo_gym_on_policy' of github.com:cmunley1/trl…
cmunley1 Jan 21, 2026
e883dcd
draft docs
cmunley1 Jan 21, 2026
2c7de07
docs update
cmunley1 Jan 21, 2026
aad21ee
ds cfg, submit update
cmunley1 Jan 22, 2026
06ab2a2
readme
cmunley1 Jan 22, 2026
cf9f177
rename train, update docs
cmunley1 Jan 22, 2026
7669c00
comment
cmunley1 Jan 22, 2026
df2f350
Merge branch 'main' into cmunley1/nemo_gym_on_policy
kashif Jan 23, 2026
3a455a9
Update trl/trainer/grpo_trainer.py
sergiopaniego Jan 23, 2026
f69a70a
Update trl/scripts/vllm_serve.py
cmunley1 Jan 26, 2026
56535f2
rename docs file
cmunley1 Jan 26, 2026
f1c6614
Merge branch 'main' into cmunley1/nemo_gym_on_policy
cmunley1 Jan 26, 2026
537f82e
Merge branch 'main' into cmunley1/nemo_gym_on_policy
sergiopaniego Jan 27, 2026
7b1fe8a
nemo gym trl edits
lbliii Jan 28, 2026
92227f0
Merge pull request #1 from lbliii/llane/nemo-gym-trl-edits
cmunley1 Jan 28, 2026
9f7f45f
Merge remote-tracking branch 'upstream/main' into cmunley1/nemo_gym_o…
cmunley1 Jan 29, 2026
4d6012e
lint
cmunley1 Jan 30, 2026
d5443eb
docs
cmunley1 Jan 30, 2026
2837bda
improve docs, rename train script
cmunley1 Jan 30, 2026
93d97a7
fixes based on review
cmunley1 Jan 30, 2026
b15ab63
subclass
cmunley1 Jan 30, 2026
6d7e8d0
config update
cmunley1 Jan 31, 2026
13c378c
docs
cmunley1 Jan 31, 2026
c5dcb5d
typo in submit
cmunley1 Jan 31, 2026
03ffa0b
Merge pull request #2 from cmunley1/cmunley1/ng-fix
cmunley1 Jan 31, 2026
b4678fb
improve nemo gym docs
cmunley1 Jan 31, 2026
a476ac5
update docs
cmunley1 Jan 31, 2026
5e70a33
rename project to server
cmunley1 Feb 2, 2026
a3f241e
vllm finish reason
cmunley1 Feb 2, 2026
82cd8d5
Merge branch 'main' into cmunley1/nemo_gym_on_policy
sergiopaniego Feb 4, 2026
e123a88
Update docs/source/nemo_gym.md
sergiopaniego Feb 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/scripts/nemo_gym/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Post-training with NeMo Gym and TRL

This integration supports training language models in NeMo-Gym environments using TRL GRPO. Both single step and multi step tasks are supported, including multi-environment training. NeMo-Gym orchestrates rollouts, returning token ids and logprobs to TRL through the rollout function for training. Currently this integration is only supported through TRL's vllm server mode.

## Interactive single node

1. Launch vLLM server:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve \
--model Qwen/Qwen3-4B-Instruct-2507 \
--tensor-parallel-size 4 \
--max-model-len 8192 \
--trust-remote-code
```

2. Start NeMo Gym servers
```
ng_run "+config_paths=[resources_servers/workplace_assistant/configs/workplace_assistant.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]"
```


3. Run training:
```bash
CUDA_VISIBLE_DEVICES=4 python train.py --config config.yaml
```

## Multinode with slurm

See submit.sh for a multinode example!

## Multi environment training

Docs coming soon!
32 changes: 32 additions & 0 deletions examples/scripts/nemo_gym/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
model_name: "Qwen/Qwen3-4B-Instruct-2507"

dataset_path: "data/train.jsonl"
eval_dataset_path: "data/val.jsonl"

output_dir: "outputs/nemo_gym"
run_name_prefix: "nemo_gym"
report_to: "wandb"
project_name: "trl-nemo-gym"
log_completions: true
num_completions_to_print: 2

learning_rate: 1.0e-5
max_steps: 1000
num_generations: 8
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
max_seq_length: 16384
warmup_steps: 5
lr_scheduler_type: "linear"
optim: "adamw_torch_fused"
weight_decay: 0.0
vllm_importance_sampling_correction: true

temperature: 1.0
top_p: 0.999

save_steps: 10

eval_strategy: "steps"
eval_steps: 10

109 changes: 109 additions & 0 deletions examples/scripts/nemo_gym/submit.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/bin/bash
#SBATCH -A account
#SBATCH -p partition
#SBATCH -N 5
#SBATCH --gres gpu:8
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=16
#SBATCH --time=4:00:00
#SBATCH --job-name=trl_nemo_gym
#SBATCH --output=logs/%j/slurm.out
#SBATCH --error=logs/%j/slurm.err

CONTAINER_IMAGE="nvcr.io/nvidia/pytorch:25.12-py3"
MOUNTS="/path/to/mounts:/path/to/mounts"

NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))

TRAIN_NODE_0="${NODELIST[0]}"
TRAIN_NODE_1="${NODELIST[1]}"
TRAIN_NODE_2="${NODELIST[2]}"
TRAIN_NODE_3="${NODELIST[3]}"
VLLM_NODE="${NODELIST[4]}"

echo "Training Nodes: $TRAIN_NODE_0, $TRAIN_NODE_1, $TRAIN_NODE_2, $TRAIN_NODE_3"
echo "vLLM Node: $VLLM_NODE"
echo "Main process IP: $TRAIN_NODE_0"

LOG_DIR="logs/${SLURM_JOB_ID}"
mkdir -p ${LOG_DIR}

echo "Starting ng_run and vLLM on ${VLLM_NODE}..."
echo "Logs will be saved to: ${LOG_DIR}"

srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \
--container-image="${CONTAINER_IMAGE}" \
--container-mounts="${MOUNTS}" \
--container-mount-home \
bash -c "
LOG_DIR=/path/to/logs
mkdir -p \${LOG_DIR}

# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh
source \$HOME/.local/bin/env

# Start nemo gym servers
(set -x && \
export HOME=/path/to/user && \
export PATH=\$HOME/.local/bin:\$PATH && \
cd /path/to/user/Gym && \
uv venv --python 3.12 && \
source .venv/bin/activate && \
uv sync && \
ray stop --force && \
ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0) > \${LOG_DIR}/ng_run.log 2>&1 &

sleep 10

# Start trl vllm server
(set -x && \
export HOME=/path/to/user && \
export HF_HOME=/path/to/user/hf_home && \
cd /path/to/user/trl && \
source .venv/bin/activate && \
python -m trl.scripts.vllm_serve \
--model Qwen/Qwen3-4B-Instruct-2507 \
--host 0.0.0.0 \
--tensor-parallel-size 8 \
--data-parallel-size 1 \
--max-model-len 16384 \
--gpu-memory-utilization 0.7 \
--port 8000) > \${LOG_DIR}/vllm_serve.log 2>&1 &

wait
" &

echo "Waiting for nemo gym and vllm to start..."
sleep 120

echo "Launching training on 4 nodes..."

TRAIN_NODES_LIST="${TRAIN_NODE_0},${TRAIN_NODE_1},${TRAIN_NODE_2},${TRAIN_NODE_3}"

srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \
--container-image="${CONTAINER_IMAGE}" \
--container-mounts="${MOUNTS}" \
--container-mount-home \
bash -c "
set -x && \
export HOME=/path/to/user && \
export HF_HOME=/path/to/user/hf_home && \
cd /path/to/user/trl && \
source .venv/bin/activate && \
cd examples/scripts/nemo_gym && \
accelerate launch \
--config_file deepspeed_zero3.yaml \
--num_processes 32 \
--num_machines 4 \
--machine_rank \$SLURM_PROCID \
--main_process_ip ${TRAIN_NODE_0} \
--main_process_port 29500 \
--rdzv_backend c10d \
train.py \
--config config.yaml \
--vllm_server_host ${VLLM_NODE} \
--head_server_host ${VLLM_NODE}" &

wait

Loading