Skip to content

feat: Enable simulated user for multi-turn GRPO#1412

Open
ahmadki wants to merge 22 commits intomainfrom
ahmadki/simulated-user-rec
Open

feat: Enable simulated user for multi-turn GRPO#1412
ahmadki wants to merge 22 commits intomainfrom
ahmadki/simulated-user-rec

Conversation

@ahmadki
Copy link
Member

@ahmadki ahmadki commented Oct 22, 2025

Replacing PR #732

Waiting for runs to confirm the convergence graph attached in the original PR

What does this PR do ?

Add an simple example on multi-turn GRPO using ADK.

Issues

List issues that this PR closes (syntax):

Usage

uv run --extra adk --extra automodel python examples/run_grpo_unique_numbers_w_adk.py

Training reward:
image
Validation acc:
image

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Added GRPO training example and configuration files for unique numbers task using ADK integration
    • Introduced UniqueNumbersEnv: a configurable simulated user environment with reward computation and optional grading capabilities
  • Dependencies

    • Added Google ADK and Google GenAI packages for enhanced agent capabilities

jialei777 and others added 11 commits July 23, 2025 15:44
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
@ahmadki ahmadki requested review from a team as code owners October 22, 2025 21:25
@ahmadki ahmadki changed the title feat: Enable simulated user for multi-turn GRPO DRAFT: feat: Enable simulated user for multi-turn GRPO Oct 22, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 22, 2025

📝 Walkthrough

Walkthrough

This PR introduces a new simulated user environment called "unique numbers" that uses Google ADK agents to simulate user-agent interactions within a GRPO training framework. It includes configuration files, environment implementation, utility functions, an example training script, infrastructure updates, and comprehensive tests.

Changes

Cohort / File(s) Summary
GRPO Configuration Files
examples/configs/grpo_adk_gemma.yaml, examples/configs/grpo_adk_llama8b.yaml
New YAML configurations for GRPO training with unique-numbers environment, specifying model, tokenizer, data handling, checkpointing, wandb logging, and cluster GPU allocation.
Example Training Script
examples/run_grpo_unique_numbers_w_adk.py
New script that wires GRPO training with ADK-based unique-numbers environment; includes datum generation pipeline, IterableNumbersDataset class, Ray environment setup, and orchestration of training workflow.
Simulated User Environment Implementation
nemo_rl/environments/simulated_user/unique_numbers.py
New Ray-enabled environment class (UniqueNumbersEnv) that orchestrates simulated user and grader runners; implements step processing, reward computation, termination logic, and optional ADK logging.
ADK Utilities and Prompts
nemo_rl/environments/simulated_user/adk_utils.py, nemo_rl/environments/simulated_user/prompt.py
New module providing agent creation, session extraction, async prompting with retry logic, runner setup, and demonstration orchestration; plus three new prompt templates for game rules, simulated user behavior, and grader scoring.
Environment Infrastructure
nemo_rl/distributed/ray_actor_environment_registry.py, nemo_rl/distributed/virtual_cluster.py, nemo_rl/environments/interfaces.py
Registry entry mapping UniqueNumbersEnv to ADK executable; new ADK executable definition in PY_EXECUTABLES; optional default value added to EnvironmentReturn.answers field.
Rollout Processing
nemo_rl/experience/rollouts.py
Modified to use chat-template-aware formatting for environment observations instead of plain tokenization; adds multimodal data handling and per-turn logging.
Dependencies and Configuration
pyproject.toml, pyrefly.toml
New dependency group adk with google-adk and google-genai packages; extended import replacements and project includes for ADK modules.
Unit Tests
tests/unit/environments/test_simulated_user.py
New comprehensive test module with mocked ADK components, testing UniqueNumbersRunner behavior, async prompting flow, conversation extraction, and integration with simulated user/grader runners.

Sequence Diagram(s)

