From dfe2f67c5813e37dc2ef5dce921510c57b521a62 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 14 Apr 2021 08:43:40 -0700 Subject: [PATCH 1/2] Make _get_default_cuda_device more robust to modules without params --- .../nn/data_parallel/fully_sharded_data_parallel.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 25762fd93..24410c8c9 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1540,11 +1540,14 @@ def _print_r0(self, msg: str, restart: bool = False) -> None: def _get_default_cuda_device(module: nn.Module) -> torch.device: """Try to infer CUDA device from module parameters.""" - compute_device = next(module.parameters()).device - if compute_device.type != "cuda": - # Fall back to current CUDA device. - compute_device = torch.device("cuda") - return compute_device + try: + compute_device = next(module.parameters()).device + if compute_device.type == "cuda": + return compute_device + except StopIteration: + pass + # Fall back to current CUDA device + return torch.device("cuda") @torch.no_grad() From 75cd3d3234ffc9db8ebe7ce042de4969d33adb47 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 14 Apr 2021 08:50:44 -0700 Subject: [PATCH 2/2] Fix auto_wrap docstring --- fairscale/nn/wrap/auto_wrap.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fairscale/nn/wrap/auto_wrap.py b/fairscale/nn/wrap/auto_wrap.py index 6b53dc448..abfbebf0b 100644 --- a/fairscale/nn/wrap/auto_wrap.py +++ b/fairscale/nn/wrap/auto_wrap.py @@ -88,9 +88,11 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A with enable_wrap(**params): # Wraps layer in FSDP by default if within context self.l1 = wrap(torch.nn.Linear(5, 5)) - # Wraps children modules based on a different min_num_params - my_auto_wrap_policy = functools.partial(auto_wrap_policy, min_num_params=1e7) - self.l2 = auto_wrap(TransformerBlock(), shuold_wrap=my_auto_wrap_policy) + self.l2 = auto_wrap( + TransformerBlock(), + # Wraps children modules based on a different min_num_params + auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7) + ) Args: auto_wrap_policy (Callable, Optional):