diff --git a/nemo_skills/pipeline/generate.py b/nemo_skills/pipeline/generate.py index 95e37abb59..cb5b003976 100644 --- a/nemo_skills/pipeline/generate.py +++ b/nemo_skills/pipeline/generate.py @@ -44,109 +44,175 @@ # TODO: add num_jobs here for consistency with eval? -def _create_commandgroup_from_config( - generation_cmd: str, - server_config: Optional[Dict], - with_sandbox: bool, - sandbox_port: Optional[int], +def _create_job_unified( + models: List[str], + server_configs: List[Optional[Dict]], + generation_cmd_params: Dict, cluster_config: Dict, installation_command: Optional[str], get_server_command_fn: Callable, + with_sandbox: bool, + sandbox_port: Optional[int], partition: Optional[str], keep_mounts_for_sandbox: bool, task_name: str, log_dir: str, sbatch_kwargs: Optional[Dict] = None, -) -> CommandGroup: - """Create a CommandGroup from server_config. - - Component ordering: - 1. Server (if server_config provided) - 2. Client command - 3. Sandbox (if with_sandbox=True) +) -> Dict: + """ + Create a job for n models (unified for n=1 and n>1). + + Structure: + - Group 0: Model 0 server + client + (optional sandbox) + - Group 1: Model 1 server (if n>1) + - Group N: Model N server (if n>1) + + For n=1, returns a single-element list. The Pipeline automatically + optimizes single-group lists to use efficient single-group jobs. + + Args: + models: List of model paths + server_configs: List of server configurations (one per model, None if not hosting) + generation_cmd_params: Dict of parameters to build client command lambda + cluster_config: Cluster configuration + installation_command: Installation command to run before client + get_server_command_fn: Function to build server commands + with_sandbox: Whether to include sandbox + sandbox_port: Port for sandbox + partition: Slurm partition + keep_mounts_for_sandbox: Whether to keep mounts for sandbox + task_name: Name for the task + log_dir: Directory for logs + sbatch_kwargs: Additional sbatch kwargs (including qos, time_min, exclusive, etc.) + + Returns: + Job dict with "groups" key (list of CommandGroup objects) """ - components = [] - - # 1. Add server if server_config is provided - if server_config is not None and int(server_config["num_gpus"]) > 0: - server_type = server_config["server_type"] - # Get container from server_config if provided, otherwise fall back to cluster config - if "container" in server_config: - server_container = server_config.pop("container") + num_models = len(models) + groups = [] + server_commands = [] # Track server Command objects for hostname references + + for model_idx, (model_path, server_config) in enumerate(zip(models, server_configs)): + components = [] + + # Track GPU/node requirements for this group (from server config) + group_gpus = 0 + group_nodes = 1 + + # 1. Add server if needed + if server_config is not None and int(server_config.get("num_gpus", 0)) > 0: + server_type = server_config["server_type"] + + server_container = server_config.get("container") or cluster_config["containers"][server_type] + + server_config_for_cmd = {k: v for k, v in server_config.items() if k != "container"} + cmd, num_tasks = get_server_command_fn(**server_config_for_cmd, cluster_config=cluster_config) + + # Set group GPU/node requirements from server config + group_gpus = server_config["num_gpus"] + group_nodes = server_config["num_nodes"] + + metadata = { + "num_tasks": num_tasks, + "gpus": group_gpus, + "nodes": group_nodes, + "log_prefix": f"server_{model_idx}" if num_models > 1 else "server", + "port": server_config["server_port"], + } + + server_cmd = Command( + command=cmd, + container=server_container, + gpus=group_gpus, + nodes=group_nodes, + name=f"model_{model_idx}_server" if num_models > 1 else "server", + metadata=metadata, + ) + components.append(server_cmd) + server_commands.append(server_cmd) else: - server_container = cluster_config["containers"][server_type] + # No server for this model (pre-hosted) + server_commands.append(None) + + # 2. Group 0 gets the client + if model_idx == 0: + client_env = {} + if with_sandbox and sandbox_port is not None: + client_env["NEMO_SKILLS_SANDBOX_PORT"] = str(sandbox_port) + + # Build client command as lambda + def build_client_command(): + runtime_addresses = [] + + for server_idx, server_cmd in enumerate(server_commands): + if server_cmd is not None: + # Self-hosted: construct address from hostname and port refs + addr = f"{server_cmd.hostname_ref()}:{server_cmd.meta_ref('port')}" + else: + # Pre-hosted: use the original address from config + addr = generation_cmd_params["server_addresses_prehosted"][server_idx] + runtime_addresses.append(addr) + + # Build command with runtime-resolved addresses + cmd_str = pipeline_utils.get_generation_cmd( + server_addresses=runtime_addresses, + **{k: v for k, v in generation_cmd_params.items() if k not in ["server_addresses_prehosted"]}, + ) + return pipeline_utils.wrap_python_path(cmd_str) + + client_command = build_client_command + + client_cmd = Command( + command=client_command, + container=cluster_config["containers"]["nemo-skills"], + name="generation_client" if num_models > 1 else task_name, + installation_command=installation_command, + metadata={ + "log_prefix": "main", + "environment": client_env, + }, + ) + components.append(client_cmd) - # Call server command builder directly with cluster_config - cmd, num_tasks = get_server_command_fn(**server_config, cluster_config=cluster_config) + # 3. Add sandbox to group 0 if requested + # Only Group 0 is needed because we do not currently have a way + # to route requests between multiple sandboxes + if with_sandbox: + cmd, metadata = sandbox_command( + cluster_config=cluster_config, + port=sandbox_port, + ) + metadata["log_prefix"] = "sandbox" if num_models == 1 else "sandbox_0" - # Create metadata dict - metadata = { - "num_tasks": num_tasks, - "gpus": server_config["num_gpus"], - "nodes": server_config["num_nodes"], - "log_prefix": "server", - } + sandbox_cmd = Command( + command=cmd, + container=cluster_config["containers"]["sandbox"], + name="sandbox" if num_models == 1 else "sandbox_0", + metadata=metadata, + ) + components.append(sandbox_cmd) - server_cmd = Command( - command=cmd, - container=server_container, - gpus=server_config["num_gpus"], - nodes=server_config["num_nodes"], - name=task_name, - metadata=metadata, - ) - components.append(server_cmd) - - # 2. Add main generation command - # Note: General cluster config env vars are automatically added by get_env_variables() in get_executor() - client_env = {} - if with_sandbox and sandbox_port is not None: - client_env["NEMO_SKILLS_SANDBOX_PORT"] = str(sandbox_port) - - client_cmd = Command( - command=generation_cmd, - container=cluster_config["containers"]["nemo-skills"], - name=task_name, - installation_command=installation_command, - metadata={ - "log_prefix": "main", - "environment": client_env, - }, - ) - components.append(client_cmd) - - # 3. Add sandbox if requested - if with_sandbox: - # Call sandbox command builder directly with cluster_config - cmd, metadata = sandbox_command(cluster_config=cluster_config, port=sandbox_port) - metadata["log_prefix"] = "sandbox" - - sandbox_cmd = Command( - command=cmd, - container=cluster_config["containers"]["sandbox"], - name=task_name, - metadata=metadata, - ) + # Only create group if it has components (skip empty groups for pre-hosted models) + if components: + group = CommandGroup( + commands=components, + hardware=HardwareConfig( + partition=partition, + num_gpus=group_gpus, + num_nodes=group_nodes, + sbatch_kwargs=sbatch_kwargs, + ), + name=f"model_{model_idx}_group" if num_models > 1 else task_name, + log_dir=log_dir, + ) + groups.append(group) - components.append(sandbox_cmd) - - # Find maximum GPUs/nodes needed by any component for the HardwareConfig - # The job-level resource request must be the maximum across all components - max_gpus = max((comp.gpus or 0) for comp in components) - max_nodes = max((comp.nodes or 1) for comp in components) - - return CommandGroup( - commands=components, - hardware=HardwareConfig( - partition=partition, - num_gpus=max_gpus, - num_nodes=max_nodes, - sbatch_kwargs=sbatch_kwargs, - ), - name=task_name, - log_dir=log_dir, - ) + return { + "name": task_name, + "groups": groups, + "dependencies": None, + } @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @@ -177,21 +243,45 @@ def generate( "If not specified, will use the registered generation module for the " "generation type (which is required in this case).", ), - model: str = typer.Option(None, help="Path to the model or model name in API"), - server_address: str = typer.Option( - None, help="Use ip:port for self-hosted models or the API url if using model providers" + model: List[str] = typer.Option( + None, + help="Path to the model(s). CLI: space-separated. Python API: string or list. " + "Single value broadcasts to all models for multi-model generation.", + ), + server_address: List[str] = typer.Option( + None, + help="Server address(es). CLI: space-separated. Python API: string or list. " + "Single value broadcasts to all models.", + ), + server_type: List[pipeline_utils.SupportedServers] = typer.Option( + ..., + help="Server type(s). CLI: space-separated. Python API: string or list. " + "Single value broadcasts to all models.", ), - server_type: pipeline_utils.SupportedServers = typer.Option(..., help="Type of server to use"), - server_gpus: int = typer.Option(None, help="Number of GPUs to use if hosting the model"), - server_nodes: int = typer.Option(1, help="Number of nodes required for hosting LLM server"), - server_args: str = typer.Option("", help="Any extra arguments to pass to the server"), - server_entrypoint: str = typer.Option( + server_gpus: List[int] = typer.Option( None, - help="Path to the entrypoint of the server. " - "If not specified, will use the default entrypoint for the server type.", + help="Number of GPUs per model. CLI: space-separated ints. Python API: int or list. " + "Single value broadcasts to all models.", ), - server_container: str = typer.Option( - None, help="Override container image for the hosted server (if server_gpus is set)" + server_nodes: List[int] = typer.Option( + [1], + help="Number of nodes per model. CLI: space-separated ints. Python API: int or list. " + "Single value broadcasts to all models.", + ), + server_args: List[str] = typer.Option( + [""], + help="Server arguments per model. CLI: space-separated. Python API: string or list. " + "Single value broadcasts to all models.", + ), + server_entrypoint: List[str] = typer.Option( + None, + help="Server entrypoint(s). CLI: space-separated. Python API: string or list. " + "Single value broadcasts to all models.", + ), + server_container: List[str] = typer.Option( + None, + help="Container image(s). CLI: space-separated. Python API: string or list. " + "Single value broadcasts to all models.", ), dependent_jobs: int = typer.Option(0, help="Specify this to launch that number of dependent jobs"), mount_paths: str = typer.Option(None, help="Comma separated list of paths to mount on the remote machine"), @@ -282,7 +372,18 @@ def generate( None, help="Internal option to specify task dependencies.", hidden=True ), ): - """Generate LLM completions for a given input file. + """Generate LLM completions for single or multiple models. + + Supports both single-model and multi-model generation through a unified interface. + + Parameter Types: + Multi-model parameters (model, server_*, etc.) use List[T] type hints for Typer CLI + compatibility, but accept both scalars and lists when called from Python: + - CLI: --model m1 m2 (space-separated) → Typer converts to ["m1", "m2"] + - Python API: model="m1" or model=["m1", "m2"] → Both work (normalized internally) + - Single values broadcast to all models: server_gpus=8 → [8, 8, 8] for 3 models + + Multi-model usage requires either --generation-type or --generation-module. Run `python -m nemo_skills.inference.generate --help` for other supported arguments (need to be prefixed with ++, since we use Hydra for that script). @@ -292,10 +393,38 @@ def generate( LOG.info("Starting generation job") LOG.info("Extra arguments that will be passed to the underlying script: %s", extra_arguments) - try: - server_type = server_type.value - except AttributeError: - pass + models_list = pipeline_utils.normalize_models_config(model) + num_models = len(models_list) + + LOG.info(f"Number of models: {num_models}") + for model_idx, model_name in enumerate(models_list): + LOG.info(f" Model {model_idx}: {model_name}") + + def convert_server_type_to_string(server_type): + return server_type.value if hasattr(server_type, "value") else server_type + + if isinstance(server_type, list): + server_type = [convert_server_type_to_string(st) for st in server_type] + else: + server_type = convert_server_type_to_string(server_type) + server_types_list = pipeline_utils.normalize_parameter(server_type, num_models, "server_type") + server_gpus_list = pipeline_utils.normalize_parameter(server_gpus, num_models, "server_gpus") + server_nodes_list = pipeline_utils.normalize_parameter(server_nodes, num_models, "server_nodes") + server_args_list = pipeline_utils.normalize_parameter(server_args, num_models, "server_args") + server_entrypoints_list = pipeline_utils.normalize_parameter(server_entrypoint, num_models, "server_entrypoint") + server_containers_list = pipeline_utils.normalize_parameter(server_container, num_models, "server_container") + + if server_address is not None: + server_addresses_list = pipeline_utils.normalize_parameter(server_address, num_models, "server_address") + else: + server_addresses_list = [None] * num_models + + # Validate multi-model requirements + if num_models > 1: + if generation_type is None and generation_module is None: + raise ValueError( + "Multi-model generation requires either --generation-type or --generation-module to be specified" + ) if log_samples: wandb_parameters = { @@ -311,8 +440,6 @@ def generate( else: wandb_parameters = None - get_random_port = pipeline_utils.should_get_random_port(server_gpus, exclusive) - if random_seeds and num_random_seeds: raise ValueError("Cannot specify both random_seeds and num_random_seeds") if num_random_seeds: @@ -341,8 +468,6 @@ def generate( check_mounted_paths=check_mounted_paths, ) - original_server_address = server_address - if generation_module is not None and generation_type is not None: raise ValueError("Cannot specify both generation_module and generation_type. ") if generation_module is None: @@ -393,36 +518,53 @@ def generate( chunk_id=None, ) for chunk_id in chunk_ids: - # Configure client (same as before) - server_config, server_address, extra_arguments = pipeline_utils.configure_client( - model=model, - server_type=server_type, - server_address=original_server_address, - server_gpus=server_gpus, - server_nodes=server_nodes, - server_args=server_args, - server_entrypoint=server_entrypoint, - server_container=server_container, - extra_arguments=extra_arguments_original, - get_random_port=get_random_port, - ) + server_configs = [] + server_addresses_resolved = [] + # For single model: configure_client returns extra_args with server config appended + # For multi-model: use original extra_args (server config added as lists in get_generation_cmd) + extra_arguments = extra_arguments_original + + for model_idx in range(num_models): + get_random_port_for_server = pipeline_utils.should_get_random_port( + server_gpus_list[model_idx], exclusive + ) - # Build generation command (same as before) - cmd = pipeline_utils.get_generation_cmd( - input_file=input_file, - input_dir=input_dir, - random_seed=seed, - output_dir=output_dir, - extra_arguments=extra_arguments, - chunk_id=chunk_id, - num_chunks=num_chunks, - preprocess_cmd=preprocess_cmd, - postprocess_cmd=postprocess_cmd, - wandb_parameters=wandb_parameters if seed_idx == 0 else None, - script=generation_module, - with_sandbox=with_sandbox, - ) - cmd = pipeline_utils.wrap_python_path(cmd=cmd) + srv_config, srv_address, srv_extra_args = pipeline_utils.configure_client( + model=models_list[model_idx], + server_type=server_types_list[model_idx], + server_address=server_addresses_list[model_idx], + server_gpus=server_gpus_list[model_idx], + server_nodes=server_nodes_list[model_idx], + server_args=server_args_list[model_idx], + server_entrypoint=server_entrypoints_list[model_idx], + server_container=server_containers_list[model_idx], + extra_arguments=extra_arguments_original if model_idx == 0 else "", + get_random_port=get_random_port_for_server, + ) + server_configs.append(srv_config) + server_addresses_resolved.append(srv_address) + + # For single model, capture the extra_args with server config from configure_client + if model_idx == 0 and num_models == 1: + extra_arguments = srv_extra_args + + cmd_params = { + "input_file": input_file, + "input_dir": input_dir, + "random_seed": seed, + "output_dir": output_dir, + "model_names": models_list, + "server_types": server_types_list, + "extra_arguments": extra_arguments, + "chunk_id": chunk_id, + "num_chunks": num_chunks, + "preprocess_cmd": preprocess_cmd, + "postprocess_cmd": postprocess_cmd, + "wandb_parameters": wandb_parameters if seed_idx == 0 else None, + "script": generation_module, + "server_addresses_prehosted": server_addresses_resolved, + } + cmd = cmd_params # Base task name (shared across all dependent jobs in the chain) task_name = f"{expname}-rs{seed}" if seed is not None else expname @@ -435,21 +577,21 @@ def generate( for dep_idx in range(dependent_jobs + 1): # Allocate sandbox port if needed - # This must be done BEFORE creating CommandGroup so client knows the port + # This must be done BEFORE creating job so client knows the port if with_sandbox: - current_sandbox_port = get_free_port(strategy="random") if get_random_port else 6000 + current_sandbox_port = get_free_port(strategy="random") if get_random_port_for_server else 6000 else: current_sandbox_port = None - # Create CommandGroup for this task - cmd_group = _create_commandgroup_from_config( - generation_cmd=cmd, - server_config=server_config.copy() if server_config else None, - with_sandbox=with_sandbox, - sandbox_port=current_sandbox_port, + job_spec = _create_job_unified( + models=models_list, + server_configs=[cfg.copy() if cfg else None for cfg in server_configs], + generation_cmd_params=cmd, cluster_config=cluster_config, installation_command=installation_command, get_server_command_fn=generation_task.get_server_command_fn(), + with_sandbox=with_sandbox, + sandbox_port=current_sandbox_port, partition=partition, keep_mounts_for_sandbox=keep_mounts_for_sandbox, task_name=task_name, @@ -459,6 +601,7 @@ def generate( # Use unique internal job name for dependency tracking, but same task_name internal_job_name = f"{task_name}-dep{dep_idx}" if dep_idx > 0 else task_name + job_spec["name"] = internal_job_name # Build dependencies: first job in chain gets external dependencies, rest chain to previous if dep_idx == 0: @@ -472,11 +615,7 @@ def generate( # Subsequent jobs in chain depend on previous job (use job object, not string) job_deps = [prev_job] - job_spec = { - "name": internal_job_name, - "group": cmd_group, - "dependencies": job_deps, - } + job_spec["dependencies"] = job_deps jobs.append(job_spec) prev_job = job_spec # Track for next iteration diff --git a/nemo_skills/pipeline/utils/__init__.py b/nemo_skills/pipeline/utils/__init__.py index 1e470f3539..3e738a530f 100644 --- a/nemo_skills/pipeline/utils/__init__.py +++ b/nemo_skills/pipeline/utils/__init__.py @@ -49,6 +49,8 @@ get_chunked_rs_filename, get_generation_cmd, get_remaining_jobs, + normalize_models_config, + normalize_parameter, wrap_cmd, ) from nemo_skills.pipeline.utils.mounts import ( diff --git a/nemo_skills/pipeline/utils/declarative.py b/nemo_skills/pipeline/utils/declarative.py index e294a3ed82..707486f114 100644 --- a/nemo_skills/pipeline/utils/declarative.py +++ b/nemo_skills/pipeline/utils/declarative.py @@ -186,11 +186,20 @@ def __post_init__(self): self.command = wrap_command(self.command, self.working_dir, self.env_vars) def hostname_ref(self) -> str: - """Get hostname reference for hetjob cross-component communication.""" + """Get hostname reference for hetjob cross-component communication. + + Returns a shell variable reference that resolves to the master node hostname + for this het group. Uses environment variables automatically exported by nemo-run: + SLURM_MASTER_NODE_HET_GROUP_0, SLURM_MASTER_NODE_HET_GROUP_1, etc. + + These are set via: + export SLURM_MASTER_NODE_HET_GROUP_N=$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_N | head -n1) + """ if self.het_group_index is None: - return "127.0.0.1" # Local fallback - # For heterogeneous SLURM jobs, resolve nodelist to actual hostname - return f"$(scontrol show hostnames $SLURM_JOB_NODELIST_HET_GROUP_{self.het_group_index} | head -n1)" + return "127.0.0.1" # Local fallback for non-heterogeneous jobs + + # Use the environment variable exported by nemo-run + return f"${{SLURM_MASTER_NODE_HET_GROUP_{self.het_group_index}:-localhost}}" def meta_ref(self, key: str) -> str: """Get metadata value (like port). Fails if key not found.""" @@ -571,6 +580,13 @@ def _plan_and_add_job( executors: List = [] het_group_indices: List[int] = [] + # Assign het_group_indices FIRST (before any prepare_for_execution calls) + # This is critical for cross-component references in lambdas + if heterogeneous: + for het_idx, group in enumerate(groups): + for command in group.commands: + command.het_group_index = het_idx + # In heterogeneous jobs, collect environment from all commands for cross-component refs shared_env_vars: Dict[str, str] = {} if heterogeneous: @@ -595,13 +611,6 @@ def _plan_and_add_job( ) for comp_idx, command in enumerate(group.commands): - # Assign het_group_index ONLY for heterogeneous jobs (per-job, not global) - # Non-heterogeneous jobs use localhost, so het_group_index should remain None - if heterogeneous: - command.het_group_index = het_idx - else: - command.het_group_index = None - final_cmd, exec_config = self._prepare_command(command, cluster_config) commands.append(final_cmd) diff --git a/nemo_skills/pipeline/utils/generation.py b/nemo_skills/pipeline/utils/generation.py index cd576053c1..e742e1da53 100644 --- a/nemo_skills/pipeline/utils/generation.py +++ b/nemo_skills/pipeline/utils/generation.py @@ -17,6 +17,7 @@ import shlex import subprocess from collections import defaultdict +from typing import Optional from nemo_skills.pipeline.utils.cluster import get_tunnel from nemo_skills.pipeline.utils.mounts import get_unmounted_path @@ -26,6 +27,81 @@ LOG = logging.getLogger(get_logger_name(__file__)) +def normalize_models_config( + model: Optional[str | list[str]], +) -> list[str]: + """ + Normalize model specification to list. + + Handles both scalar and list inputs: + - CLI (Typer): Converts single values to single-element lists automatically + - Python API: Accepts both strings and lists + + Args: + model: Model path(s) - string or list from Python API, list from CLI + + Returns: + List of model paths + + Raises: + ValueError: If model is None or empty + """ + if model is None: + raise ValueError("Must specify --model") + + # Handle string (Python API with single model) + if isinstance(model, str): + return [model] + + # Handle list + if len(model) == 0: + raise ValueError("Must specify --model") + return list(model) + + +def normalize_parameter( + param_value: any, + num_models: int, + param_name: str, +) -> list[any]: + """ + Normalize a parameter to a per-model list. + + Handles both scalar and list inputs for flexible usage: + - CLI (Typer): Converts single values to single-element lists automatically + - Python API: Accepts both scalars and lists directly + + Broadcast logic: + - Scalar value: Broadcast to all models [value] * num_models + - Single-element list: Broadcast to all models + - Multi-element list: Must match num_models exactly + + Args: + param_value: Parameter value (scalar or list) + num_models: Number of models + param_name: Name of parameter (for error messages) + + Returns: + List of parameter values (one per model) + + Raises: + ValueError: If list length doesn't match num_models + """ + if not isinstance(param_value, list): + return [param_value] * num_models + + if len(param_value) == num_models: + return list(param_value) + + if len(param_value) == 1: + return param_value * num_models + + raise ValueError( + f"Parameter {param_name} has {len(param_value)} values but {num_models} models specified. " + f"Must be 1 value (broadcast) or {num_models} values (per-model)." + ) + + def get_chunked_rs_filename( output_dir: str, random_seed: int = None, @@ -294,6 +370,10 @@ def get_generation_cmd( wandb_parameters=None, with_sandbox: bool = False, script: str = "nemo_skills.inference.generate", + # Optional: for multi-model generation + server_addresses: list[str] | None = None, + model_names: list[str] | None = None, + server_types: list[str] | None = None, ): """Construct the generation command for language model inference.""" if input_file is None and input_dir is None: @@ -321,12 +401,29 @@ def get_generation_cmd( # Handle file paths vs module names common_args = f"++skip_filled=True ++input_file={input_file} ++output_file={output_file}" if script.endswith(".py") or os.sep in script: - # It's a file path, run it directly with .py extension script_path = script if script.endswith(".py") else f"{script}.py" cmd += f"python {script_path} {hydra_config_args} {common_args} " else: # It's a module name, use -m flag cmd += f"python -m {script} {hydra_config_args} {common_args} " + + if server_addresses is not None and model_names is not None: + num_models = len(model_names) + if num_models > 1: + # Multi-model: pass server configuration as lists + # Just pass base_url for all models (both self-hosted and pre-hosted) + # The inference script will configure the client correctly based on server_gpus + + model_names_arg = ",".join(model_names) + cmd += f"++server.model=[{model_names_arg}] " + + server_types_arg = ",".join(server_types) + cmd += f"++server.server_type=[{server_types_arg}] " + + server_addresses_arg = ",".join(server_addresses) + cmd += f"++server.base_url=[{server_addresses_arg}] " + # For n=1: server config is already in extra_arguments from configure_client + job_end_cmd = "" if random_seed is not None and input_dir is None: # if input_dir is not None, we default to greedy generations diff --git a/tests/test_declarative_pipeline.py b/tests/test_declarative_pipeline.py index 92117e403d..3e8b5cf4c8 100644 --- a/tests/test_declarative_pipeline.py +++ b/tests/test_declarative_pipeline.py @@ -16,6 +16,7 @@ import json import os +import subprocess from unittest.mock import MagicMock, patch import pytest @@ -116,13 +117,38 @@ def test_command_hostname_ref_none(self): assert cmd.hostname_ref() == "127.0.0.1" def test_command_hostname_ref_heterogeneous(self): - """Test hostname_ref returns SLURM variable when het_group_index is set.""" - cmd = Command(command="echo test", name="test") - cmd.het_group_index = 2 + """Test hostname_ref returns runtime-resolvable references for heterogeneous jobs.""" + + cmd0 = Command(command="echo test", name="server0") + cmd0.het_group_index = 0 + + cmd1 = Command(command="echo test", name="server1") + cmd1.het_group_index = 1 + + hostname0 = cmd0.hostname_ref() + hostname1 = cmd1.hostname_ref() + + # Each het group should get different references + assert hostname0 != hostname1, "Different het groups should have different hostname references" + + # Test with environment variables set (simulating nemo-run's behavior) + env = os.environ.copy() + env["SLURM_MASTER_NODE_HET_GROUP_0"] = "node-123" + env["SLURM_MASTER_NODE_HET_GROUP_1"] = "node-456" + + # Evaluate the shell references + resolved0 = subprocess.check_output(f"echo {hostname0}", shell=True, env=env, text=True).strip() + resolved1 = subprocess.check_output(f"echo {hostname1}", shell=True, env=env, text=True).strip() - hostname = cmd.hostname_ref() - assert "$SLURM_JOB_NODELIST_HET_GROUP_2" in hostname - assert "scontrol" in hostname + assert resolved0 == "node-123", f"Expected 'node-123', got '{resolved0}'" + assert resolved1 == "node-456", f"Expected 'node-456', got '{resolved1}'" + + # Test fallback to localhost when env vars not set + env_no_slurm = {} + resolved_fallback = subprocess.check_output( + f"echo {hostname0}", shell=True, env=env_no_slurm, text=True + ).strip() + assert resolved_fallback == "localhost", f"Should fallback to localhost, got '{resolved_fallback}'" def test_command_with_installation_command(self): """Test Command with installation_command.""" @@ -475,8 +501,27 @@ def test_het_group_index_heterogeneous(self, mock_env_vars, mock_get_exp): # Commands should have het_group_index 0 and 1 assert cmd1.het_group_index == 0 assert cmd2.het_group_index == 1 - assert "$SLURM_JOB_NODELIST_HET_GROUP_0" in cmd1.hostname_ref() - assert "$SLURM_JOB_NODELIST_HET_GROUP_1" in cmd2.hostname_ref() + + # Test that hostname references actually resolve correctly + import os + import subprocess + + env = os.environ.copy() + env["SLURM_MASTER_NODE_HET_GROUP_0"] = "test-node-123" + env["SLURM_MASTER_NODE_HET_GROUP_1"] = "test-node-456" + + hostname1 = cmd1.hostname_ref() + hostname2 = cmd2.hostname_ref() + + # Different groups should have different references + assert hostname1 != hostname2 + + # Verify they resolve to the correct hostnames + resolved1 = subprocess.check_output(f"echo {hostname1}", shell=True, env=env, text=True).strip() + resolved2 = subprocess.check_output(f"echo {hostname2}", shell=True, env=env, text=True).strip() + + assert resolved1 == "test-node-123", f"Group 0 should resolve to test-node-123, got {resolved1}" + assert resolved2 == "test-node-456", f"Group 1 should resolve to test-node-456, got {resolved2}" @patch("nemo_skills.pipeline.utils.declarative.get_exp") @patch("nemo_skills.pipeline.utils.declarative.get_env_variables") diff --git a/tests/test_generation.py b/tests/test_generation.py index b69b526a0e..c6b720c03f 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -21,7 +21,7 @@ import pytest from nemo_skills.evaluation.metrics import ComputeMetrics -from nemo_skills.pipeline.generate import _create_commandgroup_from_config +from nemo_skills.pipeline.generate import _create_job_unified def test_eval_gsm8k_api(tmp_path): @@ -157,7 +157,11 @@ def test_server_metadata_from_num_tasks(): """Test that metadata dict is properly created from server command returning (cmd, num_tasks).""" mock_server_fn = MagicMock(return_value=("python server.py", 4)) cluster_config = { - "containers": {"vllm": "nvcr.io/nvidia/nemo:vllm", "nemo-skills": "nvcr.io/nvidia/nemo:skills"}, + "containers": { + "vllm": "nvcr.io/nvidia/nemo:vllm", + "nemo-skills": "nvcr.io/nvidia/nemo:skills", + "sandbox": "nvcr.io/nvidia/nemo:sandbox", + }, "executor": "slurm", } server_config = { @@ -168,20 +172,32 @@ def test_server_metadata_from_num_tasks(): "server_port": 5000, } - cmd_group = _create_commandgroup_from_config( - generation_cmd="python generate.py", - server_config=server_config, - with_sandbox=False, - sandbox_port=None, + generation_cmd_params = { + "input_file": "test.jsonl", + "output_dir": "/tmp/output", + "model_names": ["test_model"], + "num_models": 1, + "server_addresses_prehosted": ["localhost:5000"], + } + + job = _create_job_unified( + models=["test_model"], + server_configs=[server_config], + generation_cmd_params=generation_cmd_params, cluster_config=cluster_config, installation_command=None, get_server_command_fn=mock_server_fn, + with_sandbox=False, + sandbox_port=None, partition=None, keep_mounts_for_sandbox=False, task_name="test-task", log_dir="/tmp/logs", + sbatch_kwargs=None, ) + # For single model, job has "groups" key with single group + cmd_group = job["groups"][0] server_cmd = cmd_group.commands[0] assert isinstance(server_cmd.metadata, dict) assert server_cmd.metadata["num_tasks"] == 4