sequenceDiagram
    participant User as Training User
    participant GRPO as GRPO Training
    participant UniqueEnv as Unique Numbers Env
    participant SimUser as Simulated User Runner
    participant Grader as Grader Runner
    
    User->>GRPO: Start training with config
    GRPO->>UniqueEnv: Initialize environment
    UniqueEnv->>SimUser: Create ADK agent
    UniqueEnv->>Grader: Create ADK agent
    
    loop Per training step
        GRPO->>UniqueEnv: step(message_log, metadata)
        UniqueEnv->>SimUser: extract last assistant message
        UniqueEnv->>SimUser: run_prompt_async(query or statement)
        SimUser-->>UniqueEnv: simulated user response
        UniqueEnv->>UniqueEnv: check if guess pattern matched
        alt Guess detected
            UniqueEnv->>UniqueEnv: compute reward (correct/incorrect)
            UniqueEnv->>Grader: run_prompt_async(grade conversation)
            Grader-->>UniqueEnv: optional score adjustment
            UniqueEnv-->>GRPO: EnvironmentReturn (reward, terminated=True)
        else Query or other
            UniqueEnv-->>GRPO: EnvironmentReturn (response, turn increment)
        end
    end
    
    GRPO->>UniqueEnv: shutdown()
    UniqueEnv->>SimUser: cleanup
    UniqueEnv->>Grader: cleanup
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

This PR introduces substantial new functionality spanning multiple files with heterogeneous changes: new environment class with orchestration logic, async utilities with retry mechanisms, integration with external ADK library, rollout processing modifications, and comprehensive test coverage. The logic density is moderate-to-high, with coordination between simulated user/grader runners and reward computation requiring careful review.

Possibly related PRs

  • feat: Add Penguin env #1327: Adds environment-to-executable mappings and corresponding PY_EXECUTABLES constants in the same files (ray_actor_environment_registry.py, virtual_cluster.py) for a different environment type.

Suggested labels

documentation, CI:L1

Suggested reviewers

  • terrykong
  • parthchadha

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.26% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning This PR introduces substantial new features—including a multi-turn GRPO setup, ADK integration, and a novel simulated-user environment—but the description only alludes to reward and accuracy plots without providing the actual convergence graphs, numeric results, or performance comparisons, and even notes that runs are still pending confirmation. Consequently, there is no concrete testing or performance data documented to validate that these major changes are correct or free of regressions. Please include the actual test results in the PR description: attach the convergence plots, numeric validation metrics, and any before-and-after performance figures along with details of the experimental setup so reviewers can verify no regressions or unintended performance impacts.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: Enable simulated user for multi-turn GRPO' clearly and concisely summarizes the main feature addition - enabling simulated user functionality for multi-turn GRPO training, which is the central theme across the changeset.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ahmadki/simulated-user-rec

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
nemo_rl/distributed/ray_actor_environment_registry.py (1)

19-26: Honor NEMO_RL_PY_EXECUTABLES_SYSTEM for ADK for parity.

Other entries (VLLM/MCORE) respect the system override; do the same for ADK.

Apply:

 USE_SYSTEM_EXECUTABLE = os.environ.get("NEMO_RL_PY_EXECUTABLES_SYSTEM", "0") == "1"
 VLLM_EXECUTABLE = (
     PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.VLLM
 )
 MCORE_EXECUTABLE = (
     PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.MCORE
 )
+ADK_EXECUTABLE = (
+    PY_EXECUTABLES.SYSTEM if USE_SYSTEM_EXECUTABLE else PY_EXECUTABLES.ADK
+)
@@
-    "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": PY_EXECUTABLES.ADK,
+    "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": ADK_EXECUTABLE,

Also applies to: 40-41

nemo_rl/environments/interfaces.py (1)

80-82: Update step() return docs to include answers.

Docstring enumerates 5 fields; add the optional answers to avoid confusion.

Apply:

