Materialize non-layer modules after FSDP sharding#314
Conversation
There was a problem hiding this comment.
Pull request overview
This PR addresses an initialization gap in distributed/FSDP training where only decoder layers were being materialized after sharding, potentially leaving other top-level modules (e.g., embeddings / heads) on the meta device.
Changes:
- Extracted the post-FSDP materialization logic into a
_materialize_fsdp_modelhelper. - Added a second pass that attempts to materialize any non-
layerstop-level children that still have meta parameters.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| module.to_empty(device=device) | ||
| for sub in module.modules(): | ||
| if hasattr(sub, "reset_parameters"): | ||
| sub.reset_parameters() |
There was a problem hiding this comment.
Resetting all non-layer modules on meta via reset_parameters() can produce incorrect initialization for models that intentionally preload/freeze weights. For example, Eagle3DraftModel loads verifier embed_tokens/lm_head weights during __init__ and sets requires_grad=False; if those modules end up on meta after sharding (the situation this PR addresses), this logic will initialize them randomly instead of restoring the verifier weights. Consider special-casing modules like embed_tokens/lm_head to reload from the verifier after materialization (or provide a model hook for post-FSDP rehydration) rather than calling reset_parameters() indiscriminately.
There was a problem hiding this comment.
I think this only runs when has_meta is True, meaning parameters are already on meta device with no actual data. In practice, apply_fully_sharded only moves model.layers to meta, so embed_tokens/lm_head retain their verifier weights and this branch is skipped for them.
39021fc to
8e14406
Compare
8e14406 to
cc4590e
Compare
Signed-off-by: Guan-Ming (Wesley) Chiu <guanmingchiu@gmail.com>
Signed-off-by: Guan-Ming (Wesley) Chiu <guanmingchiu@gmail.com>
Summary by CodeRabbit
WalkthroughThe PR adds FSDP (Fully Sharded Data Parallel) parameter materialization to the trainer. It imports ChangesFSDP Parameter Materialization in Trainer
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Suggested labels
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 Require two reviewsWaiting for
This rule is failing.PRs labelled "two-reviews" must have at least two approving reviews before merging.
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/speculators/train/trainer.py`:
- Around line 55-56: The code computes device via device = "cuda" if acc is None
else acc.type which omits the device index and can break distributed FSDP
materialization; change the FSDP/materialization call sites in trainer.py to
accept and use an explicit torch.device constructed with the local rank (e.g.,
torch.device("cuda", local_rank)) instead of the string in device, pass that
torch.device through any functions that call
torch.accelerator.current_accelerator() or rely on device (refer to the acc
variable and device variable in trainer.py and the FSDP materialization/wrapping
calls), and ensure the caller provides local_rank so device placement is
explicit and rank-scoped rather than relying on global accelerator context.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6f04133e-c1d9-4009-9d31-ee60a15a07c8
📒 Files selected for processing (1)
src/speculators/train/trainer.py
| acc = torch.accelerator.current_accelerator() | ||
| device = "cuda" if acc is None else acc.type |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Checking distributed device pinning and materialization call sites..."
rg -n -C3 'set_device\(|current_accelerator\(|to_empty\(|local_rank|apply_fully_sharded\(|_materialize_fsdp_model\(' --type=py
echo
echo "Checking launch/config paths for local rank propagation..."
rg -n -C2 'LOCAL_RANK|--local-rank|torchrun|cuda:' --glob '*.py' --glob '*.sh' --glob '*.md'Repository: vllm-project/speculators
Length of output: 31330
🏁 Script executed:
#!/bin/bash
# Examine the _materialize_fsdp_model function and its call site
sed -n '53,78p' src/speculators/train/trainer.py
echo "---"
sed -n '150,165p' src/speculators/train/trainer.pyRepository: vllm-project/speculators
Length of output: 1797
🏁 Script executed:
#!/bin/bash
# Check what torch.accelerator.current_device_index() returns and compare to pattern
grep -A 3 "current_device_index" src/speculators/utils/util.pyRepository: vllm-project/speculators
Length of output: 166
🏁 Script executed:
#!/bin/bash
# Verify to_empty device behavior and distributed device context
grep -B 2 -A 5 "to_empty" src/speculators/train/trainer.pyRepository: vllm-project/speculators
Length of output: 738
🏁 Script executed:
#!/bin/bash
# Check if there are tests for distributed FSDP materialization
grep -r "_materialize_fsdp_model\|test.*materialize\|test.*fsdp" tests/ --include="*.py" | head -20Repository: vllm-project/speculators
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Examine the test setup to see if device pinning is verified in distributed tests
sed -n '430,475p' tests/unit/train/test_setup_model.pyRepository: vllm-project/speculators
Length of output: 1640
🏁 Script executed:
#!/bin/bash
# Check the actual behavior of torch.accelerator calls at materialization time
grep -B 5 "set_device_index\|init_process_group" src/speculators/train/utils.py | head -30Repository: vllm-project/speculators
Length of output: 465
🏁 Script executed:
#!/bin/bash
# Check if there's explicit verification that device placement is correct after materialization
sed -n '460,520p' tests/unit/train/test_setup_model.pyRepository: vllm-project/speculators
Length of output: 2153
🏁 Script executed (no clone):
# Check PyTorch documentation behavior for to_empty with device strings
python3 << 'EOF'
import torch
# Test what happens with to_empty and device strings
print("Testing to_empty behavior with device strings:")
print("torch.__version__:", torch.__version__)
# Create a test module
m = torch.nn.Linear(10, 10)
print("Original device:", m.weight.device)
# Try to_empty with string device (this is what the code does)
try:
m.to_empty(device="cuda")
print("to_empty(device='cuda') - device set to:", m.weight.device)
except Exception as e:
print("to_empty(device='cuda') error:", type(e).__name__, str(e)[:100])
# Try to_empty with explicit torch.device
try:
m2 = torch.nn.Linear(10, 10)
m2.to_empty(device=torch.device("cuda:0"))
print("to_empty(device=torch.device('cuda:0')) - device set to:", m2.weight.device)
except Exception as e:
print("to_empty with explicit device error:", type(e).__name__)
EOF
Length of output: 149
🏁 Script executed:
#!/bin/bash
# Look for explicit device assertions or checks in the distributed test
sed -n '460,550p' tests/unit/train/test_setup_model.py | grep -i "device\|cuda"Repository: vllm-project/speculators
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check if the actual code in setup_model passes device context correctly
sed -n '130,165p' src/speculators/train/trainer.pyRepository: vllm-project/speculators
Length of output: 1406
🏁 Script executed:
#!/bin/bash
# Verify what happens in single-GPU (non-distributed) case at line 142
sed -n '140,145p' src/speculators/train/trainer.pyRepository: vllm-project/speculators
Length of output: 325
🏁 Script executed:
#!/bin/bash
# Check if there's anywhere that actually constructs proper device string with rank
grep -n "torch.device\|cuda:" src/speculators/train/trainer.pyRepository: vllm-project/speculators
Length of output: 50
Pass explicit rank-scoped device to FSDP materialization instead of relying on implicit context.
The function uses device = "cuda" if acc is None else acc.type which omits the device index, creating fragility in distributed training. While torch.accelerator.set_device_index(local_rank) is called during setup, relying on implicit global device context violates distributed correctness principles. Pass an explicit torch.device with the local rank from the call site.
Suggested fix
-def _materialize_fsdp_model(model: torch.nn.Module):
+def _materialize_fsdp_model(model: torch.nn.Module, device: torch.device) -> None:
"""Materialize and reset parameters for a freshly sharded FSDP model."""
- acc = torch.accelerator.current_accelerator()
- device = "cuda" if acc is None else acc.type
for m in model.layers.children(): # type: ignore[union-attr]
if not isinstance(m, FSDPModule):
continue
@@
- _materialize_fsdp_model(self.model)
+ acc = torch.accelerator.current_accelerator()
+ if acc is None:
+ raise RuntimeError("No accelerator available for distributed FSDP setup")
+ materialize_device = torch.device(f"{acc.type}:{self.local_rank}")
+ _materialize_fsdp_model(self.model, materialize_device)Per coding guidelines: "Pay close attention to distributed training correctness: verify device placement is consistent, and FSDP wrapping is correct."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/speculators/train/trainer.py` around lines 55 - 56, The code computes
device via device = "cuda" if acc is None else acc.type which omits the device
index and can break distributed FSDP materialization; change the
FSDP/materialization call sites in trainer.py to accept and use an explicit
torch.device constructed with the local rank (e.g., torch.device("cuda",
local_rank)) instead of the string in device, pass that torch.device through any
functions that call torch.accelerator.current_accelerator() or rely on device
(refer to the acc variable and device variable in trainer.py and the FSDP
materialization/wrapping calls), and ensure the caller provides local_rank so
device placement is explicit and rank-scoped rather than relying on global
accelerator context.
|
Closing as I don't think this is needed since #333 was merged. That pr updated a lot of the loading/model setup logic and we are no longer moving the model layers to meta device before sharding. So the weights should be initialized (by the transformers |
Why
FSDP model setup only materialized decoder layers, leaving non-layer modules (embed_tokens, lm_head) unhandled if they ended up on meta device.
How