From f85a1d13c33536e1feec252ec20dc92dcc805958 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 5 Aug 2024 16:53:12 -0700 Subject: [PATCH 1/8] Add config option for saving module outputs --- olmo/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/olmo/config.py b/olmo/config.py index 8d3ed0823..d53c71de3 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -1217,6 +1217,11 @@ class TrainConfig(BaseConfig): Path to cache directory of HF datasets saved with `datasets.save_to_disk`. """ + module_outputs_save_steps: Optional[List[int]] = None + """ + Outputs of model submodules are saved during the provided steps. + """ + @property def autocast_precision(self) -> torch.dtype: if self.precision == "amp_bf16": From 28a2d18c121cd0c3f6afe11b409fe09ae0c97581 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 5 Aug 2024 16:53:44 -0700 Subject: [PATCH 2/8] Add model hooks for saving module outputs --- olmo/train.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/olmo/train.py b/olmo/train.py index fd88b205e..341055003 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1,6 +1,7 @@ from __future__ import annotations import cProfile +import functools import gc import logging import math @@ -20,6 +21,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F +import torch.utils +import torch.utils.hooks import wandb from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -650,6 +653,53 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec else: raise NotImplementedError(checkpoint_type) + def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.utils.hooks.RemovableHandle]: + if ( + self.cfg.module_outputs_save_steps is None + or self.global_step not in self.cfg.module_outputs_save_steps + ): + return [] + + if micro_batch_idx != 0 or get_global_rank() != 0: + # Hook is currently only used on the first microbatch of rank 0 + return [] + + trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}" + if trace_save_folder.exists(): + if self.cfg.save_overwrite: + shutil.rmtree(trace_save_folder) + else: + raise OLMoConfigurationError( + f"Attempting to overwrite traces at step {self.global_step} without --save_overwrite" + ) + trace_save_folder.mkdir(parents=True) + + def trace_outputs_hook( + module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + if len(args) == 0: + log.info("No input args for module %s, output %s", module_name, output) + + module_input = args[0] if len(args) > 0 else torch.tensor(()) + trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}" + trace_save_folder.mkdir(parents=True, exist_ok=True) + + module_occurence_num = 0 + while ( + module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt" + ).exists(): + module_occurence_num += 1 + torch.save(module_input, module_input_filepath) + + module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt" + torch.save(output, module_output_filepath) + + output_hooks = [] + for module_name, module in self.model.named_modules(prefix="model"): + output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name))) + + return output_hooks + def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor: # Labels are just input IDs shifted to the left (first item is ignored). labels, label_mask, attention_mask, instance_mask = ( @@ -740,6 +790,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor if micro_batch_idx != num_micro_batches - 1: grad_sync_context = self.dist_model.no_sync + # Register output hooks + output_hooks: List[torch.utils.hooks.RemovableHandle] = [] + output_hooks += self._setup_module_output_save_hooks(micro_batch_idx) + with grad_sync_context(): with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision): # Run forward pass. @@ -756,6 +810,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor # Run backward pass. loss.backward() + # Remove output hooks + for hook in output_hooks: + hook.remove() + return ce_batch_loss, z_batch_loss def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]: From e1073dc578d528f6c3752df358a320734b067b1c Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 5 Aug 2024 16:56:39 -0700 Subject: [PATCH 3/8] Add script for comparing module outputs --- scripts/compare_module_outputs.py | 134 ++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 scripts/compare_module_outputs.py diff --git a/scripts/compare_module_outputs.py b/scripts/compare_module_outputs.py new file mode 100644 index 000000000..44b7653d4 --- /dev/null +++ b/scripts/compare_module_outputs.py @@ -0,0 +1,134 @@ +from argparse import ArgumentParser +import logging +from pathlib import Path +from typing import List + +import torch + + +logger = logging.getLogger(__name__) + + +def _get_module_names(checkpoint_traces_folder: Path) -> List[str]: + module_names = [] + for trace_file in checkpoint_traces_folder.iterdir(): + trace_file_name = trace_file.name + if trace_file_name.endswith("_input.pt"): + module_name = trace_file_name.removesuffix("_input.pt") + elif trace_file_name.endswith("_output.pt"): + module_name = trace_file_name.removesuffix("_output.pt") + else: + logger.warning("Cannot get parameter from file %s, skipping", trace_file_name) + + module_names.append(module_name) + + return module_names + + +def compare_module_output( + base_traces_folder: Path, + compare_traces_folder: Path, + module_name: str, + *, + include_non_tensor_outputs: bool = True, + verbose: bool = False, +): + base_module_input_path = base_traces_folder / f"{module_name}_input.pt" + base_module_output_path = base_traces_folder / f"{module_name}_output.pt" + compare_module_input_path = compare_traces_folder / f"{module_name}_input.pt" + compare_module_output_path = compare_traces_folder / f"{module_name}_output.pt" + + map_location = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + base_input = torch.load(str(base_module_input_path), map_location=map_location) + compare_input = torch.load(str(compare_module_input_path), map_location=map_location) + + if verbose or base_input.dtype != compare_input.dtype: + logger.info("%s input dtypes: %s %s", module_name, base_input.dtype, compare_input.dtype) + if verbose or base_input.shape != compare_input.shape: + logger.info("%s input shapes: %s %s", module_name, base_input.shape, compare_input.shape) + if (norm_diff := torch.linalg.vector_norm((compare_input - base_input).float()).item()) != 0.0 or verbose: + logger.info("%s input norm diff: %.6f", module_name, norm_diff) + if "wte" in module_name: + logger.info("%s mis-matching wte elements: %d", module_name, torch.sum(torch.logical_not(torch.eq(base_input, compare_input)))) + + base_output = torch.load(str(base_module_output_path), map_location=map_location) + compare_output = torch.load(str(compare_module_output_path), map_location=map_location) + + if isinstance(base_output, torch.Tensor): + if verbose or base_output.dtype != compare_output.dtype: + logger.info("%s output dtypes: %s %s", module_name, base_output.dtype, compare_output.dtype) + if (norm_diff := torch.linalg.vector_norm((compare_output - base_output).float()).item()) != 0.0 or verbose: + logger.info("%s output norm diff: %.6f", module_name, norm_diff) + elif include_non_tensor_outputs: + logger.info("%s outputs: %s %s", module_name, base_output, compare_output) + else: + if verbose: + logger.info("Base output is type %s, skipping", type(base_output)) + + +def compare_module_outputs( + base_traces_folder: Path, + compare_traces_folder: Path, + *, + include_non_tensor_outputs: bool = True, + verbose: bool = False, +): + base_modules = set(_get_module_names(base_traces_folder)) + compare_modules = set(_get_module_names(compare_traces_folder)) + + base_only_modules = base_modules - compare_modules + if len(base_only_modules) > 0: + logger.info("Base-only modules: %s", ", ".join(base_only_modules)) + + compare_only_modules = compare_modules - base_modules + if len(compare_only_modules) > 0: + logger.info("Compare-only modules: %s", ", ".join(compare_only_modules)) + + common_modules = base_modules.intersection(compare_modules) + for module_name in sorted(common_modules): + compare_module_output( + base_traces_folder, + compare_traces_folder, + module_name, + include_non_tensor_outputs=include_non_tensor_outputs, + verbose=verbose, + ) + + +def main(): + logging.basicConfig(encoding="utf-8", level=logging.INFO) + + parser = ArgumentParser() + parser.add_argument( + "base_model_traces_path", + type=Path, + help="Path where output traces of the base (i.e. reference) model are stored", + ) + parser.add_argument( + "compare_model_traces_path", + type=Path, + help="Path where output traces of the compare (a.k.a new, different) model are stored", + ) + parser.add_argument( + "--skip_non_tensor_outputs", + action="store_false", + dest="include_non_tensor_outputs", + help="If set, do not compare module outputs that are not tensors", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="If set, show extra information", + ) + + args = parser.parse_args() + compare_module_outputs( + args.base_model_traces_path, + args.compare_model_traces_path, + include_non_tensor_outputs=args.include_non_tensor_outputs, + verbose=args.verbose, + ) + + +if __name__ == "__main__": + main() From 164333e29bd0bcebe754489d9ec2b6e350a7cdc5 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 5 Aug 2024 16:58:45 -0700 Subject: [PATCH 4/8] Run ruff --- scripts/compare_module_outputs.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/scripts/compare_module_outputs.py b/scripts/compare_module_outputs.py index 44b7653d4..0fae831fc 100644 --- a/scripts/compare_module_outputs.py +++ b/scripts/compare_module_outputs.py @@ -1,11 +1,10 @@ -from argparse import ArgumentParser import logging +from argparse import ArgumentParser from pathlib import Path from typing import List import torch - logger = logging.getLogger(__name__) @@ -49,7 +48,11 @@ def compare_module_output( if (norm_diff := torch.linalg.vector_norm((compare_input - base_input).float()).item()) != 0.0 or verbose: logger.info("%s input norm diff: %.6f", module_name, norm_diff) if "wte" in module_name: - logger.info("%s mis-matching wte elements: %d", module_name, torch.sum(torch.logical_not(torch.eq(base_input, compare_input)))) + logger.info( + "%s mis-matching wte elements: %d", + module_name, + torch.sum(torch.logical_not(torch.eq(base_input, compare_input))), + ) base_output = torch.load(str(base_module_output_path), map_location=map_location) compare_output = torch.load(str(compare_module_output_path), map_location=map_location) @@ -57,7 +60,9 @@ def compare_module_output( if isinstance(base_output, torch.Tensor): if verbose or base_output.dtype != compare_output.dtype: logger.info("%s output dtypes: %s %s", module_name, base_output.dtype, compare_output.dtype) - if (norm_diff := torch.linalg.vector_norm((compare_output - base_output).float()).item()) != 0.0 or verbose: + if ( + norm_diff := torch.linalg.vector_norm((compare_output - base_output).float()).item() + ) != 0.0 or verbose: logger.info("%s output norm diff: %.6f", module_name, norm_diff) elif include_non_tensor_outputs: logger.info("%s outputs: %s %s", module_name, base_output, compare_output) From ce1b41d1d97fa8ed6712afec8a6fd35298710dcd Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 6 Aug 2024 11:33:32 -0700 Subject: [PATCH 5/8] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c86f78bc9..f73375dc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `model.rope_theta` configuration option. - Added `model.embedding_layer_norm` configuration option for adding a LN to the embeddings. - Added `model.emb_init_std` configuration option to override the standard deviation used to initialize the embeddings. +- Added ability to save outputs of submodules for debugging purposes. ### Changed From 5cf4000c79fbd8712eb0775a7def0e75e1390c0d Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 12 Aug 2024 14:20:51 -0700 Subject: [PATCH 6/8] Skip non-tensor outputs by default when comparing module outputs --- scripts/compare_module_outputs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/compare_module_outputs.py b/scripts/compare_module_outputs.py index 0fae831fc..487099929 100644 --- a/scripts/compare_module_outputs.py +++ b/scripts/compare_module_outputs.py @@ -115,8 +115,8 @@ def main(): help="Path where output traces of the compare (a.k.a new, different) model are stored", ) parser.add_argument( - "--skip_non_tensor_outputs", - action="store_false", + "--include_non_tensor_outputs", + action="store_true", dest="include_non_tensor_outputs", help="If set, do not compare module outputs that are not tensors", ) From cea04f0073865ddd7240e4c2482264e18a8ae00c Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 12 Aug 2024 14:22:16 -0700 Subject: [PATCH 7/8] Update help message for comparing non-tensor module outputs --- scripts/compare_module_outputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/compare_module_outputs.py b/scripts/compare_module_outputs.py index 487099929..308290e75 100644 --- a/scripts/compare_module_outputs.py +++ b/scripts/compare_module_outputs.py @@ -118,7 +118,7 @@ def main(): "--include_non_tensor_outputs", action="store_true", dest="include_non_tensor_outputs", - help="If set, do not compare module outputs that are not tensors", + help="If set, compare module outputs that are not tensors", ) parser.add_argument( "--verbose", From 81c07c6c75210c25d7255b7e5665a7076d04027e Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 12 Aug 2024 14:24:30 -0700 Subject: [PATCH 8/8] Refer to comparison script in module_outputs_save_steps --- olmo/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/olmo/config.py b/olmo/config.py index d53c71de3..10afa48fb 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -1219,7 +1219,8 @@ class TrainConfig(BaseConfig): module_outputs_save_steps: Optional[List[int]] = None """ - Outputs of model submodules are saved during the provided steps. + Outputs of model submodules are saved during the provided steps. Submodule outputs + can be compared using `scripts/compare_module_outputs.py`. """ @property