-        - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags.
+        - EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, terminateds, and answers (optional).
nemo_rl/experience/rollouts.py (2)

475-491: Bug: Tensor used in if condition; will raise truth-value error.

active_input_lengths[i] is a Tensor; if (... + active_input_lengths[i] >= max_seq_len) yields a Tensor boolean, which is invalid in if.

Apply:

-            if (
-                len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i]
-                >= max_seq_len
-            ):
+            input_len = int(active_input_lengths[i].item())
+            if len(tokenized_obs) + len(generated_ids[i]) + input_len >= max_seq_len:
                 tokens_left_for_obs = max_seq_len - (
-                    len(generated_ids[i]) + active_input_lengths[i]
+                    len(generated_ids[i]) + input_len
                 )

758-766: Bug: Tensor truth value in if condition (single-sample path).

input_lengths is a Tensor; comparing directly in if is invalid.

Apply:

-        if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len:
+        input_len = int(input_lengths.item())
+        if input_len + gen_token_count + len(tokenized_obs) >= max_seq_len:
             # Truncate environment observation
-            max_env_tokens = max_seq_len - input_lengths - gen_token_count
+            max_env_tokens = max_seq_len - input_len - gen_token_count
🧹 Nitpick comments (13)
examples/configs/grpo_adk_llama8b.yaml (1)

37-43: Confirm intentional batching differences vs Gemma config.

dynamic_batching.enabled is False here but True in the Gemma config. If this is model‑specific tuning, consider a short YAML comment noting why.

nemo_rl/experience/rollouts.py (1)

376-380: Prefer logger over print for per-turn progress.

Switch to logging.getLogger(__name__).info/debug and allow callers to control verbosity.

To support this outside the hunk, add:

# at top-level
import logging
logger = logging.getLogger(__name__)

Then:

-        if max_rollout_turns > 1:
-            print(
-                f"▶ ▶ ▶ Running rollout turn {turn + 1} / {max_rollout_turns} with {len(active_indices)} active samples..."
-            )
+        if max_rollout_turns > 1:
+            logger.info(
+                "▶ ▶ ▶ Running rollout turn %d / %d with %d active samples...",
+                turn + 1, max_rollout_turns, len(active_indices)
+            )
nemo_rl/environments/simulated_user/prompt.py (3)

1-7: Promote constants to UPPER_SNAKE_CASE and keep aliases.

