added functionality for gbs to scale up with number of gpus#2551
added functionality for gbs to scale up with number of gpus#2551rsalagame-nvidia wants to merge 1 commit intor0.3.0from
Conversation
📝 WalkthroughWalkthroughModified the Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Possibly related PRs
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@scripts/performance/utils/utils.py`:
- Around line 183-186: The code silently truncates scaled global batch size by
using int(base_config.gbs_scaling_factor * num_gpus); update the logic that
computes gbs_size (when num_gpus != base_config.num_gpus) to compute the
multiplication in integer space using divmod or equivalent (e.g., quotient,
remainder = divmod(base_config.gbs_scaling_factor * num_gpus, 1) or better
compute numerator/denominator so you can divmod exactly) and raise a clear
exception if the result is not an integer so we fail fast; reference the symbols
gbs_size, base_config.gbs_scaling_factor, num_gpus and ensure the error message
makes it obvious this affects the scaling used in set_post_overrides.
| elif num_gpus != base_config.num_gpus: | ||
| # Scale GBS with num_gpus so experiment name matches the scaled GBS applied in set_post_overrides | ||
| gbs_size = int(base_config.gbs_scaling_factor * num_gpus) | ||
| else: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's read the target file to see the actual code and context
cat -n scripts/performance/utils/utils.py | head -200 | tail -40Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1902
🏁 Script executed:
# Check file size first
wc -l scripts/performance/utils/utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 109
🏁 Script executed:
# Search for gbs_scaling_factor definition and usage
rg "gbs_scaling_factor" --context=3Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1786
🏁 Script executed:
# Search for WorkloadBaseConfig definition and global_batch_size
rg "class WorkloadBaseConfig|global_batch_size" --context=2 -A 5Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 50386
🏁 Script executed:
# Get more context around the gbs_scaling_factor usage in overrides.py
rg "gbs_scaling_factor" scripts/performance/utils/overrides.py -B 5 -A 5Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 570
🏁 Script executed:
# Look for WorkloadBaseConfig class definition
rg "class WorkloadBaseConfig" --context=10 -A 20Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1781
🏁 Script executed:
# Check if there are any test files that test this scaling behavior
fd ".*test.*" --type f -e py | xargs grep -l "gbs_scaling_factor\|generate_experiment_config" 2>/dev/null | head -10Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Find workload_base_configs.py files to understand the actual configs
find . -name "*workload_base_configs*" -type fRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 564
🏁 Script executed:
# Let's check what actual base configs exist and their values
rg "WorkloadBaseConfig\(" --context=3 scripts/performance/configs/ -A 3 | head -100Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 5232
🏁 Script executed:
# Search for any existing tests related to generate_experiment_config
rg "generate_experiment_config" --context=5Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check one of the actual workload config files to see structure
cat scripts/performance/configs/llama/llama3_workload_base_configs.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 20305
🏁 Script executed:
# Run the verification script from the review comment
python - <<'PY'
import ast
from pathlib import Path
def int_lit(node):
return node.value if isinstance(node, ast.Constant) and isinstance(node.value, int) else None
found = False
for path in sorted(Path(".").rglob("*workload_base_configs.py")):
try:
tree = ast.parse(path.read_text())
except Exception as e:
continue
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
fn = node.func
name = fn.id if isinstance(fn, ast.Name) else (fn.attr if isinstance(fn, ast.Attribute) else None)
if name != "WorkloadBaseConfig":
continue
vals = {kw.arg: int_lit(kw.value) for kw in node.keywords if kw.arg in {"num_gpus", "global_batch_size"}}
ng = vals.get("num_gpus")
gbs = vals.get("global_batch_size")
if ng is None or gbs is None:
continue
# Check if any common target GPU counts would cause non-exact division
problematic = []
for target_gpus in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
if ng > 0 and (gbs * target_gpus) % ng != 0:
truncated = int((gbs * target_gpus) // ng)
exact = (gbs * target_gpus) / ng
problematic.append(f"target={target_gpus}: truncates {exact:.3f}→{truncated}")
if problematic:
found = True
print(f"{path.name}:{node.lineno}: num_gpus={ng}, gbs={gbs}")
for p in problematic[:3]: # Show first 3
print(f" {p}")
if not found:
print("✓ No truncation risks found in literal WorkloadBaseConfig definitions")
PYRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 140
🏁 Script executed:
# Check if there are any tests for the generate_experiment_config function
find . -name "*.py" -path "*/test*" -exec grep -l "generate_experiment_config\|gbs_scaling" {} \;Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Look at how set_post_overrides uses gbs_scaling_factor to understand the relationship better
rg "set_post_overrides" -B 3 -A 15 scripts/performance/utils/overrides.py | head -80Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 542
🏁 Script executed:
# Get the full set_post_overrides function to understand the context
sed -n '1,300p' scripts/performance/utils/overrides.py | tail -150Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 7151
🏁 Script executed:
# Get the complete set_post_overrides function
rg "def set_post_overrides" -A 80 scripts/performance/utils/overrides.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1836
🏁 Script executed:
# Also check how generate_experiment_config is used - in what context is it called
rg "generate_experiment_config" -B 5 -A 5 scripts/performance/utils/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Search more broadly for generate_experiment_config usage
rg "generate_experiment_config" -B 2 -A 2Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check what happens after set_post_overrides - does it use the gbs_size value from generate_experiment_config?
cat scripts/performance/utils/utils.py | sed -n '140,200p'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2385
Validate GBS scaling is exact; fail fast if target GPU count causes non-integral scaling.
The current code computes gbs_size = int(base_config.gbs_scaling_factor * num_gpus), which silently truncates when the scaled batch size is not an integer. While all existing configs (e.g., 128÷8=16.0, 128÷64=2.0, 256÷64=4.0) divide evenly, future configs may not. Computing in integer space with divmod() and raising an explicit error if the scaled GBS is not integral prevents silent data loss and catches config errors early.
Proposed fix
- elif num_gpus != base_config.num_gpus:
- # Scale GBS with num_gpus so experiment name matches the scaled GBS applied in set_post_overrides
- gbs_size = int(base_config.gbs_scaling_factor * num_gpus)
+ elif num_gpus != base_config.num_gpus:
+ # Keep scaling in integer space to avoid float truncation.
+ scaled_gbs, remainder = divmod(base_config.global_batch_size * num_gpus, base_config.num_gpus)
+ if remainder != 0:
+ raise ValueError(
+ "Scaled global_batch_size is not an integer for the requested GPU count: "
+ f"global_batch_size={base_config.global_batch_size}, "
+ f"base_num_gpus={base_config.num_gpus}, num_gpus={num_gpus}"
+ )
+ gbs_size = scaled_gbs🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@scripts/performance/utils/utils.py` around lines 183 - 186, The code silently
truncates scaled global batch size by using int(base_config.gbs_scaling_factor *
num_gpus); update the logic that computes gbs_size (when num_gpus !=
base_config.num_gpus) to compute the multiplication in integer space using
divmod or equivalent (e.g., quotient, remainder =
divmod(base_config.gbs_scaling_factor * num_gpus, 1) or better compute
numerator/denominator so you can divmod exactly) and raise a clear exception if
the result is not an integer so we fail fast; reference the symbols gbs_size,
base_config.gbs_scaling_factor, num_gpus and ensure the error message makes it
obvious this affects the scaling used in set_post_overrides.
added functionality for gbs to scale up with number of gpus which was missing previously.
Summary by CodeRabbit