Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion nemo_skills/pipeline/nemo_rl/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from dataclasses import dataclass
from enum import Enum
Expand All @@ -34,8 +35,10 @@
parse_kwargs,
resolve_mount_paths,
run_exp,
should_get_random_port,
temporary_env_update,
)
from nemo_skills.pipeline.utils.server import SupportedServers, get_free_port
from nemo_skills.utils import (
get_logger_name,
setup_logging,
Expand Down Expand Up @@ -280,6 +283,17 @@ def grpo_nemo_rl(
"Format: START:STOP (1-indexed, STOP exclusive, same as slice syntax arr[start:stop]). "
"Example: '3:5' profiles steps 3 and 4 only. NOTE: START must be ≥ 1, so '0:10' is invalid.",
),
server_model: str = typer.Option(None, help="Path to the model or model name in API for judge server"),
server_address: str = typer.Option(
None, help="Use ip:port for self-hosted models or the API url if using model providers"
),
server_type: SupportedServers = typer.Option(None, help="Type of server to use for judge"),
server_gpus: int = typer.Option(None, help="Number of GPUs to use if hosting the judge model"),
server_nodes: int = typer.Option(1, help="Number of nodes required for hosting judge LLM server"),
n_servers: int = typer.Option(
1, help="Number of independent judge servers to launch (each with server_nodes nodes)"
),
server_args: str = typer.Option("", help="Any extra arguments to pass to the judge server"),
partition: str = typer.Option(
None, help="Can specify if need interactive jobs or a specific non-default partition"
),
Expand Down Expand Up @@ -388,6 +402,41 @@ def grpo_nemo_rl(
if validation_data is not None:
validation_data = get_mounted_path(cluster_config, validation_data)

# Server configuration for LLM-as-a-judge
server_config = None
if server_type is not None:
get_random_port = should_get_random_port(server_gpus, exclusive)
if server_address is None: # we need to host the model
assert server_gpus is not None, "Need to specify server_gpus if hosting the model"
server_port = get_free_port(strategy="random") if get_random_port else 5000

server_config = {
"model_path": server_model,
"server_type": server_type,
"num_gpus": server_gpus,
"num_nodes": server_nodes,
"server_args": server_args,
"server_port": server_port,
}
Comment on lines +405 to +420
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing validation: server_model should be required when server_type is specified.

When server_type is provided but server_model is omitted (defaults to None), model_path=None is passed into get_server_command, producing a command string containing the literal string "None". This would fail at runtime with a confusing error.

Proposed fix
     server_config = None
     if server_type is not None:
+        if server_model is None:
+            raise ValueError("server_model is required when server_type is specified")
         get_random_port = should_get_random_port(server_gpus, exclusive)
🤖 Prompt for AI Agents
In `@nemo_skills/pipeline/nemo_rl/grpo.py` around lines 405 - 420, When
server_type is provided but you intend to host the model (server_address is
None), ensure server_model is required and non-empty to avoid passing
model_path=None into get_server_command; add a validation check (e.g., assert or
raise ValueError) before building server_config that server_model is not
None/empty, referencing the variables server_type, server_address, and
server_model and the block that constructs server_config so the code fails fast
with a clear message instead of producing "None" in the command.


client_server_args = {
"server_type": server_type.value,
"port": server_port,
"model": server_model,
"n_servers": n_servers,
}
else: # model is hosted elsewhere
assert n_servers == 1, "Only one server is supported when model is hosted elsewhere"
client_server_args = {
"server_type": server_type.value,
"host": server_address,
"model": server_model,
"n_servers": 1,
}
cluster_config["required_env_vars"] = cluster_config.get("required_env_vars", []) + [
f"JUDGE_SERVER_ARGS='{json.dumps(client_server_args)}'"
]

train_cmd = get_training_cmd(
cluster_config=cluster_config,
partition=partition,
Expand All @@ -408,7 +457,6 @@ def grpo_nemo_rl(
profile_step_range=profile_step_range,
)

server_config = None
env_update = {"RAY_LOG_SYNC_FREQUENCY": 20} if profile_step_range else {}
sbatch_kwargs = parse_kwargs(sbatch_kwargs, exclusive=exclusive, qos=qos, time_min=time_min)

Expand All @@ -426,6 +474,7 @@ def grpo_nemo_rl(
num_nodes=num_nodes,
cluster_config=cluster_config,
server_config=server_config,
n_servers=n_servers,
partition=partition,
run_after=run_after,
reuse_code=reuse_code,
Expand Down
101 changes: 62 additions & 39 deletions nemo_skills/pipeline/utils/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def add_task(
keep_mounts_for_sandbox=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gwarmstrong do we need to update the declarative code path to reflect these changes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will need to be updated when we want to use the feature on the declarative path, but at the moment I'm not sure there is value to adding it to the declarative path purely for parity sake.

Is there any way to ensure it is covered by some test case (gpu or slurm probably?) that way when we convert to declarative, we can make sure the functionality isn't dropped?

sandbox_port: int | None = None,
server_config=None,
n_servers: int = 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate n_servers inputs to avoid silent no-op behavior.

range(n_servers) silently skips server creation for 0/negative values, and n_servers is currently accepted even when server_config is None.

Proposed fix
 def add_task(
@@
     ray_template: str | None = None,
 ):
@@
+    if n_servers < 1:
+        raise ValueError("n_servers must be >= 1")
+    if server_config is None and n_servers != 1:
+        raise ValueError("n_servers is only supported when server_config is provided")

As per coding guidelines, "Avoid cases where user-passed parameters are unused; code should fail if user specifies an unsupported argument or if a required argument is missing."

Also applies to: 557-557

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/pipeline/utils/exp.py` at line 447, The parameter n_servers is
accepted but can be zero/negative or ignored when server_config is None (e.g.,
code using range(n_servers)), so add input validation at the start of the
function that uses n_servers: if n_servers is not a positive int raise a
ValueError; additionally, if n_servers > 0 ensure server_config is not None and
raise a ValueError if it is, so user-specified servers are never silently
ignored; update any code paths that iterate with range(n_servers) to rely on
this validation.

reuse_code_exp: str | run.Experiment | None = None,
reuse_code: bool = True,
task_dependencies: list[str] = None,
Expand Down Expand Up @@ -528,51 +529,68 @@ def add_task(

het_group = 0
het_group_indices = []
total_het_groups = (server_config is not None) + bool(cmd) + with_sandbox
total_het_groups = (n_servers if server_config is not None else 0) + bool(cmd) + with_sandbox

LOG.info("Adding a task with commands:")

commands = []
executors = []
# assuming server always has the largest resources request, so it needs to go first
if server_config is not None and int(server_config["num_gpus"]) > 0:

# Check if we need to add server first to ensure SLURM allocates GPU partition
# This happens when the client doesn't need GPUs but the server does
server_needs_gpus = server_config is not None and int(server_config.get("num_gpus", 0)) > 0
client_num_gpus = num_gpus or 0
# For ray heterogenous jobs, nemo-run assumes the first het group is the main task
# So we send the server last if the job needs gpus
server_goes_first = server_needs_gpus and not client_num_gpus

Comment on lines +542 to +546
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Compute effective client GPU count once and reuse it.

server_goes_first/server overlap are decided from Line 542, but main executor GPU allocation is recomputed at Line 609 with different logic. This can make ordering and overlap inconsistent with actual resource requests.

Proposed fix
-    client_num_gpus = num_gpus or 0
+    client_num_gpus = (num_gpus or 0) if (server_config is None or num_nodes > 1) else 0
     # For ray heterogenous jobs, nemo-run assumes the first het group is the main task
     # So we send the server last if the job needs gpus
     server_goes_first = server_needs_gpus and not client_num_gpus
...
-                client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0
                 executors.append(
                     get_executor(
...
                         gpus_per_node=client_num_gpus,
...
                         overlap=(not client_num_gpus),  # Only when the main task does not have gpus

As per coding guidelines, "Keep code simple and elegant; reuse/extend existing functionality when possible, minimize conditional checks..."

Also applies to: 576-576, 609-629

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/pipeline/utils/exp.py` around lines 542 - 546, Compute the
effective client GPU count once (e.g., rename/assign client_num_gpus ->
effective_client_gpus using the same num_gpus/server_needs_gpus inputs) and
reuse that variable everywhere instead of recomputing; update the server
ordering boolean (server_goes_first) to use effective_client_gpus and replace
the later GPU-allocation logic that currently recomputes client GPUs (the block
around the "main executor GPU allocation" code) to reference
effective_client_gpus so ordering, overlap, and allocation decisions are
consistent across client_num_gpus, server_goes_first, server_needs_gpus and the
main executor allocation logic.

def add_server_tasks():
nonlocal het_group
# avoid mutating server_config, as it may be used again later in dependent jobs
server_config = copy.deepcopy(server_config)
_server_config = copy.deepcopy(server_config)
# do not pass container into the command builder
# NOTE: avoid evaluating default (which would index cluster_config) unless needed
server_container = server_config.pop("container", None)
server_container = _server_config.pop("container", None)
if server_container is None:
server_container = cluster_config["containers"][server_config["server_type"]]
server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=server_config["num_gpus"],
partition=partition,
account=account,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix="server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
with_ray=with_ray,
ray_template=ray_template,
)
if cluster_config["executor"] != "slurm" and num_server_tasks > 1:
server_cmd = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}"
commands.append(server_cmd)
executors.append(server_executor)
het_group_indices.append(het_group)
het_group += 1
LOG.info("Server command: %s", server_cmd)
server_container = cluster_config["containers"][_server_config["server_type"]]

# then goes the main task(s) unless it's empty
for server_idx in range(n_servers):
server_cmd, num_server_tasks = get_server_command(**_server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=_server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=_server_config["num_gpus"],
partition=partition,
account=account,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix=f"server_{server_idx}" if n_servers > 1 else "server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus
with_ray=False,
ray_template=ray_template,
)
cmd_to_add = server_cmd
if cluster_config["executor"] != "slurm" and num_server_tasks > 1:
cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}"
commands.append(cmd_to_add)
executors.append(server_executor)
het_group_indices.append(het_group)
het_group += 1
LOG.info("Server %d command: %s", server_idx, server_cmd)
Comment on lines +557 to +587
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all servers launched in the loop use the same server_port from server_config, causing port conflicts when n_servers > 1

each server instance needs a unique port

Suggested change
for server_idx in range(n_servers):
server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=server_config["num_gpus"],
partition=partition,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix=f"server_{server_idx}" if n_servers > 1 else "server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus
with_ray=False,
ray_template=ray_template,
)
cmd_to_add = server_cmd
if cluster_config["executor"] != "slurm" and num_server_tasks > 1:
cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}"
commands.append(cmd_to_add)
executors.append(server_executor)
het_group_indices.append(het_group)
het_group += 1
LOG.info("Server %d command: %s", server_idx, server_cmd)
for server_idx in range(n_servers):
# Get a unique port for each server if launching multiple
current_server_config = server_config.copy()
if n_servers > 1:
current_server_config["server_port"] = get_free_port(strategy="random")
server_cmd, num_server_tasks = get_server_command(**current_server_config, cluster_config=cluster_config)
server_executor = get_executor(
cluster_config=cluster_config,
container=server_container,
num_nodes=current_server_config["num_nodes"],
tasks_per_node=num_server_tasks,
gpus_per_node=current_server_config["num_gpus"],
partition=partition,
dependencies=dependencies,
job_name=task_name,
log_dir=log_dir,
log_prefix=f"server_{server_idx}" if n_servers > 1 else "server",
extra_package_dirs=extra_package_dirs,
sbatch_kwargs=sbatch_kwargs,
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus
with_ray=False,
ray_template=ray_template,
)


