feat: Enable simulated user for multi-turn GRPO#1412
Conversation
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
📝 WalkthroughWalkthroughThis PR introduces a new simulated user environment called "unique numbers" that uses Google ADK agents to simulate user-agent interactions within a GRPO training framework. It includes configuration files, environment implementation, utility functions, an example training script, infrastructure updates, and comprehensive tests. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Training User
participant GRPO as GRPO Training
participant UniqueEnv as Unique Numbers Env
participant SimUser as Simulated User Runner
participant Grader as Grader Runner
User->>GRPO: Start training with config
GRPO->>UniqueEnv: Initialize environment
UniqueEnv->>SimUser: Create ADK agent
UniqueEnv->>Grader: Create ADK agent
loop Per training step
GRPO->>UniqueEnv: step(message_log, metadata)
UniqueEnv->>SimUser: extract last assistant message
UniqueEnv->>SimUser: run_prompt_async(query or statement)
SimUser-->>UniqueEnv: simulated user response
UniqueEnv->>UniqueEnv: check if guess pattern matched
alt Guess detected
UniqueEnv->>UniqueEnv: compute reward (correct/incorrect)
UniqueEnv->>Grader: run_prompt_async(grade conversation)
Grader-->>UniqueEnv: optional score adjustment
UniqueEnv-->>GRPO: EnvironmentReturn (reward, terminated=True)
else Query or other
UniqueEnv-->>GRPO: EnvironmentReturn (response, turn increment)
end
end
GRPO->>UniqueEnv: shutdown()
UniqueEnv->>SimUser: cleanup
UniqueEnv->>Grader: cleanup
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes This PR introduces substantial new functionality spanning multiple files with heterogeneous changes: new environment class with orchestration logic, async utilities with retry mechanisms, integration with external ADK library, rollout processing modifications, and comprehensive test coverage. The logic density is moderate-to-high, with coordination between simulated user/grader runners and reward computation requiring careful review. Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
19-26: Honor NEMO_RL_PY_EXECUTABLES_SYSTEM for ADK for parity.Other entries (VLLM/MCORE) respect the system override; do the same for ADK.
Apply:
USE_SYSTEM_EXECUTABLE = os.environ.get("NEMO_RL_PY_EXECUTABLES_SYSTEM", "0") == "1" VLLM_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM ) MCORE_EXECUTABLE = ( PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE ) +ADK_EXECUTABLE = ( + PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.ADK +) @@ - "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": PY_EXECUTABLES.ADK, + "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": ADK_EXECUTABLE,Also applies to: 40-41
nemo_rl/environments/interfaces.py (1)
80-82: Update step() return docs to include answers.Docstring enumerates 5 fields; add the optional
answersto avoid confusion.Apply:
- - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags. + - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, terminateds, and answers (optional).nemo_rl/experience/rollouts.py (2)
475-491: Bug: Tensor used inifcondition; will raise truth-value error.
active_input_lengths[i]is a Tensor;if (... + active_input_lengths[i] >= max_seq_len)yields a Tensor boolean, which is invalid inif.Apply:
- if ( - len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i] - >= max_seq_len - ): + input_len = int(active_input_lengths[i].item()) + if len(tokenized_obs) + len(generated_ids[i]) + input_len >= max_seq_len: tokens_left_for_obs = max_seq_len - ( - len(generated_ids[i]) + active_input_lengths[i] + len(generated_ids[i]) + input_len )
758-766: Bug: Tensor truth value inifcondition (single-sample path).
input_lengthsis a Tensor; comparing directly inifis invalid.Apply:
- if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len: + input_len = int(input_lengths.item()) + if input_len + gen_token_count + len(tokenized_obs) >= max_seq_len: # Truncate environment observation - max_env_tokens = max_seq_len - input_lengths - gen_token_count + max_env_tokens = max_seq_len - input_len - gen_token_count
🧹 Nitpick comments (13)
examples/configs/grpo_adk_llama8b.yaml (1)
37-43: Confirm intentional batching differences vs Gemma config.
dynamic_batching.enabledis False here but True in the Gemma config. If this is model‑specific tuning, consider a short YAML comment noting why.nemo_rl/experience/rollouts.py (1)
376-380: Prefer logger over print for per-turn progress.Switch to
logging.getLogger(__name__).info/debugand allow callers to control verbosity.To support this outside the hunk, add:
# at top-level import logging logger = logging.getLogger(__name__)Then:
- if max_rollout_turns > 1: - print( - f"▶ ▶ ▶ Running rollout turn {turn + 1} / {max_rollout_turns} with {len(active_indices)} active samples..." - ) + if max_rollout_turns > 1: + logger.info( + "▶ ▶ ▶ Running rollout turn %d / %d with %d active samples...", + turn + 1, max_rollout_turns, len(active_indices) + )nemo_rl/environments/simulated_user/prompt.py (3)
1-7: Promote constants to UPPER_SNAKE_CASE and keep aliases.Module-level prompt strings are constants. Rename to UPPER_SNAKE_CASE and keep lowercase aliases for compatibility.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from typing import Final + - starting_user_prompt = ( + STARTING_USER_PROMPT: Final[str] = ( "I will play a game with you. I have a list of integers in mind and can NOT tell you. " "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do is the following: " "You can either ask me 'what is number k?' to get the number at position k in my list, " "or answer 'there are m unique numbers' whenever you feel you want to make a guess. " "Please do not say anything else. You cannot ask me to provide the list of integers." ) +simulated_user_instruction = SIMULATED_USER_INSTRUCTION = SIMULATED_USER_INSTRUCTION # type: ignore # back-compat alias +starting_user_prompt = STARTING_USER_PROMPT # back-compat alias +grader_instruction = GRADER_INSTRUCTION # back-compat aliasFollow-up diff below adjusts definitions and typos.
As per coding guidelines.
2-7: Fix grammar in the starting prompt.Minor clarity/grammar tweaks.
Apply:
- "I will play a game with you. I have a list of integers in mind and can NOT tell you. " - "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do is the following: " + "I will play a game with you. I have a list of integers that I will not reveal. " + "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do are: "
10-19: Use consistent naming and strip once.Define as constant and avoid trailing strip duplication.
Apply:
-simulated_user_instruction = """ +SIMULATED_USER_INSTRUCTION: Final[str] = """ ... -""".strip() +""".strip()tests/unit/environments/test_simulated_user.py (2)
90-107: Fixture stubs: keep scope local and silence ruff ARG warnings.Good isolation overall. Consider prefixing unused fixture/function args with “_” or add “# noqa: ARG001/ARG005” where appropriate to keep linters quiet without affecting readability.
220-249: Retry test is solid; minor style nit.
monkeypatcharg is unused. Prefix with_monkeypatchfor clarity.examples/run_grpo_unique_numbers_w_adk.py (2)
195-198: Timezone: avoid manual UTC offsets.Use zoneinfo to format local time reliably (handles DST).
Apply:
-from datetime import datetime, timedelta +from datetime import datetime +from zoneinfo import ZoneInfo ... - now_pst = datetime.utcnow() + timedelta(hours=-7) + now_pst = datetime.now(ZoneInfo("America/Los_Angeles"))
229-240: Unused variable ‘cluster’.Prefix with underscore to silence linters.
Apply:
- cluster, + _cluster,nemo_rl/environments/simulated_user/adk_utils.py (3)
22-36: Minor grammar in default instruction.“help people” → “helps people”.
Apply:
- instruction=instruction - or "You are a helpful assistant that help people answer questions.", + instruction=instruction + or "You are a helpful assistant that helps people answer questions.",
100-116: Use logger.exception and tighten generic except.Prefer
logger.exceptionto capture traceback. Keep a narrowServerErrorexcept, and retain a broad fallback but clearly log traceback.Apply:
- except ServerError as e: + except ServerError as e: retries += 1 delay_with_jitter = delay + (random.random() * 2 - 1) * (delay * 0.5) - logger.error( + logger.exception( f"Gemini API call (with message {new_message}) failed with ServerError {e} (attempt {retries}/{max_retries}). Retrying in {delay_with_jitter} seconds..." ) await asyncio.sleep(delay_with_jitter) delay *= 2 # Exponential backoff - except Exception as e: - logger.error( + except Exception as e: # keep as last-resort + logger.exception( f"Gemini API call (with message {new_message}) failed with an unexpected error: {e}." ) return f"<No response due to unexpected error: {e}>" - logger.error( + logger.error( f"Gemini API call (with message {new_message}) reached maximum retries ({max_retries}) without success." ) - return f"<No response due after {max_retries} retries>" + return f"<No response after {max_retries} retries>"
39-46: Session access assertions are brittle.Assuming exactly one app/user/session can break multi-user scenarios. Return helpful errors or search by keys defensively.
Apply (illustrative):
- assert len(app_session_map) == 1, "Expected exactly one app in session_service" - user_sessions_map = next(iter(app_session_map.values())) - sessions = user_sessions_map[user_id] - assert len(sessions) == 1, "Expected exactly one user in app session" - return next(iter(sessions.values())) + if not app_session_map: + raise RuntimeError("No sessions available in session_service") + # Prefer the first app containing the user_id + for user_sessions_map in app_session_map.values(): + if user_id in user_sessions_map: + sessions = user_sessions_map[user_id] + if not sessions: + raise RuntimeError(f"No sessions for user_id={user_id}") + return next(iter(sessions.values())) + raise KeyError(f"user_id={user_id} not found in session_service")nemo_rl/environments/simulated_user/unique_numbers.py (1)
246-261: zip(strict=...) for clarity and static analysis.Be explicit with
strict=Falseto document intent and silence linters.Apply:
- for log, meta in zip(message_log_batch, metadata): + for log, meta in zip(message_log_batch, metadata, strict=False):
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (13)
examples/configs/grpo_adk_gemma.yaml(1 hunks)examples/configs/grpo_adk_llama8b.yaml(1 hunks)examples/run_grpo_unique_numbers_w_adk.py(1 hunks)nemo_rl/distributed/ray_actor_environment_registry.py(1 hunks)nemo_rl/distributed/virtual_cluster.py(1 hunks)nemo_rl/environments/interfaces.py(1 hunks)nemo_rl/environments/simulated_user/adk_utils.py(1 hunks)nemo_rl/environments/simulated_user/prompt.py(1 hunks)nemo_rl/environments/simulated_user/unique_numbers.py(1 hunks)nemo_rl/experience/rollouts.py(6 hunks)pyproject.toml(1 hunks)pyrefly.toml(2 hunks)tests/unit/environments/test_simulated_user.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/distributed/ray_actor_environment_registry.pynemo_rl/environments/simulated_user/prompt.pynemo_rl/environments/interfaces.pynemo_rl/environments/simulated_user/adk_utils.pynemo_rl/distributed/virtual_cluster.pyexamples/run_grpo_unique_numbers_w_adk.pynemo_rl/environments/simulated_user/unique_numbers.pynemo_rl/experience/rollouts.pytests/unit/environments/test_simulated_user.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/distributed/ray_actor_environment_registry.pynemo_rl/environments/simulated_user/prompt.pynemo_rl/environments/interfaces.pynemo_rl/environments/simulated_user/adk_utils.pynemo_rl/distributed/virtual_cluster.pynemo_rl/environments/simulated_user/unique_numbers.pynemo_rl/experience/rollouts.py
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/grpo_adk_gemma.yamlexamples/configs/grpo_adk_llama8b.yaml
🧬 Code graph analysis (6)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
PY_EXECUTABLES(42-60)
nemo_rl/environments/simulated_user/adk_utils.py (1)
tests/unit/environments/test_simulated_user.py (2)
from_text(26-27)create_session(125-130)
examples/run_grpo_unique_numbers_w_adk.py (8)
nemo_rl/algorithms/utils.py (1)
get_tokenizer(157-288)nemo_rl/data/interfaces.py (1)
DatumSpec(32-40)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env(50-65)nemo_rl/distributed/virtual_cluster.py (1)
init_ray(86-172)nemo_rl/environments/simulated_user/unique_numbers.py (2)
UniqueNumbersEnv(229-296)UniqueNumbersMetadata(44-52)nemo_rl/models/generation/__init__.py (1)
configure_generation_config(24-45)nemo_rl/utils/config.py (1)
parse_hydra_overrides(146-166)nemo_rl/utils/logger.py (1)
get_next_experiment_dir(1311-1345)
nemo_rl/environments/simulated_user/unique_numbers.py (2)
nemo_rl/environments/interfaces.py (2)
EnvironmentInterface(52-88)EnvironmentReturn(26-49)nemo_rl/environments/simulated_user/adk_utils.py (3)
extract_conversation_history(52-61)create_agent(14-36)run_prompt_async(64-116)
nemo_rl/experience/rollouts.py (1)
tests/unit/data/test_data_processor.py (1)
apply_chat_template(45-57)
tests/unit/environments/test_simulated_user.py (2)
nemo_rl/environments/simulated_user/unique_numbers.py (1)
_UniqueNumbersRunner(55-225)nemo_rl/environments/simulated_user/adk_utils.py (2)
run_prompt_async(64-116)extract_conversation_history(52-61)
🪛 Ruff (0.14.1)
nemo_rl/environments/simulated_user/adk_utils.py
101-101: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
102-104: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
107-107: Do not catch blind exception: Exception
(BLE001)
108-110: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
175-175: Unpacked variable convo2 is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
examples/run_grpo_unique_numbers_w_adk.py
82-82: Unused function argument: add_system_prompt
(ARG001)
100-100: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
101-101: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
232-232: Unpacked variable cluster is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
nemo_rl/environments/simulated_user/unique_numbers.py
197-197: Do not catch blind exception: Exception
(BLE001)
246-246: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
tests/unit/environments/test_simulated_user.py
92-92: Unused function argument: user_id
(ARG001)
92-92: Unused function argument: silence
(ARG001)
105-105: Unused lambda argument: a
(ARG005)
105-105: Unused lambda argument: k
(ARG005)
166-166: Unused function argument: patch_unique_numbers
(ARG001)
180-180: Unused function argument: patch_unique_numbers
(ARG001)
194-194: Unused function argument: patch_unique_numbers
(ARG001)
207-207: Unused function argument: patch_unique_numbers
(ARG001)
220-220: Unused function argument: patch_adk_utils
(ARG001)
228-228: Unused function argument: monkeypatch
(ARG001)
228-228: Unused function argument: patch_adk_utils
(ARG001)
252-252: Unused function argument: patch_adk_utils
(ARG001)
🔇 Additional comments (5)
pyrefly.toml (1)
19-20: Stub additions look good.Import stubs for
google.adk.*andgoogle.genai.*plus includes for simulated_user modules are appropriate for static checks.Also applies to: 89-92
nemo_rl/distributed/virtual_cluster.py (1)
60-61: ADK executable entry is consistent.Matches the existing pattern for extras-based executables.
pyproject.toml (1)
104-107: Verify ADK extra resolves and imports using pipThe
uvcommand isn’t available; please run:python3 -m venv venv source venv/bin/activate pip install --upgrade pip pip install .[adk] python - <<'PY' import google.adk, google.genai print("OK", google.adk.__version__, getattr(google.genai, "__version__", "n/a")) PYnemo_rl/environments/simulated_user/unique_numbers.py (1)
145-156: Async orchestration inside thread pool: verify no running event loop.
asyncio.runinside threads works, but if an event loop is active it raisesRuntimeError. Consider extracting an asyncprocess_turn_asyncand running it at the caller, or add a small helper that usesasyncio.get_running_loop()and falls back appropriately.Would you like a patch to make
process_turnasync and wire it throughThreadPoolExecutorwithasyncio.runat the top-level call site?examples/run_grpo_unique_numbers_w_adk.py (1)
186-193: Default config and registry verified:examples/configs/grpo_adk_llama8b.yamlis present andUniqueNumbersEnvis correctly registered in the environment registry.
| query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE) | ||
| guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE) | ||
|
|
There was a problem hiding this comment.
Regex misses plural ‘numbers’; guesses won’t register.
Pattern only matches “unique number”. It should accept both singular and plural.
Apply:
- guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE)
+ guess_re = re.compile(r"there are (\d+)\s+unique numbers?\b", re.IGNORECASE)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE) | |
| guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE) | |
| query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE) | |
| guess_re = re.compile(r"there are (\d+)\s+unique numbers?\b", re.IGNORECASE) |
🤖 Prompt for AI Agents
In nemo_rl/environments/simulated_user/unique_numbers.py around lines 56 to 58,
the guess_re regex only matches the singular phrase "unique number" so plural
guesses like "unique numbers" won't register; update the pattern to accept both
singular and plural (e.g., make "number" optional "s" or use "numbers?" with
appropriate word boundaries) and keep the re.IGNORECASE flag so both "number"
and "numbers" are matched.
| grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1." | ||
| grading_response = asyncio.run( | ||
| run_prompt_async( | ||
| metadata["grader_runner"], | ||
| "grader", | ||
| grading_prompt, | ||
| silence=True, | ||
| ) | ||
| ) | ||
| try: | ||
| grade = int(re.search(r"(\d+)", grading_response).group(1)) | ||
| reward = (reward + grade) / 2.0 | ||
| except Exception as e: | ||
| print( | ||
| f"Failed to parse grade from grader response '{grading_response}': {e}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Grader integration: fix typo and parse numeric [0,1] robustly.
- “converstation” → “conversation”.
- Parse floats, clamp to [0,1]; don’t cast to int (drops partial credit).
Apply:
- grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1."
+ grading_prompt = f"Here is the conversation:\n{convo_str}\nPlease return only a numeric score between 0 and 1."
grading_response = asyncio.run(
run_prompt_async(
metadata["grader_runner"],
"grader",
grading_prompt,
silence=True,
)
)
try:
- grade = int(re.search(r"(\d+)", grading_response).group(1))
- reward = (reward + grade) / 2.0
+ m = re.search(r"\b([01](?:\.\d+)?)\b", grading_response)
+ if m:
+ grade = float(m.group(1))
+ grade = max(0.0, min(1.0, grade))
+ reward = 0.5 * (reward + grade)
+ else:
+ raise ValueError("No numeric score found")
except Exception as e:
print(
f"Failed to parse grade from grader response '{grading_response}': {e}"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1." | |
| grading_response = asyncio.run( | |
| run_prompt_async( | |
| metadata["grader_runner"], | |
| "grader", | |
| grading_prompt, | |
| silence=True, | |
| ) | |
| ) | |
| try: | |
| grade = int(re.search(r"(\d+)", grading_response).group(1)) | |
| reward = (reward + grade) / 2.0 | |
| except Exception as e: | |
| print( | |
| f"Failed to parse grade from grader response '{grading_response}': {e}" | |
| ) | |
| grading_prompt = f"Here is the conversation:\n{convo_str}\nPlease return only a numeric score between 0 and 1." | |
| grading_response = asyncio.run( | |
| run_prompt_async( | |
| metadata["grader_runner"], | |
| "grader", | |
| grading_prompt, | |
| silence=True, | |
| ) | |
| ) | |
| try: | |
| m = re.search(r"\b([01](?:\.\d+)?)\b", grading_response) | |
| if m: | |
| grade = float(m.group(1)) | |
| grade = max(0.0, min(1.0, grade)) | |
| reward = 0.5 * (reward + grade) | |
| else: | |
| raise ValueError("No numeric score found") | |
| except Exception as e: | |
| print( | |
| f"Failed to parse grade from grader response '{grading_response}': {e}" | |
| ) |
🧰 Tools
🪛 Ruff (0.14.1)
197-197: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In nemo_rl/environments/simulated_user/unique_numbers.py around lines 185 to
201, fix the typo in the grading prompt ("converstation" → "conversation") and
change the grade parsing to robustly extract a float (allow integers or
decimals, e.g., via regex for optional decimal), do not cast to int, clamp the
parsed value to the [0.0, 1.0] range, compute reward = (reward + parsed_grade) /
2.0, and keep the existing exception handling but include the raw
grading_response in the log; ensure any ValueError or None from regex is handled
gracefully and does not crash the function.
| @ray.remote | ||
| class UniqueNumbersEnv(EnvironmentInterface): | ||
| """Environment where the LLM must deduce the count of unique numbers.""" | ||
|
|
||
| def __init__(self, cfg: Optional[UniqueNumbersConfig] = None): | ||
| cfg = cfg or UniqueNumbersConfig() | ||
| self.min_length = cfg.get("min_length", 3) | ||
| self.max_length = cfg.get("max_length", 7) | ||
| self.default_max_turns = cfg.get("max_turns", 10) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
@ray.remote requires pragma and avoid hidden defaults.
- Add “# pragma: no cover” on the class line.
- Do not set non-None config defaults in code; YAML is the single source of truth. Either require cfg or read keys directly without defaults. Current
self.min_length/max_length/default_max_turnsare unused; remove or read from cfg directly.
Apply:
-@ray.remote
-class UniqueNumbersEnv(EnvironmentInterface):
+@ray.remote
+class UniqueNumbersEnv(EnvironmentInterface): # pragma: no cover
@@
- def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
- cfg = cfg or UniqueNumbersConfig()
- self.min_length = cfg.get("min_length", 3)
- self.max_length = cfg.get("max_length", 7)
- self.default_max_turns = cfg.get("max_turns", 10)
+ def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
+ if cfg is None:
+ raise ValueError("cfg is required; defaults must be provided via YAML.")
+ # If needed later, access required keys directly:
+ # self.min_length = cfg["min_length"]; self.max_length = cfg["max_length"]; self.default_max_turns = cfg["max_turns"]As per coding guidelines.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @ray.remote | |
| class UniqueNumbersEnv(EnvironmentInterface): | |
| """Environment where the LLM must deduce the count of unique numbers.""" | |
| def __init__(self, cfg: Optional[UniqueNumbersConfig] = None): | |
| cfg = cfg or UniqueNumbersConfig() | |
| self.min_length = cfg.get("min_length", 3) | |
| self.max_length = cfg.get("max_length", 7) | |
| self.default_max_turns = cfg.get("max_turns", 10) | |
| @ray.remote | |
| class UniqueNumbersEnv(EnvironmentInterface): # pragma: no cover | |
| """Environment where the LLM must deduce the count of unique numbers.""" | |
| def __init__(self, cfg: Optional[UniqueNumbersConfig] = None): | |
| if cfg is None: | |
| raise ValueError("cfg is required; defaults must be provided via YAML.") | |
| # If needed later, access required keys directly: | |
| # self.min_length = cfg["min_length"] | |
| # self.max_length = cfg["max_length"] | |
| # self.default_max_turns = cfg["max_turns"] |
🤖 Prompt for AI Agents
nemo_rl/environments/simulated_user/unique_numbers.py around lines 228-236: the
@ray.remote class decorator needs a coverage pragma and the constructor must not
embed hidden defaults from YAML; add "# pragma: no cover" to the @ray.remote
line, and stop assigning non-None fallback values inside __init__. Either make
cfg required (remove the Optional and raise a clear error if None) or read
config keys directly without providing hardcoded defaults (e.g., assign
self.min_length = cfg["min_length"] etc.), and if those attributes are actually
unused remove them altogether; ensure no hidden defaults remain in code and
update the type signature and any callers accordingly.
9c5ff7a to
4522038
Compare
4522038 to
7efcaa1
Compare
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Mohammadreza Mohseni <mmohseni@google.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
7efcaa1 to
e495cb8
Compare
nemo_rl/experience/rollouts.py
Outdated
| and hasattr(tokenizer, "bos_token_id") | ||
| and formatted_obs[0] == tokenizer.bos_token_id | ||
| ): | ||
| formatted_obs = formatted_obs[1:] |
There was a problem hiding this comment.
can we have another assert that the first token isn't bos?
nemo_rl/experience/rollouts.py
Outdated
| ).input_ids[0] | ||
| # Tokenize the raw content from the environment into chat format if needed | ||
| env_role = env_output.observations[i]["role"].lower() | ||
| if env_role in {"user", "assistant", "system"}: |
There was a problem hiding this comment.
could we have unit tests for both paths which will also demonstrate what the message log should look like?
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
649ad02 to
d0d4629
Compare
Replacing PR #732
Waiting for runs to confirm the convergence graph attached in the original PR
What does this PR do ?
Add an simple example on multi-turn GRPO using ADK.
Issues
List issues that this PR closes (syntax):
Usage
Training reward:


Validation acc:
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Dependencies