Skip to content

Add Gym GRPO example for multi-step tool calling with Workplace Assistant environment#1625

Closed
shashank3959 wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
shashank3959:dev/shashankv-add-grpo-workplace-asst
Closed

Add Gym GRPO example for multi-step tool calling with Workplace Assistant environment#1625
shashank3959 wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
shashank3959:dev/shashankv-add-grpo-workplace-asst

Conversation

@shashank3959
Copy link
Contributor

@shashank3959 shashank3959 commented Dec 11, 2025

What does this PR do ?

Add Gym GRPO example for multi-step tool calling with Workplace Assistant environment

Issues

[List issues that this PR closes (syntax):](NVIDIA-NeMo/Gym#376)

Usage

Mentioned in the README

EXP_NAME="$(date +%Y%m%d)/nemo_gym_grpo/nemotron_nano_v2_9b/workplace_assistant_001"

# Configuration file path
CONFIG_PATH=examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml

# Launch training
# Set these environment variables before running:
TORCH_CUDA_ARCH_LIST="9.0 10.0" \
HF_HOME=.cache/ \
HF_TOKEN="YOUR_HUGGINGFACE_TOKEN" \
WANDB_API_KEY="YOUR_WANDB_API_KEY" \
NRL_FORCE_REBUILD_VENVS=true \
VLLM_LOGGING_LEVEL=ERROR \
uv run python examples/nemo_gym/run_grpo_nemo_gym.py \
    --config=$CONFIG_PATH \
    logger.wandb.project="${USER}-nemo-gym-rl-integration" \
    logger.wandb.name=$EXP_NAME \
    logger.log_dir=results/$EXP_NAME

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Added example configurations for training workplace assistant models with multiple model variants (Nemotron Nano and Qwen3).
    • Introduced support for Nemotron JSON-based tool calling in the generation backend.

✏️ Tip: You can customize this high-level summary in your review settings.

Add GRPO training configuration for the Workplace Assistant agentic
tool-use environment using Qwen3-4B-Instruct model.

Configuration details:
- Model: Qwen/Qwen3-4B-Instruct-2507 with Hermes tool parser
- Dataset: nvidia/Nemotron-RL-agent-workplace_assistant (1129 train, 126 val)
- Environment: 26 tools across 5 categories (email, calendar, CRM, etc.)
- Agent: workplace_assistant_simple_agent with 6 max tool-calling steps
- Training: Single-turn rollouts with up to 6 tool calls per episode
- VLLM: Auto-tools enabled for efficient multi-tool handling
- Hardware: Configured for single-node (8 GPUs) with DTensor (TP=2)

Default training settings:
- 64 prompts/step × 16 generations = 1024 rollouts/step
- Max context: 32k tokens
- Validation every 10 steps
- Checkpointing enabled with top-3 by validation accuracy

Data preparation requires:
1. Download dataset using download_workplace_assistant.py
2. Run ng_prepare_data with workplace_assistant.yaml config

Signed-off-by: Shashank Verma <shashank3959@gmail.com>
- Add nemotron_toolcall_parser_no_streaming.py for Nemotron Nano v2 tool calling
- Register nemotron_json parser in vllm_worker_async.py on import
- Fix edge case in _replace_prefix_tokens for mismatched prefix lengths
- Add GRPO configs for workplace assistant with Nemotron Nano v2 9B and 12B
- Set num_prompts_per_step=8, max_num_steps=36 for ~1 epoch
- Set val_period=6, save_period=6 (aligned)
- Set num_nodes=1 for single node default
Signed-off-by: Shashank Verma <shashank3959@gmail.com>
@shashank3959 shashank3959 requested review from a team as code owners December 11, 2025 11:15
Signed-off-by: Shashank Verma <shashank3959@gmail.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 11, 2025

📝 Walkthrough

Walkthrough

Introduces GRPO workplace assistant training configurations for three model variants (Nemotron Nano 12B, Nemotron Nano 9B, Qwen3-4B Instruct), adds a non-streaming Nemotron JSON tool-call parser for vLLM integration, and refines prefix token detection logic in the vLLM async worker.

Changes

Cohort / File(s) Summary
GRPO Workplace Assistant Configurations
examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_12b.yaml, examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml, examples/nemo_gym/grpo_workplace_assistant_qwen3_4binstruct.yaml
Adds three comprehensive YAML configurations for GRPO training with distributed Megatron settings, vLLM generation backends, dynamic batching, sequence packing, tool-calling workflows, data paths, and logging integrations (W&B, TensorBoard, MLflow). Each targets a different model variant with corresponding training hyperparameters and resource allocation.
Nemotron Tool Parser
nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py
Introduces NemotronJSONToolParser class registered as "nemotron_json" that extracts tool calls from Nemotron-formatted <TOOLCALL> JSON blocks in vLLM model outputs, converting them to standardized ToolCall objects with error handling and logging. Streaming tool extraction is explicitly unsupported.
vLLM Worker Refinement
nemo_rl/models/generation/vllm/vllm_worker_async.py
Registers the Nemotron JSON tool parser at import time (with graceful fallback if unavailable) and refines _replace_prefix_tokens to constrain EOS token detection to the common prefix length between template prefix and full template token IDs, adding bounds checking to prevent out-of-range access.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~70 minutes

  • YAML configuration files: Each defines extensive nested hyperparameters for GRPO, Megatron distributed training, vLLM generation, and logging; three files with similar structure but distinct model-specific values require cross-checking of references, constraints, and parameter interactions (~20 min per file).
  • Nemotron tool parser: Non-trivial JSON extraction and error handling logic with regex parsing and state management; correctness of tool call conversion, malformed entry skipping, and fallback behavior needs careful verification (~15 min).
  • vLLM worker prefix token refinement: Bounds-checking logic and template token handling warrant close inspection to ensure correctness and prevent edge-case failures (~10 min).

Possibly related PRs

Suggested reviewers

  • terrykong

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning PR introduces major changes including three GRPO configuration files and a new tool parser, but provides no test results, performance metrics, or validation evidence. Document test results showing successful GRPO training execution, NemotronJSONToolParser validation, multi-step tool calling verification, and relevant convergence metrics.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly corresponds to the main change: adding GRPO examples for multi-step tool calling with the Workplace Assistant environment, which is exactly what the configuration files and tool parser implement.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
nemo_rl/models/generation/vllm/vllm_worker_async.py (1)

28-40: Narrow the exception type for optional parser registration

Catching a blanket Exception here will also swallow real bugs inside nemotron_toolcall_parser_no_streaming and silently disable the parser. If the intent is just to make the dependency optional, this should be limited to ImportError (and optionally log other exceptions).

-try:  # pragma: no cover - optional runtime dependency
-    from nemo_rl.models.generation.vllm import (  # noqa: F401
-        nemotron_toolcall_parser_no_streaming,
-    )
-except Exception:
-    # If the module is missing for some reason, vLLM will simply not
-    # have the `nemotron_json` parser registered and will raise a
-    # clear error when it is requested.
-    pass
+try:  # pragma: no cover - optional runtime dependency
+    from nemo_rl.models.generation.vllm import (  # noqa: F401
+        nemotron_toolcall_parser_no_streaming,
+    )
+except ImportError:
+    # If the module is missing, vLLM will simply not have the
+    # `nemotron_json` parser registered and will raise a clear error
+    # when it is requested.
+    pass

As per coding guidelines, narrowing except to specific error types is preferred.

nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py (2)

73-121: Tighten exception handling in tool-call parsing

The overall structure is good, but both the outer and inner except Exception blocks are broader than needed and conflict with the guideline to catch specific error types. They also make it harder to distinguish genuinely broken outputs from unexpected programming errors.

Consider narrowing and optionally adding a lightweight debug log for malformed entries:

-        try:
-            # Grab the JSON substring inside the TOOLCALL tags.
-            str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip()
+        try:
+            # Grab the JSON substring inside the TOOLCALL tags.
+            matches = self.tool_call_regex.findall(model_output)
+            if not matches:
+                raise ValueError("No <TOOLCALL>...</TOOLCALL> block found.")
+            str_tool_calls = matches[0].strip()
@@
-            json_tool_calls = json.loads(str_tool_calls)
+            json_tool_calls = json.loads(str_tool_calls)
             tool_calls: list[ToolCall] = []
             for tool_call in json_tool_calls:
-                try:
+                try:
                     args = tool_call.get("arguments")
                     if isinstance(args, dict):
                         args_str = json.dumps(args, ensure_ascii=False)
                     else:
-                        args_str = args
+                        args_str = args if isinstance(args, str) else json.dumps(
+                            args, ensure_ascii=False
+                        )
@@
-                except Exception:
-                    # Skip malformed tool call entries rather than failing hard.
-                    continue
+                except (KeyError, TypeError, ValueError) as exc:
+                    # Skip malformed tool call entries rather than failing hard.
+                    logger.debug(
+                        "Skipping malformed tool call entry %r: %s",
+                        tool_call,
+                        exc,
+                    )
+                    continue
@@
-        except Exception:
+        except (json.JSONDecodeError, IndexError, TypeError, ValueError) as exc:
             logger.exception(
-                "Error in extracting tool call from response. Response: %s",
-                model_output,
+                "Error in extracting tool call from response: %s. Response: %s",
+                exc,
+                model_output,
             )

This keeps the robustness you want for malformed model outputs but avoids silently swallowing unexpected programming errors. As per coding guidelines.


57-61: Unused request parameter and static-analysis noise

extract_tool_calls doesn’t use the request argument, but it must stay in the signature to conform to the ToolParser interface. To make that intent explicit and silence ARG002, you can either mark it unused or delete it locally:

     def extract_tool_calls(
         self,
         model_output: str,
-        request: ChatCompletionRequest,
+        request: ChatCompletionRequest,  # noqa: ARG002
@@
-        """
-        Non-streaming extraction: look for a single <TOOLCALL>...</TOOLCALL>
-        block containing a JSON list of tool calls.
-        """
+        """
+        Non-streaming extraction: look for a single <TOOLCALL>...</TOOLCALL>
+        block containing a JSON list of tool calls.
+        """
+        # `request` is unused in this simplified parser but kept for API compatibility.

S105 around the TOOLCALL markers is a false positive in this context; those literals are protocol tokens, not secrets.

examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_12b.yaml (1)

180-215: Confirm intentional cap of max_new_tokens at 2048 vs full context

In this 12B config, policy.max_total_sequence_length and vllm_cfg.max_model_len are 32768, but generation.max_new_tokens is set to 2048. That’s perfectly valid if you want to bound rollout length and memory, but it does diverge from other GRPO/distillation configs that often use the full context window for max_new_tokens.

Please double‑check that 2048 is the intended limit for this workload (vs ${policy.max_total_sequence_length}) given your training budget and task complexity. The tool_parser: nemotron_json wiring itself looks correct. Based on learnings.

examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml (1)

225-233: Verify workspace/data path consistency for the 9B config

Here the train/validation paths point to 3rdparty/Gym-workspace/Gym/..., whereas the 12B and Qwen configs in this PR reference 3rdparty/Penguin-workspace/Penguin/... for the same workplace assistant data.

If both workspaces exist this is fine, but if the intent was to share the same dataset location across all three examples, please double‑check that these paths and workspace names are correct and consistent.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e3cfb11 and f3c8e46.

📒 Files selected for processing (5)
  • examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_12b.yaml (1 hunks)
  • examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml (1 hunks)
  • examples/nemo_gym/grpo_workplace_assistant_qwen3_4binstruct.yaml (1 hunks)
  • nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py (1 hunks)
  • nemo_rl/models/generation/vllm/vllm_worker_async.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml
  • nemo_rl/models/generation/vllm/vllm_worker_async.py
  • nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py
  • examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_12b.yaml
  • examples/nemo_gym/grpo_workplace_assistant_qwen3_4binstruct.yaml
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/generation/vllm/vllm_worker_async.py
  • nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/generation/vllm/vllm_worker_async.py
  • nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/generation/vllm/vllm_worker_async.py
  • nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py
🧠 Learnings (4)
📚 Learning: 2025-09-19T03:00:58.662Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.

Applied to files:

  • examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml
  • examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_12b.yaml
📚 Learning: 2025-09-10T05:34:35.406Z
Learnt from: bxyu-nvidia
Repo: NVIDIA-NeMo/RL PR: 1110
File: nemo_rl/models/generation/vllm/vllm_worker_async.py:346-359
Timestamp: 2025-09-10T05:34:35.406Z
Learning: In nemo_rl/models/generation/vllm/vllm_worker_async.py, the HTTP server intentionally uses different path structures: `/v1/chat/completions` is under the `/v1` prefix while `/tokenize` is at the root level without the `/v1` prefix. This is the intended design.

Applied to files:

  • nemo_rl/models/generation/vllm/vllm_worker_async.py
📚 Learning: 2025-09-10T05:29:34.349Z
Learnt from: bxyu-nvidia
Repo: NVIDIA-NeMo/RL PR: 1110
File: nemo_rl/models/generation/vllm/vllm_worker_async.py:98-105
Timestamp: 2025-09-10T05:29:34.349Z
Learning: In the _maybe_correct_merged_tokens function in nemo_rl/models/generation/vllm/vllm_worker_async.py, the loop condition `len(candidate_token_ids) < len(actual_token_ids) - 1` is intentionally designed to prevent accessing the final token in actual_token_ids, likely to handle specific tokenization edge cases in the vLLM HTTP server integration.

Applied to files:

  • nemo_rl/models/generation/vllm/vllm_worker_async.py
📚 Learning: 2025-09-18T14:57:31.003Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: nemo_rl/algorithms/distillation.py:312-354
Timestamp: 2025-09-18T14:57:31.003Z
Learning: The distillation algorithm's cluster setup logic is designed to follow the same patterns used in GRPO for handling distributed training clusters and resource allocation.

Applied to files:

  • examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_12b.yaml
🪛 GitHub Actions: Copyright check
nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py

[error] 1-1: Found files with missing copyright: path= ./nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py

🪛 Ruff (0.14.8)
nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py

50-50: Possible hardcoded password assigned to: "tool_call_start_token"

(S105)


51-51: Possible hardcoded password assigned to: "tool_call_end_token"

(S105)


60-60: Unused method argument: request

(ARG002)


100-102: try-except-continue detected, consider logging the exception

(S112)


100-100: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (2)
nemo_rl/models/generation/vllm/vllm_worker_async.py (1)

123-131: EOS prefix search refinement looks correct

The updated search over range(max_pos) (min of prefix/template lengths) is a sensible guard against pathological length mismatches and still finds the last EOS within the shared prefix region as intended. No functional issues spotted here.

examples/nemo_gym/grpo_workplace_assistant_qwen3_4binstruct.yaml (1)

54-206: Qwen GRPO + vLLM + hermes tool parser config looks consistent

The Qwen3-4B Instruct setup (dtensor-enabled policy, full-context max_new_tokens, make_sequence_length_divisible_by tied to tensor parallel size, and vLLM HTTP server with tool_parser: hermes + enable_auto_tools: true) is internally consistent and aligns with the patterns used in other GRPO configs.

No issues spotted in this configuration section.

Comment on lines +1 to +9
"""
Nemotron JSON tool parser for vLLM, adapted from
`nvidia/NVIDIA-Nemotron-Nano-9B-v2` on Hugging Face:
`https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2/blob/main/nemotron_toolcall_parser_no_streaming.py`.

The original file is licensed under the NVIDIA Open Model License /
Apache-2.0-equivalent terms; this variant is trimmed to the pieces needed
for NeMo RL non-streaming tool calling.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add NVIDIA copyright header to satisfy CI and licensing

This new Python file is missing the standard NVIDIA copyright header, which is required for non-test .py files and is already causing the copyright check to fail.

You can mirror the header used elsewhere (e.g., vllm_worker_async.py):

+ # Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+ #
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ #     http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
+
 """
 Nemotron JSON tool parser for vLLM, adapted from
 `nvidia/NVIDIA-Nemotron-Nano-9B-v2` on Hugging Face:

As per coding guidelines and the pipeline failure.

🧰 Tools
🪛 GitHub Actions: Copyright check

[error] 1-1: Found files with missing copyright: path= ./nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py

🤖 Prompt for AI Agents
nemo_rl/models/generation/vllm/nemotron_toolcall_parser_no_streaming.py lines
1-9: the file is missing the required NVIDIA copyright header used across
non-test .py files which is causing CI/license checks to fail; add the same
standard NVIDIA copyright/header block used in other files (for example
vllm_worker_async.py) at the top of this file, preserving the license text, year
and copyright owner format, and ensure there is a blank line after the header
before the module docstring.

@bxyu-nvidia
Copy link
Contributor

Merging the changes in the PR via #1630

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants