Skip to content

Commit 8f7ee69

Browse files
author
Myle Ott
authored
[fix] [FSDP] Make _get_default_cuda_device more robust to modules without params (#606)
1 parent 82d6997 commit 8f7ee69

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

fairscale/nn/data_parallel/fully_sharded_data_parallel.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1540,11 +1540,14 @@ def _print_r0(self, msg: str, restart: bool = False) -> None:
15401540

15411541
def _get_default_cuda_device(module: nn.Module) -> torch.device:
15421542
"""Try to infer CUDA device from module parameters."""
1543-
compute_device = next(module.parameters()).device
1544-
if compute_device.type != "cuda":
1545-
# Fall back to current CUDA device.
1546-
compute_device = torch.device("cuda")
1547-
return compute_device
1543+
try:
1544+
compute_device = next(module.parameters()).device
1545+
if compute_device.type == "cuda":
1546+
return compute_device
1547+
except StopIteration:
1548+
pass
1549+
# Fall back to current CUDA device
1550+
return torch.device("cuda")
15481551

15491552

15501553
@torch.no_grad()

fairscale/nn/wrap/auto_wrap.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A
8888
with enable_wrap(**params):
8989
# Wraps layer in FSDP by default if within context
9090
self.l1 = wrap(torch.nn.Linear(5, 5))
91-
# Wraps children modules based on a different min_num_params
92-
my_auto_wrap_policy = functools.partial(auto_wrap_policy, min_num_params=1e7)
93-
self.l2 = auto_wrap(TransformerBlock(), shuold_wrap=my_auto_wrap_policy)
91+
self.l2 = auto_wrap(
92+
TransformerBlock(),
93+
# Wraps children modules based on a different min_num_params
94+
auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7)
95+
)
9496
9597
Args:
9698
auto_wrap_policy (Callable, Optional):

0 commit comments

Comments
 (0)