Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/speculators/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
StateDictOptions,
set_model_state_dict,
)
from torch.distributed.fsdp import FSDPModule
from torch.utils.data import DataLoader
from tqdm import TqdmExperimentalWarning
from tqdm.rich import tqdm
Expand Down Expand Up @@ -49,6 +50,33 @@ class TrainerConfig(NamedTuple):
log_freq: int = 1


def _materialize_fsdp_model(model: torch.nn.Module):
"""Materialize and reset parameters for a freshly sharded FSDP model."""
acc = torch.accelerator.current_accelerator()
device = "cuda" if acc is None else acc.type
Comment on lines +55 to +56

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Checking distributed device pinning and materialization call sites..."
rg -n -C3 'set_device\(|current_accelerator\(|to_empty\(|local_rank|apply_fully_sharded\(|_materialize_fsdp_model\(' --type=py

echo
echo "Checking launch/config paths for local rank propagation..."
rg -n -C2 'LOCAL_RANK|--local-rank|torchrun|cuda:' --glob '*.py' --glob '*.sh' --glob '*.md'

Repository: vllm-project/speculators

Length of output: 31330


🏁 Script executed:

#!/bin/bash
# Examine the _materialize_fsdp_model function and its call site
sed -n '53,78p' src/speculators/train/trainer.py
echo "---"
sed -n '150,165p' src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 1797


🏁 Script executed:

#!/bin/bash
# Check what torch.accelerator.current_device_index() returns and compare to pattern
grep -A 3 "current_device_index" src/speculators/utils/util.py

Repository: vllm-project/speculators

Length of output: 166


🏁 Script executed:

#!/bin/bash
# Verify to_empty device behavior and distributed device context
grep -B 2 -A 5 "to_empty" src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 738


🏁 Script executed:

#!/bin/bash
# Check if there are tests for distributed FSDP materialization
grep -r "_materialize_fsdp_model\|test.*materialize\|test.*fsdp" tests/ --include="*.py" | head -20

Repository: vllm-project/speculators

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Examine the test setup to see if device pinning is verified in distributed tests
sed -n '430,475p' tests/unit/train/test_setup_model.py

Repository: vllm-project/speculators

Length of output: 1640


🏁 Script executed:

#!/bin/bash
# Check the actual behavior of torch.accelerator calls at materialization time
grep -B 5 "set_device_index\|init_process_group" src/speculators/train/utils.py | head -30

Repository: vllm-project/speculators

Length of output: 465


🏁 Script executed:

#!/bin/bash
# Check if there's explicit verification that device placement is correct after materialization
sed -n '460,520p' tests/unit/train/test_setup_model.py

Repository: vllm-project/speculators

Length of output: 2153


🏁 Script executed (no clone):

# Check PyTorch documentation behavior for to_empty with device strings
python3 << 'EOF'
import torch

# Test what happens with to_empty and device strings
print("Testing to_empty behavior with device strings:")
print("torch.__version__:", torch.__version__)

# Create a test module
m = torch.nn.Linear(10, 10)
print("Original device:", m.weight.device)

# Try to_empty with string device (this is what the code does)
try:
    m.to_empty(device="cuda")
    print("to_empty(device='cuda') - device set to:", m.weight.device)
except Exception as e:
    print("to_empty(device='cuda') error:", type(e).__name__, str(e)[:100])

# Try to_empty with explicit torch.device
try:
    m2 = torch.nn.Linear(10, 10)
    m2.to_empty(device=torch.device("cuda:0"))
    print("to_empty(device=torch.device('cuda:0')) - device set to:", m2.weight.device)
except Exception as e:
    print("to_empty with explicit device error:", type(e).__name__)
EOF

Length of output: 149


🏁 Script executed:

#!/bin/bash
# Look for explicit device assertions or checks in the distributed test
sed -n '460,550p' tests/unit/train/test_setup_model.py | grep -i "device\|cuda"

Repository: vllm-project/speculators

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check if the actual code in setup_model passes device context correctly
sed -n '130,165p' src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 1406


🏁 Script executed:

#!/bin/bash
# Verify what happens in single-GPU (non-distributed) case at line 142
sed -n '140,145p' src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 325


🏁 Script executed:

#!/bin/bash
# Check if there's anywhere that actually constructs proper device string with rank
grep -n "torch.device\|cuda:" src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 50


Pass explicit rank-scoped device to FSDP materialization instead of relying on implicit context.

The function uses device = "cuda" if acc is None else acc.type which omits the device index, creating fragility in distributed training. While torch.accelerator.set_device_index(local_rank) is called during setup, relying on implicit global device context violates distributed correctness principles. Pass an explicit torch.device with the local rank from the call site.

Suggested fix
-def _materialize_fsdp_model(model: torch.nn.Module):
+def _materialize_fsdp_model(model: torch.nn.Module, device: torch.device) -> None:
     """Materialize and reset parameters for a freshly sharded FSDP model."""
-    acc = torch.accelerator.current_accelerator()
-    device = "cuda" if acc is None else acc.type

     for m in model.layers.children():  # type: ignore[union-attr]
         if not isinstance(m, FSDPModule):
             continue
@@
-            _materialize_fsdp_model(self.model)
+            acc = torch.accelerator.current_accelerator()
+            if acc is None:
+                raise RuntimeError("No accelerator available for distributed FSDP setup")
+            materialize_device = torch.device(f"{acc.type}:{self.local_rank}")
+            _materialize_fsdp_model(self.model, materialize_device)

Per coding guidelines: "Pay close attention to distributed training correctness: verify device placement is consistent, and FSDP wrapping is correct."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/train/trainer.py` around lines 55 - 56, The code computes
device via device = "cuda" if acc is None else acc.type which omits the device
index and can break distributed FSDP materialization; change the
FSDP/materialization call sites in trainer.py to accept and use an explicit
torch.device constructed with the local rank (e.g., torch.device("cuda",
local_rank)) instead of the string in device, pass that torch.device through any
functions that call torch.accelerator.current_accelerator() or rely on device
(refer to the acc variable and device variable in trainer.py and the FSDP
materialization/wrapping calls), and ensure the caller provides local_rank so
device placement is explicit and rank-scoped rather than relying on global
accelerator context.


for m in model.layers.children(): # type: ignore[union-attr]
if not isinstance(m, FSDPModule):
continue
m.to_empty(device=device) # type: ignore[attr-defined]
for sub_module in m.modules(): # type: ignore[attr-defined]
if hasattr(sub_module, "reset_parameters"):
sub_module.reset_parameters() # type: ignore[operator]

for name, module in model.named_children():
if name == "layers":
continue
tensors = list(module.parameters(recurse=True)) + list(
module.buffers(recurse=True)
)
has_meta = any(t.device.type == "meta" for t in tensors)
if has_meta:
module.to_empty(device=device)
for sub in module.modules():
if hasattr(sub, "reset_parameters"):
sub.reset_parameters() # type: ignore[operator]


class Trainer:
def __init__(
self,
Expand Down Expand Up @@ -127,6 +155,7 @@ def setup_model(self):
if load_checkpoint:
self.checkpointer.load_model_state_dict(self.model)
else:
_materialize_fsdp_model(self.model)
# Broadcast full state dict from rank 0 to all ranks
set_model_state_dict(
self.model,
Expand Down
Loading