-
Notifications
You must be signed in to change notification settings - Fork 163
Add multi-instance pipeline support (gpus_per_node) #1247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
409c8ee
1574526
436e7b3
83c2375
fbfdff9
db294bf
42eb131
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -495,8 +495,15 @@ def get_generation_cmd( | |||||||||||||||||||||||||
| cmd += "++wait_for_sandbox=true " | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if chunk_id is not None: | ||||||||||||||||||||||||||
| cmd += f" ++num_chunks={num_chunks} ++chunk_id={chunk_id} " | ||||||||||||||||||||||||||
| output_file = get_chunked_rs_filename(output_dir, random_seed=random_seed, chunk_id=chunk_id) | ||||||||||||||||||||||||||
| # Check if chunk_id is a shell expression (e.g., "$((0 + $SLURM_LOCALID))") | ||||||||||||||||||||||||||
| is_shell_expr = isinstance(chunk_id, str) and "$" in str(chunk_id) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if is_shell_expr: | ||||||||||||||||||||||||||
| # For shell expressions, use double quotes so shell expands the expression | ||||||||||||||||||||||||||
| cmd += f' ++num_chunks={num_chunks} "++chunk_id={chunk_id}" ' | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| cmd += f" ++num_chunks={num_chunks} ++chunk_id={chunk_id} " | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| donefiles = [] | ||||||||||||||||||||||||||
| # we are always waiting for all chunks in num_chunks, no matter chunk_ids in | ||||||||||||||||||||||||||
| # the current run (as we don't want to merge partial jobs) | ||||||||||||||||||||||||||
|
|
@@ -505,10 +512,23 @@ def get_generation_cmd( | |||||||||||||||||||||||||
| donefile = f"{filename}.done" | ||||||||||||||||||||||||||
| donefiles.append(donefile) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if job_end_cmd: | ||||||||||||||||||||||||||
| job_end_cmd += f" && touch {donefiles[chunk_id]} " | ||||||||||||||||||||||||||
| if is_shell_expr: | ||||||||||||||||||||||||||
| # For shell expression, compute the donefile path at runtime | ||||||||||||||||||||||||||
| # Get the base pattern with _chunk_0 and replace with shell expression | ||||||||||||||||||||||||||
| base_donefile = donefiles[0] # e.g., /path/output_chunk_0.jsonl.done | ||||||||||||||||||||||||||
| # Replace "_chunk_0.jsonl" with "_chunk_$((expr)).jsonl" where expr is expanded by shell | ||||||||||||||||||||||||||
| # Extract the expression part (e.g., "0 + $SLURM_LOCALID" from "$((0 + $SLURM_LOCALID))") | ||||||||||||||||||||||||||
| donefile_pattern = base_donefile.replace("_chunk_0.jsonl", f"_chunk_{chunk_id}.jsonl") | ||||||||||||||||||||||||||
| if job_end_cmd: | ||||||||||||||||||||||||||
| job_end_cmd += f' && touch "{donefile_pattern}" ' | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| job_end_cmd = f'touch "{donefile_pattern}" ' | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| job_end_cmd = f"touch {donefiles[chunk_id]} " | ||||||||||||||||||||||||||
| output_file = get_chunked_rs_filename(output_dir, random_seed=random_seed, chunk_id=chunk_id) | ||||||||||||||||||||||||||
| if job_end_cmd: | ||||||||||||||||||||||||||
| job_end_cmd += f" && touch {donefiles[chunk_id]} " | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| job_end_cmd = f"touch {donefiles[chunk_id]} " | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # getting file name as if there is no chunking since that's where we want to merge | ||||||||||||||||||||||||||
| merged_output_file = get_chunked_rs_filename(output_dir=output_dir, random_seed=random_seed) | ||||||||||||||||||||||||||
|
|
@@ -582,6 +602,7 @@ def configure_client( | |||||||||||||||||||||||||
| get_random_port: bool, | ||||||||||||||||||||||||||
| extra_arguments: str, | ||||||||||||||||||||||||||
| server_container: str | None = None, | ||||||||||||||||||||||||||
| gpus_per_node: int = 1, | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| Utility function to configure a client for the model inference server. | ||||||||||||||||||||||||||
|
|
@@ -597,6 +618,7 @@ def configure_client( | |||||||||||||||||||||||||
| get_random_port: Whether to get a random port for the server. | ||||||||||||||||||||||||||
| extra_arguments: Extra arguments to pass to the command. | ||||||||||||||||||||||||||
| server_container: Container to use for the server. | ||||||||||||||||||||||||||
| gpus_per_node: Number of GPUs per node for multi-instance mode. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||
| A tuple containing: | ||||||||||||||||||||||||||
|
|
@@ -625,9 +647,16 @@ def configure_client( | |||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| if server_container: | ||||||||||||||||||||||||||
| server_config["container"] = server_container | ||||||||||||||||||||||||||
| extra_arguments = ( | ||||||||||||||||||||||||||
| f"++server.host=127.0.0.1 ++server.port={server_port} ++server.model={model} {extra_arguments}" | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| if gpus_per_node > 1: | ||||||||||||||||||||||||||
| # Multi-instance mode: port is computed at runtime based on SLURM_LOCALID | ||||||||||||||||||||||||||
| extra_arguments = ( | ||||||||||||||||||||||||||
| f"++server.host=127.0.0.1 " | ||||||||||||||||||||||||||
| f'"++server.port=$(({server_port} + $SLURM_LOCALID))" ++server.model={model} {extra_arguments}' | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
Comment on lines
+650
to
+655
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistent
🛡️ Proposed fix — use consistent fallback- f'"++server.port=$(({server_port} + $SLURM_LOCALID))" ++server.model={model} {extra_arguments}'
+ f'"++server.port=$(({server_port} + ${{SLURM_LOCALID:-0}}))" ++server.model={model} {extra_arguments}'📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| extra_arguments = ( | ||||||||||||||||||||||||||
| f"++server.host=127.0.0.1 ++server.port={server_port} ++server.model={model} {extra_arguments}" | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| else: # model is hosted elsewhere | ||||||||||||||||||||||||||
| server_config = None | ||||||||||||||||||||||||||
| extra_arguments = f"++server.base_url={server_address} ++server.model={model} {extra_arguments}" | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,9 +120,17 @@ def get_server_command( | |
| server_port: int, | ||
| server_args: str = "", | ||
| server_entrypoint: str | None = None, | ||
| gpus_per_node: int = 1, | ||
| ): | ||
| num_tasks = num_gpus | ||
|
|
||
| if gpus_per_node > 1 and server_type != "generic": | ||
| raise ValueError( | ||
| f"Multi-instance mode (gpus_per_node={gpus_per_node}) is only supported for " | ||
| f"server_type='generic', but got server_type='{server_type}'. " | ||
| f"Use gpus_per_node=1 or switch to server_type='generic'." | ||
| ) | ||
|
|
||
| # check if the model path is mounted if not vllm, sglang, or trtllm; | ||
| # vllm, sglang, trtllm can also pass model name as "model_path" so we need special processing | ||
| if server_type not in ["vllm", "sglang", "trtllm", "generic"]: | ||
|
|
@@ -209,15 +217,29 @@ def get_server_command( | |
| elif server_type == "generic": | ||
| if not server_entrypoint: | ||
| raise ValueError("For 'generic' server type, 'server_entrypoint' must be specified.") | ||
| server_start_cmd = ( | ||
| f"{server_entrypoint} " | ||
| f" --model {model_path} " | ||
| f" --num_gpus {num_gpus} " | ||
| f" --num_nodes {num_nodes} " | ||
| f" --port {server_port} " | ||
| f" {server_args} " | ||
| ) | ||
| num_tasks = 1 | ||
| if gpus_per_node > 1: | ||
| # Multi-instance mode: each SLURM task gets its own GPU and port | ||
| server_start_cmd = ( | ||
| f"echo 'SLURM_LOCALID='$SLURM_LOCALID' SLURM_PROCID='$SLURM_PROCID && " | ||
| f"export CUDA_VISIBLE_DEVICES=${{SLURM_LOCALID:-0}} && " | ||
| f"{server_entrypoint} " | ||
| f" --model {model_path} " | ||
| f" --num_gpus 1 " | ||
| f" --num_nodes 1 " | ||
| f" --port $(({server_port} + ${{SLURM_LOCALID:-0}})) " | ||
| f" {server_args} " | ||
| ) | ||
| num_tasks = gpus_per_node | ||
| else: | ||
| server_start_cmd = ( | ||
| f"{server_entrypoint} " | ||
| f" --model {model_path} " | ||
| f" --num_gpus {num_gpus} " | ||
| f" --num_nodes {num_nodes} " | ||
| f" --port {server_port} " | ||
| f" {server_args} " | ||
| ) | ||
| num_tasks = 1 | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+220
to
+242
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check the current file structure and the changes
cd /tmp/repo 2>/dev/null || cd .
# 1. Find the server.py file and verify it exists
find . -name "server.py" -path "*/pipeline/utils/*" 2>/dev/null | head -5
# 2. Check git diff if in a repo
if git rev-parse --git-dir > /dev/null 2>&1; then
echo "=== Git status ==="
git status --short
echo "=== Checking git diff for server.py ==="
git diff HEAD -- nemo_skills/pipeline/utils/server.py 2>/dev/null | head -150
else
echo "Not in a git repo"
fiRepository: NVIDIA-NeMo/Skills Length of output: 172 🏁 Script executed: #!/bin/bash
# 3. Search for gpus_per_node in the codebase to see if it's validated
echo "=== Searching for gpus_per_node usage ==="
rg "gpus_per_node" --type py -B 3 -A 3
# 4. Search for any assertions or validations related to GPU counts
echo "=== Searching for GPU count validations ==="
rg "(gpus_per_node|num_gpus)" --type py | rg -E "(assert|raise|if.*<=|if.*>=|if.*<|if.*>)" -B 1 -A 1Repository: NVIDIA-NeMo/Skills Length of output: 21948 🏁 Script executed: #!/bin/bash
# 5. Check the server.py file around lines 220-242 to see current state
echo "=== Reading server.py around lines 220-242 ==="
if [ -f "nemo_skills/pipeline/utils/server.py" ]; then
sed -n '210,250p' nemo_skills/pipeline/utils/server.py | cat -n
else
echo "File not found, searching..."
find . -name "server.py" -path "*/pipeline/utils/*" -exec sed -n '210,250p' {} \; 2>/dev/null
fiRepository: NVIDIA-NeMo/Skills Length of output: 2001 🏁 Script executed: #!/bin/bash
# 6. Check how num_tasks is used after being set
echo "=== Checking num_tasks usage in server.py ==="
rg "num_tasks" nemo_skills/pipeline/utils/server.py -B 2 -A 2
# 7. Check exp.py for num_tasks usage
echo "=== Checking num_tasks in exp.py ==="
rg "num_tasks" nemo_skills/pipeline/utils/exp.py -B 2 -A 2 2>/dev/null || echo "exp.py patterns not found"
# 8. Search broader for num_tasks usage
echo "=== Broader num_tasks search ==="
rg "num_tasks" --type py -l 2>/dev/null | head -10Repository: NVIDIA-NeMo/Skills Length of output: 2518 🏁 Script executed: #!/bin/bash
# 9. Look at the full function context around lines 220-242 to understand what this is
echo "=== Full context of the function ==="
if [ -f "nemo_skills/pipeline/utils/server.py" ]; then
sed -n '100,260p' nemo_skills/pipeline/utils/server.py | cat -n
fiRepository: NVIDIA-NeMo/Skills Length of output: 7162 Add validation for The However, the multi-instance path (lines 121–133) lacks validation that if gpus_per_node > num_gpus:
raise ValueError(
f"gpus_per_node ({gpus_per_node}) cannot exceed num_gpus ({num_gpus})"
)Place this check immediately after the 🤖 Prompt for AI Agents |
||
| else: | ||
| raise ValueError(f"Server type '{server_type}' not supported for model inference.") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 792
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 2945
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 1663
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 704
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 502
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 947
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 1159
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 3276
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 244
🏁 Script executed:
Repository: NVIDIA-NeMo/Skills
Length of output: 694
Fragile string replacement for shell-expression donefile path — no error if pattern is absent.
base_donefile.replace("_chunk_0.jsonl", f"_chunk_{chunk_id}.jsonl")silently returns the original string when the pattern isn't found (e.g., naming convention change) and could replace in a directory component if the output path itself contains_chunk_0.jsonl. Thetouchwould then create a file at the wrong path and the subsequentmerge_chunkswould stall waiting for a done file that never appears.Limit the replacement to the basename only and validate the pattern was found:
🛡️ Proposed fix
🤖 Prompt for AI Agents