# If client doesn't need GPUs but server does, add server first so SLURM allocates GPU partition
if server_goes_first:
add_server_tasks()

Comment on lines +589 to +592
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Server-first ordering breaks sandbox node selection assumptions.

When Line 590 runs, executors[0] becomes a server executor, but Line 659 still derives sandbox num_nodes from executors[0]. That can size sandbox by server nodes instead of main-task nodes.

Proposed fix
-                num_nodes=executors[0].nodes if cluster_config["executor"] == "slurm" else 1,
+                num_nodes=num_nodes if cluster_config["executor"] == "slurm" else 1,

Also applies to: 659-659

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@nemo_skills/pipeline/utils/exp.py` around lines 589 - 592, The server-first
insertion (server_goes_first -> add_server_tasks) mutates executors so later
code that computes sandbox num_nodes from executors[0] can pick the server's
node count; compute the sandbox node count from the intended main-task executor
before potentially calling add_server_tasks (or locate the first non-server/main
executor instead of using executors[0]) so num_nodes is derived from the main
task; update the logic around server_goes_first, add_server_tasks, and the
sandbox num_nodes calculation to use that precomputed/main-task executor
reference (referencing symbols: server_goes_first, add_server_tasks, executors,
and num_nodes).

# Then goes the main task(s) unless it's empty
if cmd:
if isinstance(cmd, str):
cmd = [cmd]
Expand All @@ -588,13 +606,14 @@ def add_task(
with temporary_env_update(cluster_config, {"NEMO_SKILLS_SANDBOX_PORT": sandbox_port}):
cur_cmd = install_packages_wrap(cur_cmd, installation_command)
commands.append(cur_cmd)
client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

client_num_gpus is calculated here inside the loop, but was already defined at line 531. This shadows the outer variable and is calculated inside the wrong scope (should be outside the for loop at lines 588-617).

Suggested change
client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0
client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0

Move this line before line 588 (before the for cur_idx, (cur_cmd... loop starts).

executors.append(
get_executor(
cluster_config=cluster_config,
container=cur_container,
num_nodes=num_nodes,
tasks_per_node=cur_tasks,
gpus_per_node=num_gpus if server_config is None else 0,
gpus_per_node=client_num_gpus,
partition=partition,
account=account,
dependencies=dependencies,
Expand All @@ -606,7 +625,7 @@ def add_task(
heterogeneous=heterogeneous,
het_group=het_group,
total_het_groups=total_het_groups,
overlap=(server_config is not None) or with_sandbox,
overlap=(not client_num_gpus), # Only when the main task does not have gpus
with_ray=with_ray,
ray_template=ray_template,
)
Expand All @@ -615,7 +634,7 @@ def add_task(
het_group += 1
LOG.info("Main command(s): %s", ", ".join(cmd))

# finally a sandbox if needed
# Then a sandbox if needed
if with_sandbox:
sandbox_env_updates = {
"LISTEN_PORT": sandbox_port,
Expand Down Expand Up @@ -653,7 +672,7 @@ def add_task(
het_group=het_group,
total_het_groups=total_het_groups,
overlap=True,
with_ray=with_ray,
with_ray=False,
ray_template=ray_template,
# Allow the sandbox to survive individual worker crashes (e.g. SIGILL
# from libraries compiled for a different CPU). nemo-run hardcodes
Expand All @@ -673,6 +692,10 @@ def add_task(
het_group += 1
LOG.info("Sandbox command: %s", commands[-1])

# If server wasn't added first (because client needs GPUs or server doesn't need GPUs), add it now
if server_config is not None and not server_goes_first:
add_server_tasks()

if cluster_config["executor"] != "none":
tunnel = get_tunnel(cluster_config)
if reuse_code:
Expand Down