Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
a6aacdb
ini
xxman-google Jul 3, 2025
eba5e76
fix: load HF model only on rank 0 (#544)
parthchadha Jul 2, 2025
e790556
feat: supports evaluation of multiple-choice benchmarks (#559)
xxman-google Jul 2, 2025
9325915
fix: enable expandable segments for hopper+ (#594)
parthchadha Jul 3, 2025
ab84468
feat: Enable vLLM cudagraphs (#498)
jiemingz Jul 3, 2025
b14ae90
docs: Update guide to include minimum compute requirement (#505)
abukharin-nv Jul 3, 2025
e694d38
fix: skip HelpSteer3 unit test if downloading failed (#612)
yuki-97 Jul 4, 2025
4e9bd60
feat: optimize get logprobs when cp enabled. (#528)
joyang-nv Jul 5, 2025
6d45949
enable mcore rope fusion (#608)
jiemingz Jul 5, 2025
d1bd4a5
fix: fix non-colocated with vllm tp>1 (#601)
yuki-97 Jul 6, 2025
9bc851c
feat: Refit: reduce the number of IPC calls by packing weights (#589)
guyueh1 Jul 7, 2025
01cbc28
feat: add flash-attn==2.7.4.post1 to backend dependencies (#622)
terrykong Jul 8, 2025
bb614d4
fix: Fix crash for logprob error plot (#623)
yfw Jul 8, 2025
fe8ada3
refactor: remove fsdp1 path (#614)
yuki-97 Jul 8, 2025
76b6c89
fix: fix a answer parsing bug in MMLU-Pro. (#598)
xxman-google Jul 8, 2025
8711249
feat: add MMMLU eval benchmark. (#596)
xxman-google Jul 8, 2025
0ce6470
fix: pytest_sessionfinish hook in case there is no _unit_test_data. (…
ffrujeri Jul 8, 2025
adf167c
fix: Don't call broadcast on dtensor (#627)
parthchadha Jul 8, 2025
c2384bb
fix: Fix eval when using async engine (#626)
parthchadha Jul 8, 2025
299bf0c
feat: Megatron MoE Support (#590)
yfw Jul 9, 2025
e7d9253
chore: exclude ray.remote from coverage (#624)
terrykong Jul 9, 2025
ea7938f
feat: guide to configure custom vllm version (#529)
terrykong Jul 9, 2025
9f5f833
feat: Deepseek Support (#591)
yfw Jul 9, 2025
387856a
feat: decouple checkpointing from validation (#575)
ashors1 Jul 9, 2025
b5b3424
feat: dynamically detect --gres=gpu:8 arg to work on clusters that do…
terrykong Jul 10, 2025
ebf2c05
fix: fix nccl P2P initialization error for non-colocated (#636)
zhandaz Jul 10, 2025
3a37061
fix: Mcore: Added functional grpo test and typing fixes (#527)
SahilJain314 Jul 11, 2025
89aa84c
feat: plumb environment variables to RayWorkerGroup (#631)
ashors1 Jul 11, 2025
6d385a2
feat: Qwen3 support (#592)
ashors1 Jul 14, 2025
883d573
fix: Fix megatron llama3.1-8b config (#652)
yfw Jul 14, 2025
aa3cb34
fix: update qwen32b config (#658)
yuki-97 Jul 14, 2025
a27ab00
fix: Make trust_remote_code default true in checkpoint (#663)
parthchadha Jul 14, 2025
c06c264
feat: add script to redact hparam paths from tensorboard logs (#347)
terrykong Jul 14, 2025
28aae4a
test: add a unit test that verifies that the correct keys are present…
ashors1 Jul 15, 2025
9083330
add adk
xxman-google Jul 16, 2025
95cb615
rollout stable and faster
xxman-google Jul 22, 2025
86505bb
cleanup
xxman-google Jul 22, 2025
3c5aa35
docs: Add GitHub icon and link to top bar (#669)
aschilling-nv Jul 15, 2025
171ef50
fix: Tie weights after set_model_state_dict if required (#666)
parthchadha Jul 15, 2025
905d089
feat: optimize refit by reducing set of IPC handles sent to each devi…
ZhiyuLi-Nvidia Jul 15, 2025
d0424cf
fix: adjust temperature scaling logic based on engine version (#660)
jubick1337 Jul 16, 2025
07d7d92
feat: introduce megatron checkpoint dir precedence (#665)
terrykong Jul 16, 2025
00d74b7
feat: optimize refit by preparing refit info ahead of time (#638)
yuki-97 Jul 16, 2025
6b97100
docs: update converter path in README. (#672)
xxman-google Jul 17, 2025
c8115f9
fix: make mcore lr scheduler configuration consistent with dtensor (#…
ashors1 Jul 17, 2025
88a9429
fix: fix mcore LR increment (#685)
ashors1 Jul 17, 2025
a0df2ef
fix: upgrade datasets to fix squad download (#692)
ashors1 Jul 18, 2025
22a984a
fix: Megatron config updates to avoid OOM (#687)
ashors1 Jul 18, 2025
00e33a9
fix: fix lr scheduler for config that was missed in #681 (#693)
ashors1 Jul 18, 2025
c4b5151
fix: Fix gemma models broken by HF update (#676)
yfw Jul 19, 2025
6a22c8a
chore: add CP+SP (sequence parallel) assertion in DTensor worker (#689)
yuki-97 Jul 19, 2025
75a5a6d
feat: MLFlow Integration for experiment tracking (#697)
terrykong Jul 21, 2025
022647b
fix: Fix activation checkpointing for mcore path (#703)
yfw Jul 21, 2025
163d750
feat: Enable Context Parallelism and Sequence Packing for MCore and D…
SahilJain314 Jul 22, 2025
71ed6e7
fix: SyntaxWarning: invalid escape sequence '\s' (#705)
RayenTian Jul 22, 2025
6c1898a
cleanup and add docstring
jialei777 Jul 23, 2025
af6fa8d
remove image
jialei777 Jul 23, 2025
084235d
ci: add a job that checks if submodules are fast forwarded (#695)
terrykong Jul 22, 2025
1b972ce
add uv.lock file
jialei777 Jul 23, 2025
6d9460b
update uv lock
jialei777 Jul 23, 2025
e6890cc
Merge branch 'main' into jialei/simulated-user
jialei777 Jul 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/configs/grpo_adk_llama8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# GRPO configuration for unique numbers environment
defaults: "grpo_math_8B.yaml"

grpo:
num_prompts_per_step: 32
num_generations_per_prompt: 16
max_rollout_turns: 20
max_num_steps: 100
val_at_start: false

data:
add_system_prompt: false

checkpointing:
enabled: false
checkpoint_dir: "results/grpo-adk"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 3
save_period: 10

env:
unique_numbers:
cfg:
max_turns: 15
min_length: 5
max_length: 10
max_integer: 15

logger:
wandb_enabled: True
wandb:
project: "grpo-simulated-adk"
name: "llama-8b-__NOW__"

policy:
train_global_batch_size: 512
dynamic_batching:
enabled: False
tokenizer:
chat_template: "{% for message in messages %}{% if loop.first %}<|begin_of_text|>{% endif %}<|start_header_id|>{{ message['role'] }}<|end_header_id|>\n{{ message['content'] }}<|eot_id|>{% endfor %}<|start_header_id|>assistant<|end_header_id|>\n"

cluster:
gpus_per_node: 8
245 changes: 245 additions & 0 deletions examples/run_grpo_unique_numbers_w_adk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""Run GRPO with the Unique Numbers Simulator using ADK.

This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm
in a multi-turn conversational environment powered by the ADK framework.

### Task Overview
The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user.
The interaction is structured as a turn-based dialogue:
- The user generates a list of integers.
- The agent queries specific positions in the list (by index).
- The user replies with the value at that index (if available).
- The agent continues the interaction until it makes a final guess at the number of unique integers.

### Environment Details
The environment is a simulated user that:
- Randomly generates a list of integers at setup.
- Responds to the agent's queries using an LLM via the ADK endpoint.
- Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task).

### Example Usage
uv run python examples/run_grpo_unique_numbers_w_adk.py

### Requirements
- A working ADK environment with access to a compatible LLM endpoint.
For the default Gemini endpoint, the following environment variables must be set:
- `GOOGLE_GENAI_USE_VERTEXAI=1`
- `GOOGLE_CLOUD_PROJECT="your-project-id"`
- `GOOGLE_CLOUD_LOCATION="your-location"`

- A properly configured GRPO YAML file.
By default, the script uses:
`examples/configs/grpo_adk_llama8b.yaml`
"""

import argparse
import itertools
import os
import pprint
import random
from datetime import datetime, timedelta
from typing import Iterator

from omegaconf import OmegaConf
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer

from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.environments.simulated_user.prompt import starting_user_prompt
from nemo_rl.environments.simulated_user.unique_numbers import (
UniqueNumbersEnv,
UniqueNumbersMetadata,
)
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)


def parse_args():
parser = argparse.ArgumentParser(
description="Run GRPO with unique numbers simulator"
)
parser.add_argument(
"--config", type=str, default=None, help="Path to YAML config file"
)
args, overrides = parser.parse_known_args()
return args, overrides


def generate_datum(
tokenizer: AutoTokenizer,
env_cfg: dict,
task_name: str,
idx: int,
add_system_prompt: bool,
) -> DatumSpec:
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": starting_user_prompt}],
tokenize=False,
add_system_prompt=add_system_prompt,
add_generation_prompt=True,
add_special_tokens=False,
).strip()
token_ids = tokenizer(
formatted_prompt, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]

def _generate_numbers(
min_length, max_length, max_integer, default_max_turns
) -> UniqueNumbersMetadata:
length = random.randint(min_length, max_length)
numbers = [random.randint(0, max_integer) for _ in range(length)]
return UniqueNumbersMetadata(
numbers=numbers,
unique_count=len(set(numbers)),
turn=0,
max_turns=default_max_turns,
)

metadata = _generate_numbers(
min_length=env_cfg["cfg"]["min_length"],
max_length=env_cfg["cfg"]["max_length"],
max_integer=env_cfg["cfg"]["max_integer"],
default_max_turns=env_cfg["cfg"]["max_turns"],
)

message_log: LLMMessageLogType = [
{"role": "user", "content": formatted_prompt, "token_ids": token_ids}
]
return {
"message_log": message_log,
"length": len(token_ids),
"extra_env_info": metadata,
"loss_multiplier": 1.0,
"idx": idx,
"task_name": task_name,
}


class IterableNumbersDataset(IterableDataset):
def __init__(self, tokenizer, env_cfg, task_name, add_system_prompt, length):
super().__init__()
self.tokenizer = tokenizer
self.env_cfg = env_cfg
self.task_name = task_name
self.add_system_prompt = add_system_prompt
self.length = length

def __iter__(self) -> Iterator[DatumSpec]:
for i in itertools.count():
yield generate_datum(
tokenizer=self.tokenizer,
env_cfg=self.env_cfg,
task_name=self.task_name,
idx=i,
add_system_prompt=self.add_system_prompt,
)

def __len__(self):
return self.length


def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt):
env_config = env_cfg[task_name]
env = UniqueNumbersEnv.options(num_gpus=0).remote(cfg=dict(env_config["cfg"]))
task_to_env = {task_name: env}

train_ds = IterableNumbersDataset(
tokenizer=tokenizer,
env_cfg=env_config,
task_name=task_name,
add_system_prompt=add_system_prompt,
length=length,
)
val_ds = IterableNumbersDataset(
tokenizer=tokenizer,
env_cfg=env_config,
task_name=task_name,
add_system_prompt=add_system_prompt,
length=val_length,
)
val_task_to_env = task_to_env
return train_ds, val_ds, task_to_env, val_task_to_env


def main():
args, overrides = parse_args()
if not args.config:
args.config = os.path.join(
os.path.dirname(__file__), "configs", "grpo_adk_llama8b.yaml"
)
config = load_config(args.config)
if overrides:
config = parse_hydra_overrides(config, overrides)
config: MasterConfig = OmegaConf.to_container(config, resolve=True)

now_pst = datetime.utcnow() + timedelta(hours=-7)
config["logger"]["wandb"]["name"] = config["logger"]["wandb"]["name"].replace(
"__NOW__", now_pst.strftime("%m/%d-%H:%M")
)

config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
if config["checkpointing"]["enabled"]:
print(
f"\U0001f4ca Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)

pprint.pprint(config)

init_ray()

tokenizer = get_tokenizer(config["policy"]["tokenizer"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)

ds_length = (
config["grpo"]["num_prompts_per_step"]
* config["grpo"]["num_generations_per_prompt"]
* config["grpo"]["max_num_steps"]
)
dataset, val_dataset, task_to_env, val_task_to_env = setup_data(
tokenizer=tokenizer,
env_cfg=config["env"],
task_name="unique_numbers",
length=ds_length,
val_length=config["grpo"]["max_val_samples"],
add_system_prompt=config["data"]["add_system_prompt"],
)

(
policy,
policy_generation,
cluster,
dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)

grpo_train(
policy,
policy_generation,
dataloader,
val_dataloader,
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
grpo_state,
master_config,
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions nemo_rl/distributed/ray_actor_environment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker": PY_EXECUTABLES.MCORE,
"nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM,
"nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": PY_EXECUTABLES.SYSTEM,
}


Expand Down
Empty file.
Loading