Module-level prompt strings are constants. Rename to UPPER_SNAKE_CASE and keep lowercase aliases for compatibility.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+
+from typing import Final
+
- starting_user_prompt = (
+ STARTING_USER_PROMPT: Final[str] = (
     "I will play a game with you. I have a list of integers in mind and can NOT tell you. "
     "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do is the following: "
     "You can either ask me 'what is number k?' to get the number at position k in my list, "
     "or answer 'there are m unique numbers' whenever you feel you want to make a guess. "
     "Please do not say anything else. You cannot ask me to provide the list of integers."
 )
 
+simulated_user_instruction = SIMULATED_USER_INSTRUCTION = SIMULATED_USER_INSTRUCTION  # type: ignore  # back-compat alias
+starting_user_prompt = STARTING_USER_PROMPT  # back-compat alias
+grader_instruction = GRADER_INSTRUCTION  # back-compat alias

Follow-up diff below adjusts definitions and typos.
As per coding guidelines.


2-7: Fix grammar in the starting prompt.

Minor clarity/grammar tweaks.

Apply:

- "I will play a game with you. I have a list of integers in mind and can NOT tell you. "
- "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do is the following: "
+ "I will play a game with you. I have a list of integers that I will not reveal. "
+ "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do are: "

10-19: Use consistent naming and strip once.

Define as constant and avoid trailing strip duplication.

Apply:

-simulated_user_instruction = """
+SIMULATED_USER_INSTRUCTION: Final[str] = """
 ...
-""".strip()
+""".strip()
tests/unit/environments/test_simulated_user.py (2)

90-107: Fixture stubs: keep scope local and silence ruff ARG warnings.

Good isolation overall. Consider prefixing unused fixture/function args with “_” or add “# noqa: ARG001/ARG005” where appropriate to keep linters quiet without affecting readability.


220-249: Retry test is solid; minor style nit.

monkeypatch arg is unused. Prefix with _monkeypatch for clarity.

examples/run_grpo_unique_numbers_w_adk.py (2)

195-198: Timezone: avoid manual UTC offsets.

Use zoneinfo to format local time reliably (handles DST).

Apply:

-from datetime import datetime, timedelta
+from datetime import datetime
+from zoneinfo import ZoneInfo
...
-    now_pst = datetime.utcnow() + timedelta(hours=-7)
+    now_pst = datetime.now(ZoneInfo("America/Los_Angeles"))

229-240: Unused variable ‘cluster’.

Prefix with underscore to silence linters.

Apply:

-        cluster,
+        _cluster,
nemo_rl/environments/simulated_user/adk_utils.py (3)

22-36: Minor grammar in default instruction.

“help people” → “helps people”.

Apply:

-        instruction=instruction
-        or "You are a helpful assistant that help people answer questions.",
+        instruction=instruction
+        or "You are a helpful assistant that helps people answer questions.",

100-116: Use logger.exception and tighten generic except.

Prefer logger.exception to capture traceback. Keep a narrow ServerError except, and retain a broad fallback but clearly log traceback.

Apply:

-        except ServerError as e:
+        except ServerError as e:
             retries += 1
             delay_with_jitter = delay + (random.random() * 2 - 1) * (delay * 0.5)
-            logger.error(
+            logger.exception(
                 f"Gemini API call (with message {new_message}) failed with ServerError {e} (attempt {retries}/{max_retries}). Retrying in {delay_with_jitter} seconds..."
             )
             await asyncio.sleep(delay_with_jitter)
             delay *= 2  # Exponential backoff
-        except Exception as e:
-            logger.error(
+        except Exception as e:  # keep as last-resort
+            logger.exception(
                 f"Gemini API call (with message {new_message}) failed with an unexpected error: {e}."
             )
             return f"<No response due to unexpected error: {e}>"
 
-    logger.error(
+    logger.error(
         f"Gemini API call (with message {new_message}) reached maximum retries ({max_retries}) without success."
     )
-    return f"<No response due after {max_retries} retries>"
+    return f"<No response after {max_retries} retries>"

39-46: Session access assertions are brittle.

Assuming exactly one app/user/session can break multi-user scenarios. Return helpful errors or search by keys defensively.

Apply (illustrative):

-    assert len(app_session_map) == 1, "Expected exactly one app in session_service"
-    user_sessions_map = next(iter(app_session_map.values()))
-    sessions = user_sessions_map[user_id]
-    assert len(sessions) == 1, "Expected exactly one user in app session"
-    return next(iter(sessions.values()))
+    if not app_session_map:
+        raise RuntimeError("No sessions available in session_service")
+    # Prefer the first app containing the user_id
+    for user_sessions_map in app_session_map.values():
+        if user_id in user_sessions_map:
+            sessions = user_sessions_map[user_id]
+            if not sessions:
+                raise RuntimeError(f"No sessions for user_id={user_id}")
+            return next(iter(sessions.values()))
+    raise KeyError(f"user_id={user_id} not found in session_service")
nemo_rl/environments/simulated_user/unique_numbers.py (1)

246-261: zip(strict=...) for clarity and static analysis.

Be explicit with strict=False to document intent and silence linters.

Apply:

-        for log, meta in zip(message_log_batch, metadata):
+        for log, meta in zip(message_log_batch, metadata, strict=False):
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3a69c21 and 611c70f.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (13)
  • examples/configs/grpo_adk_gemma.yaml (1 hunks)
  • examples/configs/grpo_adk_llama8b.yaml (1 hunks)
  • examples/run_grpo_unique_numbers_w_adk.py (1 hunks)
  • nemo_rl/distributed/ray_actor_environment_registry.py (1 hunks)
  • nemo_rl/distributed/virtual_cluster.py (1 hunks)
  • nemo_rl/environments/interfaces.py (1 hunks)
  • nemo_rl/environments/simulated_user/adk_utils.py (1 hunks)
  • nemo_rl/environments/simulated_user/prompt.py (1 hunks)
  • nemo_rl/environments/simulated_user/unique_numbers.py (1 hunks)
  • nemo_rl/experience/rollouts.py (6 hunks)
  • pyproject.toml (1 hunks)
  • pyrefly.toml (2 hunks)
  • tests/unit/environments/test_simulated_user.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/distributed/ray_actor_environment_registry.py
  • nemo_rl/environments/simulated_user/prompt.py
  • nemo_rl/environments/interfaces.py
  • nemo_rl/environments/simulated_user/adk_utils.py
  • nemo_rl/distributed/virtual_cluster.py
  • examples/run_grpo_unique_numbers_w_adk.py
  • nemo_rl/environments/simulated_user/unique_numbers.py
  • nemo_rl/experience/rollouts.py
  • tests/unit/environments/test_simulated_user.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/distributed/ray_actor_environment_registry.py
  • nemo_rl/environments/simulated_user/prompt.py
  • nemo_rl/environments/interfaces.py
  • nemo_rl/environments/simulated_user/adk_utils.py
  • nemo_rl/distributed/virtual_cluster.py
  • nemo_rl/environments/simulated_user/unique_numbers.py
  • nemo_rl/experience/rollouts.py
examples/configs/*.yaml

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/
.yaml

Files:

  • examples/configs/grpo_adk_gemma.yaml
  • examples/configs/grpo_adk_llama8b.yaml
🧬 Code graph analysis (6)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
nemo_rl/distributed/virtual_cluster.py (1)
  • PY_EXECUTABLES (42-60)
nemo_rl/environments/simulated_user/adk_utils.py (1)
tests/unit/environments/test_simulated_user.py (2)
  • from_text (26-27)
  • create_session (125-130)
examples/run_grpo_unique_numbers_w_adk.py (8)
nemo_rl/algorithms/utils.py (1)
  • get_tokenizer (157-288)
nemo_rl/data/interfaces.py (1)
  • DatumSpec (32-40)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
  • get_actor_python_env (50-65)
nemo_rl/distributed/virtual_cluster.py (1)
  • init_ray (86-172)
nemo_rl/environments/simulated_user/unique_numbers.py (2)
  • UniqueNumbersEnv (229-296)
  • UniqueNumbersMetadata (44-52)
nemo_rl/models/generation/__init__.py (1)
  • configure_generation_config (24-45)
nemo_rl/utils/config.py (1)
  • parse_hydra_overrides (146-166)
nemo_rl/utils/logger.py (1)
  • get_next_experiment_dir (1311-1345)
nemo_rl/environments/simulated_user/unique_numbers.py (2)
nemo_rl/environments/interfaces.py (2)
  • EnvironmentInterface (52-88)
  • EnvironmentReturn (26-49)
nemo_rl/environments/simulated_user/adk_utils.py (3)
  • extract_conversation_history (52-61)
  • create_agent (14-36)
  • run_prompt_async (64-116)
nemo_rl/experience/rollouts.py (1)
tests/unit/data/test_data_processor.py (1)
  • apply_chat_template (45-57)
tests/unit/environments/test_simulated_user.py (2)
nemo_rl/environments/simulated_user/unique_numbers.py (1)
  • _UniqueNumbersRunner (55-225)
nemo_rl/environments/simulated_user/adk_utils.py (2)
  • run_prompt_async (64-116)
  • extract_conversation_history (52-61)
🪛 Ruff (0.14.1)
nemo_rl/environments/simulated_user/adk_utils.py

101-101: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


102-104: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


107-107: Do not catch blind exception: Exception

(BLE001)


108-110: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


175-175: Unpacked variable convo2 is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

examples/run_grpo_unique_numbers_w_adk.py

82-82: Unused function argument: add_system_prompt

(ARG001)


100-100: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


101-101: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


232-232: Unpacked variable cluster is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

nemo_rl/environments/simulated_user/unique_numbers.py

197-197: Do not catch blind exception: Exception

(BLE001)


246-246: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

tests/unit/environments/test_simulated_user.py

92-92: Unused function argument: user_id

(ARG001)


92-92: Unused function argument: silence

(ARG001)


105-105: Unused lambda argument: a

(ARG005)


105-105: Unused lambda argument: k

(ARG005)


166-166: Unused function argument: patch_unique_numbers

(ARG001)


180-180: Unused function argument: patch_unique_numbers

(ARG001)


194-194: Unused function argument: patch_unique_numbers

(ARG001)


207-207: Unused function argument: patch_unique_numbers

(ARG001)


220-220: Unused function argument: patch_adk_utils

(ARG001)


228-228: Unused function argument: monkeypatch

(ARG001)


228-228: Unused function argument: patch_adk_utils

(ARG001)


252-252: Unused function argument: patch_adk_utils

(ARG001)

🔇 Additional comments (5)
pyrefly.toml (1)

19-20: Stub additions look good.

Import stubs for google.adk.* and google.genai.* plus includes for simulated_user modules are appropriate for static checks.

Also applies to: 89-92

nemo_rl/distributed/virtual_cluster.py (1)

60-61: ADK executable entry is consistent.

Matches the existing pattern for extras-based executables.

pyproject.toml (1)

104-107: Verify ADK extra resolves and imports using pip

The uv command isn’t available; please run:

python3 -m venv venv
source venv/bin/activate
pip install --upgrade pip
pip install .[adk]
python - <<'PY'
import google.adk, google.genai
print("OK", google.adk.__version__, getattr(google.genai, "__version__", "n/a"))
PY
nemo_rl/environments/simulated_user/unique_numbers.py (1)

145-156: Async orchestration inside thread pool: verify no running event loop.

asyncio.run inside threads works, but if an event loop is active it raises RuntimeError. Consider extracting an async process_turn_async and running it at the caller, or add a small helper that uses asyncio.get_running_loop() and falls back appropriately.

Would you like a patch to make process_turn async and wire it through ThreadPoolExecutor with asyncio.run at the top-level call site?

examples/run_grpo_unique_numbers_w_adk.py (1)

186-193: Default config and registry verified: examples/configs/grpo_adk_llama8b.yaml is present and UniqueNumbersEnv is correctly registered in the environment registry.

Comment on lines +56 to +58
query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE)
guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Regex misses plural ‘numbers’; guesses won’t register.

Pattern only matches “unique number”. It should accept both singular and plural.

Apply:

-    guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE)
+    guess_re = re.compile(r"there are (\d+)\s+unique numbers?\b", re.IGNORECASE)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE)
guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE)
query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE)
guess_re = re.compile(r"there are (\d+)\s+unique numbers?\b", re.IGNORECASE)
🤖 Prompt for AI Agents
In nemo_rl/environments/simulated_user/unique_numbers.py around lines 56 to 58,
the guess_re regex only matches the singular phrase "unique number" so plural
guesses like "unique numbers" won't register; update the pattern to accept both
singular and plural (e.g., make "number" optional "s" or use "numbers?" with
appropriate word boundaries) and keep the re.IGNORECASE flag so both "number"
and "numbers" are matched.

Comment on lines +185 to +201
grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1."
grading_response = asyncio.run(
run_prompt_async(
metadata["grader_runner"],
"grader",
grading_prompt,
silence=True,
)
)
try:
grade = int(re.search(r"(\d+)", grading_response).group(1))
reward = (reward + grade) / 2.0
except Exception as e:
print(
f"Failed to parse grade from grader response '{grading_response}': {e}"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Grader integration: fix typo and parse numeric [0,1] robustly.

  • “converstation” → “conversation”.
  • Parse floats, clamp to [0,1]; don’t cast to int (drops partial credit).

Apply:

-                grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1."
+                grading_prompt = f"Here is the conversation:\n{convo_str}\nPlease return only a numeric score between 0 and 1."
                 grading_response = asyncio.run(
                     run_prompt_async(
                         metadata["grader_runner"],
                         "grader",
                         grading_prompt,
                         silence=True,
                     )
                 )
                 try:
-                    grade = int(re.search(r"(\d+)", grading_response).group(1))
-                    reward = (reward + grade) / 2.0
+                    m = re.search(r"\b([01](?:\.\d+)?)\b", grading_response)
+                    if m:
+                        grade = float(m.group(1))
+                        grade = max(0.0, min(1.0, grade))
+                        reward = 0.5 * (reward + grade)
+                    else:
+                        raise ValueError("No numeric score found")
                 except Exception as e:
                     print(
                         f"Failed to parse grade from grader response '{grading_response}': {e}"
                     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1."
grading_response = asyncio.run(
run_prompt_async(
metadata["grader_runner"],
"grader",
grading_prompt,
silence=True,
)
)
try:
grade = int(re.search(r"(\d+)", grading_response).group(1))
reward = (reward + grade) / 2.0
except Exception as e:
print(
f"Failed to parse grade from grader response '{grading_response}': {e}"
)
grading_prompt = f"Here is the conversation:\n{convo_str}\nPlease return only a numeric score between 0 and 1."
grading_response = asyncio.run(
run_prompt_async(
metadata["grader_runner"],
"grader",
grading_prompt,
silence=True,
)
)
try:
m = re.search(r"\b([01](?:\.\d+)?)\b", grading_response)
if m:
grade = float(m.group(1))
grade = max(0.0, min(1.0, grade))
reward = 0.5 * (reward + grade)
else:
raise ValueError("No numeric score found")
except Exception as e:
print(
f"Failed to parse grade from grader response '{grading_response}': {e}"
)
🧰 Tools
🪛 Ruff (0.14.1)

197-197: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In nemo_rl/environments/simulated_user/unique_numbers.py around lines 185 to
201, fix the typo in the grading prompt ("converstation" → "conversation") and
change the grade parsing to robustly extract a float (allow integers or
decimals, e.g., via regex for optional decimal), do not cast to int, clamp the
parsed value to the [0.0, 1.0] range, compute reward = (reward + parsed_grade) /
2.0, and keep the existing exception handling but include the raw
grading_response in the log; ensure any ValueError or None from regex is handled
gracefully and does not crash the function.

Comment on lines +228 to +236
@ray.remote
class UniqueNumbersEnv(EnvironmentInterface):
"""Environment where the LLM must deduce the count of unique numbers."""

def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
cfg = cfg or UniqueNumbersConfig()
self.min_length = cfg.get("min_length", 3)
self.max_length = cfg.get("max_length", 7)
self.default_max_turns = cfg.get("max_turns", 10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

@ray.remote requires pragma and avoid hidden defaults.

  • Add “# pragma: no cover” on the class line.
  • Do not set non-None config defaults in code; YAML is the single source of truth. Either require cfg or read keys directly without defaults. Current self.min_length/max_length/default_max_turns are unused; remove or read from cfg directly.

Apply:

-@ray.remote
-class UniqueNumbersEnv(EnvironmentInterface):
+@ray.remote
+class UniqueNumbersEnv(EnvironmentInterface):  # pragma: no cover
@@
-    def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
-        cfg = cfg or UniqueNumbersConfig()
-        self.min_length = cfg.get("min_length", 3)
-        self.max_length = cfg.get("max_length", 7)
-        self.default_max_turns = cfg.get("max_turns", 10)
+    def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
+        if cfg is None:
+            raise ValueError("cfg is required; defaults must be provided via YAML.")
+        # If needed later, access required keys directly:
+        # self.min_length = cfg["min_length"]; self.max_length = cfg["max_length"]; self.default_max_turns = cfg["max_turns"]

As per coding guidelines.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@ray.remote
class UniqueNumbersEnv(EnvironmentInterface):
"""Environment where the LLM must deduce the count of unique numbers."""
def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
cfg = cfg or UniqueNumbersConfig()
self.min_length = cfg.get("min_length", 3)
self.max_length = cfg.get("max_length", 7)
self.default_max_turns = cfg.get("max_turns", 10)
@ray.remote
class UniqueNumbersEnv(EnvironmentInterface): # pragma: no cover
"""Environment where the LLM must deduce the count of unique numbers."""
def __init__(self, cfg: Optional[UniqueNumbersConfig] = None):
if cfg is None:
raise ValueError("cfg is required; defaults must be provided via YAML.")
# If needed later, access required keys directly:
# self.min_length = cfg["min_length"]
# self.max_length = cfg["max_length"]
# self.default_max_turns = cfg["max_turns"]
🤖 Prompt for AI Agents
nemo_rl/environments/simulated_user/unique_numbers.py around lines 228-236: the
@ray.remote class decorator needs a coverage pragma and the constructor must not
embed hidden defaults from YAML; add "# pragma: no cover" to the @ray.remote
line, and stop assigning non-None fallback values inside __init__. Either make
cfg required (remove the Optional and raise a clear error if None) or read
config keys directly without providing hardcoded defaults (e.g., assign
self.min_length = cfg["min_length"] etc.), and if those attributes are actually
unused remove them altogether; ensure no hidden defaults remain in code and
update the type signature and any callers accordingly.

@terrykong terrykong removed the r0.4.0 label Oct 28, 2025
@ahmadki ahmadki force-pushed the ahmadki/simulated-user-rec branch 2 times, most recently from 9c5ff7a to 4522038 Compare November 22, 2025 20:37
@ahmadki ahmadki added the CI:L1 Run doctests, unit tests, and functional tests label Nov 22, 2025
@ahmadki ahmadki changed the title DRAFT: feat: Enable simulated user for multi-turn GRPO feat: Enable simulated user for multi-turn GRPO Nov 22, 2025
@ahmadki ahmadki force-pushed the ahmadki/simulated-user-rec branch from 4522038 to 7efcaa1 Compare November 22, 2025 20:48
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
ahmadki and others added 6 commits November 22, 2025 22:50
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Mohammadreza Mohseni <mmohseni@google.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
@ahmadki ahmadki force-pushed the ahmadki/simulated-user-rec branch from 7efcaa1 to e495cb8 Compare November 22, 2025 20:50
@ahmadki ahmadki added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Nov 22, 2025
Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ahmadki for driving this. Left some comments

Could you resolve the test failures?

and hasattr(tokenizer, "bos_token_id")
and formatted_obs[0] == tokenizer.bos_token_id
):
formatted_obs = formatted_obs[1:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have another assert that the first token isn't bos?

).input_ids[0]
# Tokenize the raw content from the environment into chat format if needed
env_role = env_output.observations[i]["role"].lower()
if env_role in {"user", "assistant", "system"}:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have unit tests for both paths which will also demonstrate what the message log should look like?

Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
@ahmadki ahmadki added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 12, 2025
Signed-off-by: Ahmad Kiswani <kiswani.ahmad@gmail.com>
@ahmadki ahmadki force-pushed the ahmadki/simulated-user-rec branch from 649ad02 to d0d4629 Compare December 12, 2025 22:07
@ahmadki ahmadki added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants