Skip to content

Materialize non-layer modules after FSDP sharding#314

Closed
guan404ming wants to merge 2 commits into
vllm-project:mainfrom
guan404ming:fix/materialize-non-layer-fsdp-modules
Closed

Materialize non-layer modules after FSDP sharding#314
guan404ming wants to merge 2 commits into
vllm-project:mainfrom
guan404ming:fix/materialize-non-layer-fsdp-modules

Conversation

@guan404ming

Copy link
Copy Markdown
Contributor

Why

FSDP model setup only materialized decoder layers, leaving non-layer modules (embed_tokens, lm_head) unhandled if they ended up on meta device.

How

  • Extract FSDP materialization into _materialize_fsdp_model helper
  • Add a second pass to materialize any non-layer top-level modules still on meta device

Copilot AI review requested due to automatic review settings February 27, 2026 15:35

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses an initialization gap in distributed/FSDP training where only decoder layers were being materialized after sharding, potentially leaving other top-level modules (e.g., embeddings / heads) on the meta device.

Changes:

  • Extracted the post-FSDP materialization logic into a _materialize_fsdp_model helper.
  • Added a second pass that attempts to materialize any non-layers top-level children that still have meta parameters.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/speculators/train/trainer.py Outdated
Comment on lines +65 to +68
module.to_empty(device=device)
for sub in module.modules():
if hasattr(sub, "reset_parameters"):
sub.reset_parameters()

Copilot AI Feb 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resetting all non-layer modules on meta via reset_parameters() can produce incorrect initialization for models that intentionally preload/freeze weights. For example, Eagle3DraftModel loads verifier embed_tokens/lm_head weights during __init__ and sets requires_grad=False; if those modules end up on meta after sharding (the situation this PR addresses), this logic will initialize them randomly instead of restoring the verifier weights. Consider special-casing modules like embed_tokens/lm_head to reload from the verifier after materialization (or provide a model hook for post-FSDP rehydration) rather than calling reset_parameters() indiscriminately.

Copilot uses AI. Check for mistakes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this only runs when has_meta is True, meaning parameters are already on meta device with no actual data. In practice, apply_fully_sharded only moves model.layers to meta, so embed_tokens/lm_head retain their verifier weights and this branch is skipped for them.

Comment thread src/speculators/train/trainer.py
Comment thread src/speculators/train/trainer.py Outdated
@fynnsu fynnsu self-requested a review February 27, 2026 16:02
@guan404ming guan404ming force-pushed the fix/materialize-non-layer-fsdp-modules branch 2 times, most recently from 39021fc to 8e14406 Compare March 2, 2026 17:25
@guan404ming guan404ming closed this May 8, 2026
@guan404ming guan404ming force-pushed the fix/materialize-non-layer-fsdp-modules branch from 8e14406 to cc4590e Compare May 8, 2026 17:29
Signed-off-by: Guan-Ming (Wesley) Chiu <guanmingchiu@gmail.com>
Signed-off-by: Guan-Ming (Wesley) Chiu <guanmingchiu@gmail.com>
@guan404ming guan404ming reopened this May 8, 2026
@coderabbitai

coderabbitai Bot commented May 8, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Summary by CodeRabbit

  • Bug Fixes
    • Improved distributed training initialization to ensure proper placement of model parameters and buffers across accelerator devices, enhancing stability and reliability when using distributed training configurations.

Walkthrough

The PR adds FSDP (Fully Sharded Data Parallel) parameter materialization to the trainer. It imports FSDPModule, introduces a _materialize_fsdp_model helper that initializes sharded parameters and buffers on the accelerator device with parameter reset, and integrates this helper into distributed model setup before state dict synchronization.

Changes

FSDP Parameter Materialization in Trainer

