Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion scripts/performance/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,15 @@ def get_exp_name_config(
else base_config.expert_tensor_parallel_size
)
mbs_size = args.micro_batch_size if args.micro_batch_size is not None else base_config.micro_batch_size
gbs_size = args.global_batch_size if args.global_batch_size is not None else base_config.global_batch_size

if args.global_batch_size is not None:
gbs_size = args.global_batch_size
elif num_gpus != base_config.num_gpus:
# Scale GBS with num_gpus so experiment name matches the scaled GBS applied in set_post_overrides
gbs_size = int(base_config.gbs_scaling_factor * num_gpus)
else:
Comment on lines +183 to +186
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's read the target file to see the actual code and context
cat -n scripts/performance/utils/utils.py | head -200 | tail -40

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1902


🏁 Script executed:

# Check file size first
wc -l scripts/performance/utils/utils.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 109


🏁 Script executed:

# Search for gbs_scaling_factor definition and usage
rg "gbs_scaling_factor" --context=3

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1786


🏁 Script executed:

# Search for WorkloadBaseConfig definition and global_batch_size
rg "class WorkloadBaseConfig|global_batch_size" --context=2 -A 5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 50386


🏁 Script executed:

# Get more context around the gbs_scaling_factor usage in overrides.py
rg "gbs_scaling_factor" scripts/performance/utils/overrides.py -B 5 -A 5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 570


🏁 Script executed:

# Look for WorkloadBaseConfig class definition
rg "class WorkloadBaseConfig" --context=10 -A 20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1781


🏁 Script executed:

# Check if there are any test files that test this scaling behavior
fd ".*test.*" --type f -e py | xargs grep -l "gbs_scaling_factor\|generate_experiment_config" 2>/dev/null | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Find workload_base_configs.py files to understand the actual configs
find . -name "*workload_base_configs*" -type f

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 564


🏁 Script executed:

# Let's check what actual base configs exist and their values
rg "WorkloadBaseConfig\(" --context=3 scripts/performance/configs/ -A 3 | head -100

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 5232


🏁 Script executed:

# Search for any existing tests related to generate_experiment_config
rg "generate_experiment_config" --context=5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check one of the actual workload config files to see structure
cat scripts/performance/configs/llama/llama3_workload_base_configs.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 20305


🏁 Script executed:

# Run the verification script from the review comment
python - <<'PY'
import ast
from pathlib import Path

def int_lit(node):
    return node.value if isinstance(node, ast.Constant) and isinstance(node.value, int) else None

found = False
for path in sorted(Path(".").rglob("*workload_base_configs.py")):
    try:
        tree = ast.parse(path.read_text())
    except Exception as e:
        continue
    for node in ast.walk(tree):
        if not isinstance(node, ast.Call):
            continue
        fn = node.func
        name = fn.id if isinstance(fn, ast.Name) else (fn.attr if isinstance(fn, ast.Attribute) else None)
        if name != "WorkloadBaseConfig":
            continue
        vals = {kw.arg: int_lit(kw.value) for kw in node.keywords if kw.arg in {"num_gpus", "global_batch_size"}}
        ng = vals.get("num_gpus")
        gbs = vals.get("global_batch_size")
        if ng is None or gbs is None:
            continue
        # Check if any common target GPU counts would cause non-exact division
        problematic = []
        for target_gpus in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
            if ng > 0 and (gbs * target_gpus) % ng != 0:
                truncated = int((gbs * target_gpus) // ng)
                exact = (gbs * target_gpus) / ng
                problematic.append(f"target={target_gpus}: truncates {exact:.3f}→{truncated}")
        if problematic:
            found = True
            print(f"{path.name}:{node.lineno}: num_gpus={ng}, gbs={gbs}")
            for p in problematic[:3]:  # Show first 3
                print(f"  {p}")
        
if not found:
    print("✓ No truncation risks found in literal WorkloadBaseConfig definitions")
PY

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 140


🏁 Script executed:

# Check if there are any tests for the generate_experiment_config function
find . -name "*.py" -path "*/test*" -exec grep -l "generate_experiment_config\|gbs_scaling" {} \;

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Look at how set_post_overrides uses gbs_scaling_factor to understand the relationship better
rg "set_post_overrides" -B 3 -A 15 scripts/performance/utils/overrides.py | head -80

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 542


🏁 Script executed:

# Get the full set_post_overrides function to understand the context
sed -n '1,300p' scripts/performance/utils/overrides.py | tail -150

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 7151


🏁 Script executed:

# Get the complete set_post_overrides function
rg "def set_post_overrides" -A 80 scripts/performance/utils/overrides.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1836


🏁 Script executed:

# Also check how generate_experiment_config is used - in what context is it called
rg "generate_experiment_config" -B 5 -A 5 scripts/performance/utils/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Search more broadly for generate_experiment_config usage
rg "generate_experiment_config" -B 2 -A 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check what happens after set_post_overrides - does it use the gbs_size value from generate_experiment_config?
cat scripts/performance/utils/utils.py | sed -n '140,200p'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2385


Validate GBS scaling is exact; fail fast if target GPU count causes non-integral scaling.

The current code computes gbs_size = int(base_config.gbs_scaling_factor * num_gpus), which silently truncates when the scaled batch size is not an integer. While all existing configs (e.g., 128÷8=16.0, 128÷64=2.0, 256÷64=4.0) divide evenly, future configs may not. Computing in integer space with divmod() and raising an explicit error if the scaled GBS is not integral prevents silent data loss and catches config errors early.

Proposed fix
-    elif num_gpus != base_config.num_gpus:
-        # Scale GBS with num_gpus so experiment name matches the scaled GBS applied in set_post_overrides
-        gbs_size = int(base_config.gbs_scaling_factor * num_gpus)
+    elif num_gpus != base_config.num_gpus:
+        # Keep scaling in integer space to avoid float truncation.
+        scaled_gbs, remainder = divmod(base_config.global_batch_size * num_gpus, base_config.num_gpus)
+        if remainder != 0:
+            raise ValueError(
+                "Scaled global_batch_size is not an integer for the requested GPU count: "
+                f"global_batch_size={base_config.global_batch_size}, "
+                f"base_num_gpus={base_config.num_gpus}, num_gpus={num_gpus}"
+            )
+        gbs_size = scaled_gbs
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/performance/utils/utils.py` around lines 183 - 186, The code silently
truncates scaled global batch size by using int(base_config.gbs_scaling_factor *
num_gpus); update the logic that computes gbs_size (when num_gpus !=
base_config.num_gpus) to compute the multiplication in integer space using
divmod or equivalent (e.g., quotient, remainder =
divmod(base_config.gbs_scaling_factor * num_gpus, 1) or better compute
numerator/denominator so you can divmod exactly) and raise a clear exception if
the result is not an integer so we fail fast; reference the symbols gbs_size,
base_config.gbs_scaling_factor, num_gpus and ensure the error message makes it
obvious this affects the scaling used in set_post_overrides.

gbs_size = base_config.global_batch_size

exp_config = f"gpus{num_gpus}_tp{tp_size}_pp{pp_size}_cp{cp_size}_vp{vp_size}_ep{ep_size}_etp{etp_size}_mbs{mbs_size}_gbs{gbs_size}"
return exp_config

Expand Down