Skip to content

Commit

Permalink
Merge pull request #462 from allenai/epwalsh/fsdp-wrap-patch
Browse files Browse the repository at this point in the history
Cherry-pick FSDP wrap patch from #443
  • Loading branch information
epwalsh authored Feb 22, 2024
2 parents cc36709 + e101457 commit 37ca789
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ jobs:
spec: |
version: v2
description: GPU Tests
budget: ai2/oe-training
tasks:
- name: tests
image:
Expand Down
43 changes: 29 additions & 14 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,25 +1339,37 @@ def forward(
def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
if wrap_strategy is None:
return None

# The 'recurse' mode for the wrap function does not behave like you'd expect.
# Even if we return False, it may still recurse because PyTorch does what it wants,
# not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer)
# but not other linear layers within a block.
# So we have to explicitly tell PyTorch which linear layers to wrap, and we also just
# return True in 'recurse' mode for simplicity.
size_based_module_to_wrap = {self.transformer.wte}
if hasattr(self.transformer, "ff_out"):
size_based_module_to_wrap.add(self.transformer.ff_out)

if wrap_strategy == FSDPWrapStrategy.by_block:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, OlmoBlock)
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock)
return True
else:
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, (OlmoBlock,)) or module in size_based_module_to_wrap
if recurse:
# Determine if we should recurse.
return not isinstance(module, OlmoBlock)
return True
else:
# Determine if we should wrap.
return isinstance(module, (OlmoBlock, nn.Linear, nn.Embedding))
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group:
Expand All @@ -1368,9 +1380,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, OlmoBlockGroup)
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlockGroup)
return True
else:
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
Expand All @@ -1381,12 +1395,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, (OlmoBlockGroup,)) or module in size_based_module_to_wrap
if recurse:
# Determine if we should recurse.
return not isinstance(module, OlmoBlockGroup)
return True
else:
# Determine if we should wrap.
return isinstance(module, (OlmoBlockGroup, nn.Linear, nn.Embedding))
return wrap

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.size_based:
Expand All @@ -1408,9 +1421,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
wrap = isinstance(module, OlmoBlock) and module.layer_id % c == 0
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock) and module.layer_id % c == 0
return True
else:
return wrap

return fsdp_wrap_fn
else:
Expand Down

0 comments on commit 37ca789

Please sign in to comment.