Layer / File(s) Summary
FSDP Import
src/speculators/train/trainer.py
Imports FSDPModule to support FSDP-specific operations.
Materialization Helper
src/speculators/train/trainer.py
Implements _materialize_fsdp_model(model) helper that materializes FSDP-sharded parameters/buffers on the active accelerator device via to_empty, calls reset_parameters where available, and handles modules still on the meta device.
Distributed Setup Integration
src/speculators/train/trainer.py
Calls _materialize_fsdp_model(self.model) in distributed setup_model immediately after fully-sharding the model and before broadcasting/loading the full-state-dict across ranks.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Suggested labels

bug, training, two-reviews

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: materializing non-layer modules after FSDP sharding, which directly matches the core objective of the PR.
Description check ✅ Passed The description clearly explains the problem (non-layer modules left on meta device) and the solution (extraction of materialization helper and second pass), directly related to the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot added bug Something isn't working training two-reviews labels May 8, 2026
@mergify

mergify Bot commented May 8, 2026

Copy link
Copy Markdown

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 Require two reviews

Waiting for

  • #approved-reviews-by >= 2
This rule is failing.

PRs labelled "two-reviews" must have at least two approving reviews before merging.

  • #approved-reviews-by >= 2

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/speculators/train/trainer.py`:
- Around line 55-56: The code computes device via device = "cuda" if acc is None
else acc.type which omits the device index and can break distributed FSDP
materialization; change the FSDP/materialization call sites in trainer.py to
accept and use an explicit torch.device constructed with the local rank (e.g.,
torch.device("cuda", local_rank)) instead of the string in device, pass that
torch.device through any functions that call
torch.accelerator.current_accelerator() or rely on device (refer to the acc
variable and device variable in trainer.py and the FSDP materialization/wrapping
calls), and ensure the caller provides local_rank so device placement is
explicit and rank-scoped rather than relying on global accelerator context.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6f04133e-c1d9-4009-9d31-ee60a15a07c8

📥 Commits

Reviewing files that changed from the base of the PR and between cc4590e and 71aef0b.

📒 Files selected for processing (1)
  • src/speculators/train/trainer.py

Comment on lines +55 to +56
acc = torch.accelerator.current_accelerator()
device = "cuda" if acc is None else acc.type

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Checking distributed device pinning and materialization call sites..."
rg -n -C3 'set_device\(|current_accelerator\(|to_empty\(|local_rank|apply_fully_sharded\(|_materialize_fsdp_model\(' --type=py

echo
echo "Checking launch/config paths for local rank propagation..."
rg -n -C2 'LOCAL_RANK|--local-rank|torchrun|cuda:' --glob '*.py' --glob '*.sh' --glob '*.md'

Repository: vllm-project/speculators

Length of output: 31330


🏁 Script executed:

#!/bin/bash
# Examine the _materialize_fsdp_model function and its call site
sed -n '53,78p' src/speculators/train/trainer.py
echo "---"
sed -n '150,165p' src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 1797


🏁 Script executed:

#!/bin/bash
# Check what torch.accelerator.current_device_index() returns and compare to pattern
grep -A 3 "current_device_index" src/speculators/utils/util.py

Repository: vllm-project/speculators

Length of output: 166


🏁 Script executed:

#!/bin/bash
# Verify to_empty device behavior and distributed device context
grep -B 2 -A 5 "to_empty" src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 738


🏁 Script executed:

#!/bin/bash
# Check if there are tests for distributed FSDP materialization
grep -r "_materialize_fsdp_model\|test.*materialize\|test.*fsdp" tests/ --include="*.py" | head -20

Repository: vllm-project/speculators

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Examine the test setup to see if device pinning is verified in distributed tests
sed -n '430,475p' tests/unit/train/test_setup_model.py

Repository: vllm-project/speculators

Length of output: 1640


🏁 Script executed:

#!/bin/bash
# Check the actual behavior of torch.accelerator calls at materialization time
grep -B 5 "set_device_index\|init_process_group" src/speculators/train/utils.py | head -30

Repository: vllm-project/speculators

Length of output: 465


🏁 Script executed:

#!/bin/bash
# Check if there's explicit verification that device placement is correct after materialization
sed -n '460,520p' tests/unit/train/test_setup_model.py

Repository: vllm-project/speculators

Length of output: 2153


🏁 Script executed (no clone):

# Check PyTorch documentation behavior for to_empty with device strings
python3 << 'EOF'
import torch

# Test what happens with to_empty and device strings
print("Testing to_empty behavior with device strings:")
print("torch.__version__:", torch.__version__)

# Create a test module
m = torch.nn.Linear(10, 10)
print("Original device:", m.weight.device)

# Try to_empty with string device (this is what the code does)
try:
    m.to_empty(device="cuda")
    print("to_empty(device='cuda') - device set to:", m.weight.device)
except Exception as e:
    print("to_empty(device='cuda') error:", type(e).__name__, str(e)[:100])

# Try to_empty with explicit torch.device
try:
    m2 = torch.nn.Linear(10, 10)
    m2.to_empty(device=torch.device("cuda:0"))
    print("to_empty(device=torch.device('cuda:0')) - device set to:", m2.weight.device)
except Exception as e:
    print("to_empty with explicit device error:", type(e).__name__)
EOF

Length of output: 149


🏁 Script executed:

#!/bin/bash
# Look for explicit device assertions or checks in the distributed test
sed -n '460,550p' tests/unit/train/test_setup_model.py | grep -i "device\|cuda"

Repository: vllm-project/speculators

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check if the actual code in setup_model passes device context correctly
sed -n '130,165p' src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 1406


🏁 Script executed:

#!/bin/bash
# Verify what happens in single-GPU (non-distributed) case at line 142
sed -n '140,145p' src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 325


🏁 Script executed:

#!/bin/bash
# Check if there's anywhere that actually constructs proper device string with rank
grep -n "torch.device\|cuda:" src/speculators/train/trainer.py

Repository: vllm-project/speculators

Length of output: 50


Pass explicit rank-scoped device to FSDP materialization instead of relying on implicit context.

The function uses device = "cuda" if acc is None else acc.type which omits the device index, creating fragility in distributed training. While torch.accelerator.set_device_index(local_rank) is called during setup, relying on implicit global device context violates distributed correctness principles. Pass an explicit torch.device with the local rank from the call site.

Suggested fix
-def _materialize_fsdp_model(model: torch.nn.Module):
+def _materialize_fsdp_model(model: torch.nn.Module, device: torch.device) -> None:
     """Materialize and reset parameters for a freshly sharded FSDP model."""
-    acc = torch.accelerator.current_accelerator()
-    device = "cuda" if acc is None else acc.type

     for m in model.layers.children():  # type: ignore[union-attr]
         if not isinstance(m, FSDPModule):
             continue
@@
-            _materialize_fsdp_model(self.model)
+            acc = torch.accelerator.current_accelerator()
+            if acc is None:
+                raise RuntimeError("No accelerator available for distributed FSDP setup")
+            materialize_device = torch.device(f"{acc.type}:{self.local_rank}")
+            _materialize_fsdp_model(self.model, materialize_device)

Per coding guidelines: "Pay close attention to distributed training correctness: verify device placement is consistent, and FSDP wrapping is correct."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/train/trainer.py` around lines 55 - 56, The code computes
device via device = "cuda" if acc is None else acc.type which omits the device
index and can break distributed FSDP materialization; change the
FSDP/materialization call sites in trainer.py to accept and use an explicit
torch.device constructed with the local rank (e.g., torch.device("cuda",
local_rank)) instead of the string in device, pass that torch.device through any
functions that call torch.accelerator.current_accelerator() or rely on device
(refer to the acc variable and device variable in trainer.py and the FSDP
materialization/wrapping calls), and ensure the caller provides local_rank so
device placement is explicit and rank-scoped rather than relying on global
accelerator context.

@fynnsu

fynnsu commented May 21, 2026

Copy link
Copy Markdown
Collaborator

Closing as I don't think this is needed since #333 was merged. That pr updated a lot of the loading/model setup logic and we are no longer moving the model layers to meta device before sharding. So the weights should be initialized (by the transformers post_init() step) and don't need to be re-materialized.

@fynnsu fynnsu closed this May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working training two-reviews

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants