From 1334235fb1fa03eaf5144d10a1af665147cb3c0a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 22 Feb 2024 09:24:16 -0800 Subject: [PATCH 1/2] Cherry-pick FSDP wrap patch from #443. Within the FSDP wrapping function, the "recurse" mode does not behave as expected (for large models anyway). In particular it seems like there's an inherent "sized-based" policy being applied that ignores our intent to stop recursing down the module tree at, for example, the level of an transformer block. The API for wrapping functions itself is pretty simple, so I can't imagine how we'd be interpreting it wrong. I suspect this another FSDP bug. --- olmo/model.py | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 466a37a99..a11eceb71 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1339,25 +1339,37 @@ def forward( def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None): if wrap_strategy is None: return None + + # The 'recurse' mode for the wrap function does not behave like you'd expect. + # Even if we return False, it may still recurse because PyTorch does what it wants, + # not what you want. This causes issues when, for example, we want to wrap 'ff_out' (a linear layer) + # but not other linear layers within a block. + # So we have to explicitly tell PyTorch which linear layers to wrap, and we also just + # return True in 'recurse' mode for simplicity. + size_based_module_to_wrap = {self.transformer.wte} + if hasattr(self.transformer, "ff_out"): + size_based_module_to_wrap.add(self.transformer.ff_out) + if wrap_strategy == FSDPWrapStrategy.by_block: def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, OlmoBlock) if recurse: - return True # always recurse for simplicity - return isinstance(module, OlmoBlock) + return True + else: + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_and_size: def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, (OlmoBlock,)) or module in size_based_module_to_wrap if recurse: - # Determine if we should recurse. - return not isinstance(module, OlmoBlock) + return True else: - # Determine if we should wrap. - return isinstance(module, (OlmoBlock, nn.Linear, nn.Embedding)) + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_group: @@ -1368,9 +1380,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, OlmoBlockGroup) if recurse: - return True # always recurse for simplicity - return isinstance(module, OlmoBlockGroup) + return True + else: + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size: @@ -1381,12 +1395,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, (OlmoBlockGroup,)) or module in size_based_module_to_wrap if recurse: - # Determine if we should recurse. - return not isinstance(module, OlmoBlockGroup) + return True else: - # Determine if we should wrap. - return isinstance(module, (OlmoBlockGroup, nn.Linear, nn.Embedding)) + return wrap return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.size_based: @@ -1408,9 +1421,11 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel + wrap = isinstance(module, OlmoBlock) and module.layer_id % c == 0 if recurse: - return True # always recurse for simplicity - return isinstance(module, OlmoBlock) and module.layer_id % c == 0 + return True + else: + return wrap return fsdp_wrap_fn else: From e10145720dd494f0ccaec2ea5e46a47596c59817 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 22 Feb 2024 09:50:15 -0800 Subject: [PATCH 2/2] Add budget field to Beaker CI jobs --- .github/workflows/main.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index adac82794..1093adb3d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -127,6 +127,7 @@ jobs: spec: | version: v2 description: GPU Tests + budget: ai2/oe-training tasks: - name: tests image: