diff --git a/nemo_skills/pipeline/nemo_rl/grpo.py b/nemo_skills/pipeline/nemo_rl/grpo.py index eeeff72474..aff0891d3d 100644 --- a/nemo_skills/pipeline/nemo_rl/grpo.py +++ b/nemo_skills/pipeline/nemo_rl/grpo.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from dataclasses import dataclass from enum import Enum @@ -34,8 +35,10 @@ parse_kwargs, resolve_mount_paths, run_exp, + should_get_random_port, temporary_env_update, ) +from nemo_skills.pipeline.utils.server import SupportedServers, get_free_port from nemo_skills.utils import ( get_logger_name, setup_logging, @@ -280,6 +283,17 @@ def grpo_nemo_rl( "Format: START:STOP (1-indexed, STOP exclusive, same as slice syntax arr[start:stop]). " "Example: '3:5' profiles steps 3 and 4 only. NOTE: START must be ≥ 1, so '0:10' is invalid.", ), + server_model: str = typer.Option(None, help="Path to the model or model name in API for judge server"), + server_address: str = typer.Option( + None, help="Use ip:port for self-hosted models or the API url if using model providers" + ), + server_type: SupportedServers = typer.Option(None, help="Type of server to use for judge"), + server_gpus: int = typer.Option(None, help="Number of GPUs to use if hosting the judge model"), + server_nodes: int = typer.Option(1, help="Number of nodes required for hosting judge LLM server"), + n_servers: int = typer.Option( + 1, help="Number of independent judge servers to launch (each with server_nodes nodes)" + ), + server_args: str = typer.Option("", help="Any extra arguments to pass to the judge server"), partition: str = typer.Option( None, help="Can specify if need interactive jobs or a specific non-default partition" ), @@ -388,6 +402,41 @@ def grpo_nemo_rl( if validation_data is not None: validation_data = get_mounted_path(cluster_config, validation_data) + # Server configuration for LLM-as-a-judge + server_config = None + if server_type is not None: + get_random_port = should_get_random_port(server_gpus, exclusive) + if server_address is None: # we need to host the model + assert server_gpus is not None, "Need to specify server_gpus if hosting the model" + server_port = get_free_port(strategy="random") if get_random_port else 5000 + + server_config = { + "model_path": server_model, + "server_type": server_type, + "num_gpus": server_gpus, + "num_nodes": server_nodes, + "server_args": server_args, + "server_port": server_port, + } + + client_server_args = { + "server_type": server_type.value, + "port": server_port, + "model": server_model, + "n_servers": n_servers, + } + else: # model is hosted elsewhere + assert n_servers == 1, "Only one server is supported when model is hosted elsewhere" + client_server_args = { + "server_type": server_type.value, + "host": server_address, + "model": server_model, + "n_servers": 1, + } + cluster_config["required_env_vars"] = cluster_config.get("required_env_vars", []) + [ + f"JUDGE_SERVER_ARGS='{json.dumps(client_server_args)}'" + ] + train_cmd = get_training_cmd( cluster_config=cluster_config, partition=partition, @@ -408,7 +457,6 @@ def grpo_nemo_rl( profile_step_range=profile_step_range, ) - server_config = None env_update = {"RAY_LOG_SYNC_FREQUENCY": 20} if profile_step_range else {} sbatch_kwargs = parse_kwargs(sbatch_kwargs, exclusive=exclusive, qos=qos, time_min=time_min) @@ -426,6 +474,7 @@ def grpo_nemo_rl( num_nodes=num_nodes, cluster_config=cluster_config, server_config=server_config, + n_servers=n_servers, partition=partition, run_after=run_after, reuse_code=reuse_code, diff --git a/nemo_skills/pipeline/utils/exp.py b/nemo_skills/pipeline/utils/exp.py index 023f8b2350..a47f53cfc6 100644 --- a/nemo_skills/pipeline/utils/exp.py +++ b/nemo_skills/pipeline/utils/exp.py @@ -444,6 +444,7 @@ def add_task( keep_mounts_for_sandbox=False, sandbox_port: int | None = None, server_config=None, + n_servers: int = 1, reuse_code_exp: str | run.Experiment | None = None, reuse_code: bool = True, task_dependencies: list[str] = None, @@ -528,51 +529,68 @@ def add_task( het_group = 0 het_group_indices = [] - total_het_groups = (server_config is not None) + bool(cmd) + with_sandbox + total_het_groups = (n_servers if server_config is not None else 0) + bool(cmd) + with_sandbox LOG.info("Adding a task with commands:") commands = [] executors = [] - # assuming server always has the largest resources request, so it needs to go first - if server_config is not None and int(server_config["num_gpus"]) > 0: + + # Check if we need to add server first to ensure SLURM allocates GPU partition + # This happens when the client doesn't need GPUs but the server does + server_needs_gpus = server_config is not None and int(server_config.get("num_gpus", 0)) > 0 + client_num_gpus = num_gpus or 0 + # For ray heterogenous jobs, nemo-run assumes the first het group is the main task + # So we send the server last if the job needs gpus + server_goes_first = server_needs_gpus and not client_num_gpus + + def add_server_tasks(): + nonlocal het_group # avoid mutating server_config, as it may be used again later in dependent jobs - server_config = copy.deepcopy(server_config) + _server_config = copy.deepcopy(server_config) # do not pass container into the command builder # NOTE: avoid evaluating default (which would index cluster_config) unless needed - server_container = server_config.pop("container", None) + server_container = _server_config.pop("container", None) if server_container is None: - server_container = cluster_config["containers"][server_config["server_type"]] - server_cmd, num_server_tasks = get_server_command(**server_config, cluster_config=cluster_config) - server_executor = get_executor( - cluster_config=cluster_config, - container=server_container, - num_nodes=server_config["num_nodes"], - tasks_per_node=num_server_tasks, - gpus_per_node=server_config["num_gpus"], - partition=partition, - account=account, - dependencies=dependencies, - job_name=task_name, - log_dir=log_dir, - log_prefix="server", - extra_package_dirs=extra_package_dirs, - sbatch_kwargs=sbatch_kwargs, - heterogeneous=heterogeneous, - het_group=het_group, - total_het_groups=total_het_groups, - with_ray=with_ray, - ray_template=ray_template, - ) - if cluster_config["executor"] != "slurm" and num_server_tasks > 1: - server_cmd = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}" - commands.append(server_cmd) - executors.append(server_executor) - het_group_indices.append(het_group) - het_group += 1 - LOG.info("Server command: %s", server_cmd) + server_container = cluster_config["containers"][_server_config["server_type"]] - # then goes the main task(s) unless it's empty + for server_idx in range(n_servers): + server_cmd, num_server_tasks = get_server_command(**_server_config, cluster_config=cluster_config) + server_executor = get_executor( + cluster_config=cluster_config, + container=server_container, + num_nodes=_server_config["num_nodes"], + tasks_per_node=num_server_tasks, + gpus_per_node=_server_config["num_gpus"], + partition=partition, + account=account, + dependencies=dependencies, + job_name=task_name, + log_dir=log_dir, + log_prefix=f"server_{server_idx}" if n_servers > 1 else "server", + extra_package_dirs=extra_package_dirs, + sbatch_kwargs=sbatch_kwargs, + heterogeneous=heterogeneous, + het_group=het_group, + total_het_groups=total_het_groups, + overlap=(not client_num_gpus), # Only overlap when the main task does not have gpus + with_ray=False, + ray_template=ray_template, + ) + cmd_to_add = server_cmd + if cluster_config["executor"] != "slurm" and num_server_tasks > 1: + cmd_to_add = f"mpirun --allow-run-as-root -np {num_server_tasks} bash -c {shlex.quote(server_cmd)}" + commands.append(cmd_to_add) + executors.append(server_executor) + het_group_indices.append(het_group) + het_group += 1 + LOG.info("Server %d command: %s", server_idx, server_cmd) + + # If client doesn't need GPUs but server does, add server first so SLURM allocates GPU partition + if server_goes_first: + add_server_tasks() + + # Then goes the main task(s) unless it's empty if cmd: if isinstance(cmd, str): cmd = [cmd] @@ -588,13 +606,14 @@ def add_task( with temporary_env_update(cluster_config, {"NEMO_SKILLS_SANDBOX_PORT": sandbox_port}): cur_cmd = install_packages_wrap(cur_cmd, installation_command) commands.append(cur_cmd) + client_num_gpus = num_gpus if (server_config is None or num_nodes > 1) else 0 executors.append( get_executor( cluster_config=cluster_config, container=cur_container, num_nodes=num_nodes, tasks_per_node=cur_tasks, - gpus_per_node=num_gpus if server_config is None else 0, + gpus_per_node=client_num_gpus, partition=partition, account=account, dependencies=dependencies, @@ -606,7 +625,7 @@ def add_task( heterogeneous=heterogeneous, het_group=het_group, total_het_groups=total_het_groups, - overlap=(server_config is not None) or with_sandbox, + overlap=(not client_num_gpus), # Only when the main task does not have gpus with_ray=with_ray, ray_template=ray_template, ) @@ -615,7 +634,7 @@ def add_task( het_group += 1 LOG.info("Main command(s): %s", ", ".join(cmd)) - # finally a sandbox if needed + # Then a sandbox if needed if with_sandbox: sandbox_env_updates = { "LISTEN_PORT": sandbox_port, @@ -653,7 +672,7 @@ def add_task( het_group=het_group, total_het_groups=total_het_groups, overlap=True, - with_ray=with_ray, + with_ray=False, ray_template=ray_template, # Allow the sandbox to survive individual worker crashes (e.g. SIGILL # from libraries compiled for a different CPU). nemo-run hardcodes @@ -673,6 +692,10 @@ def add_task( het_group += 1 LOG.info("Sandbox command: %s", commands[-1]) + # If server wasn't added first (because client needs GPUs or server doesn't need GPUs), add it now + if server_config is not None and not server_goes_first: + add_server_tasks() + if cluster_config["executor"] != "none": tunnel = get_tunnel(cluster_config) if reuse_code: