diff --git a/CHANGELOG.md b/CHANGELOG.md index 382cf638e..fc1a819da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `model.embedding_layer_norm` configuration option for adding a LN to the embeddings. - Added `model.emb_init_std` configuration option to override the standard deviation used to initialize the embeddings. - Added `CosLinearEnvelope` scheduler, which is a pointwise product of a cosine schedule and a linear decay. +- Added ability to save outputs of submodules for debugging purposes. ### Changed diff --git a/olmo/config.py b/olmo/config.py index d3f94f37c..ae454bb19 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -1218,6 +1218,12 @@ class TrainConfig(BaseConfig): Path to cache directory of HF datasets saved with `datasets.save_to_disk`. """ + module_outputs_save_steps: Optional[List[int]] = None + """ + Outputs of model submodules are saved during the provided steps. Submodule outputs + can be compared using `scripts/compare_module_outputs.py`. + """ + @property def autocast_precision(self) -> torch.dtype: if self.precision == "amp_bf16": diff --git a/olmo/train.py b/olmo/train.py index fd88b205e..341055003 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1,6 +1,7 @@ from __future__ import annotations import cProfile +import functools import gc import logging import math @@ -20,6 +21,8 @@ import torch import torch.distributed as dist import torch.nn.functional as F +import torch.utils +import torch.utils.hooks import wandb from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -650,6 +653,53 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec else: raise NotImplementedError(checkpoint_type) + def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.utils.hooks.RemovableHandle]: + if ( + self.cfg.module_outputs_save_steps is None + or self.global_step not in self.cfg.module_outputs_save_steps + ): + return [] + + if micro_batch_idx != 0 or get_global_rank() != 0: + # Hook is currently only used on the first microbatch of rank 0 + return [] + + trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}" + if trace_save_folder.exists(): + if self.cfg.save_overwrite: + shutil.rmtree(trace_save_folder) + else: + raise OLMoConfigurationError( + f"Attempting to overwrite traces at step {self.global_step} without --save_overwrite" + ) + trace_save_folder.mkdir(parents=True) + + def trace_outputs_hook( + module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor + ) -> None: + if len(args) == 0: + log.info("No input args for module %s, output %s", module_name, output) + + module_input = args[0] if len(args) > 0 else torch.tensor(()) + trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}" + trace_save_folder.mkdir(parents=True, exist_ok=True) + + module_occurence_num = 0 + while ( + module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt" + ).exists(): + module_occurence_num += 1 + torch.save(module_input, module_input_filepath) + + module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt" + torch.save(output, module_output_filepath) + + output_hooks = [] + for module_name, module in self.model.named_modules(prefix="model"): + output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name))) + + return output_hooks + def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor: # Labels are just input IDs shifted to the left (first item is ignored). labels, label_mask, attention_mask, instance_mask = ( @@ -740,6 +790,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor if micro_batch_idx != num_micro_batches - 1: grad_sync_context = self.dist_model.no_sync + # Register output hooks + output_hooks: List[torch.utils.hooks.RemovableHandle] = [] + output_hooks += self._setup_module_output_save_hooks(micro_batch_idx) + with grad_sync_context(): with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision): # Run forward pass. @@ -756,6 +810,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor # Run backward pass. loss.backward() + # Remove output hooks + for hook in output_hooks: + hook.remove() + return ce_batch_loss, z_batch_loss def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]: diff --git a/scripts/compare_module_outputs.py b/scripts/compare_module_outputs.py new file mode 100644 index 000000000..308290e75 --- /dev/null +++ b/scripts/compare_module_outputs.py @@ -0,0 +1,139 @@ +import logging +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import torch + +logger = logging.getLogger(__name__) + + +def _get_module_names(checkpoint_traces_folder: Path) -> List[str]: + module_names = [] + for trace_file in checkpoint_traces_folder.iterdir(): + trace_file_name = trace_file.name + if trace_file_name.endswith("_input.pt"): + module_name = trace_file_name.removesuffix("_input.pt") + elif trace_file_name.endswith("_output.pt"): + module_name = trace_file_name.removesuffix("_output.pt") + else: + logger.warning("Cannot get parameter from file %s, skipping", trace_file_name) + + module_names.append(module_name) + + return module_names + + +def compare_module_output( + base_traces_folder: Path, + compare_traces_folder: Path, + module_name: str, + *, + include_non_tensor_outputs: bool = True, + verbose: bool = False, +): + base_module_input_path = base_traces_folder / f"{module_name}_input.pt" + base_module_output_path = base_traces_folder / f"{module_name}_output.pt" + compare_module_input_path = compare_traces_folder / f"{module_name}_input.pt" + compare_module_output_path = compare_traces_folder / f"{module_name}_output.pt" + + map_location = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + base_input = torch.load(str(base_module_input_path), map_location=map_location) + compare_input = torch.load(str(compare_module_input_path), map_location=map_location) + + if verbose or base_input.dtype != compare_input.dtype: + logger.info("%s input dtypes: %s %s", module_name, base_input.dtype, compare_input.dtype) + if verbose or base_input.shape != compare_input.shape: + logger.info("%s input shapes: %s %s", module_name, base_input.shape, compare_input.shape) + if (norm_diff := torch.linalg.vector_norm((compare_input - base_input).float()).item()) != 0.0 or verbose: + logger.info("%s input norm diff: %.6f", module_name, norm_diff) + if "wte" in module_name: + logger.info( + "%s mis-matching wte elements: %d", + module_name, + torch.sum(torch.logical_not(torch.eq(base_input, compare_input))), + ) + + base_output = torch.load(str(base_module_output_path), map_location=map_location) + compare_output = torch.load(str(compare_module_output_path), map_location=map_location) + + if isinstance(base_output, torch.Tensor): + if verbose or base_output.dtype != compare_output.dtype: + logger.info("%s output dtypes: %s %s", module_name, base_output.dtype, compare_output.dtype) + if ( + norm_diff := torch.linalg.vector_norm((compare_output - base_output).float()).item() + ) != 0.0 or verbose: + logger.info("%s output norm diff: %.6f", module_name, norm_diff) + elif include_non_tensor_outputs: + logger.info("%s outputs: %s %s", module_name, base_output, compare_output) + else: + if verbose: + logger.info("Base output is type %s, skipping", type(base_output)) + + +def compare_module_outputs( + base_traces_folder: Path, + compare_traces_folder: Path, + *, + include_non_tensor_outputs: bool = True, + verbose: bool = False, +): + base_modules = set(_get_module_names(base_traces_folder)) + compare_modules = set(_get_module_names(compare_traces_folder)) + + base_only_modules = base_modules - compare_modules + if len(base_only_modules) > 0: + logger.info("Base-only modules: %s", ", ".join(base_only_modules)) + + compare_only_modules = compare_modules - base_modules + if len(compare_only_modules) > 0: + logger.info("Compare-only modules: %s", ", ".join(compare_only_modules)) + + common_modules = base_modules.intersection(compare_modules) + for module_name in sorted(common_modules): + compare_module_output( + base_traces_folder, + compare_traces_folder, + module_name, + include_non_tensor_outputs=include_non_tensor_outputs, + verbose=verbose, + ) + + +def main(): + logging.basicConfig(encoding="utf-8", level=logging.INFO) + + parser = ArgumentParser() + parser.add_argument( + "base_model_traces_path", + type=Path, + help="Path where output traces of the base (i.e. reference) model are stored", + ) + parser.add_argument( + "compare_model_traces_path", + type=Path, + help="Path where output traces of the compare (a.k.a new, different) model are stored", + ) + parser.add_argument( + "--include_non_tensor_outputs", + action="store_true", + dest="include_non_tensor_outputs", + help="If set, compare module outputs that are not tensors", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="If set, show extra information", + ) + + args = parser.parse_args() + compare_module_outputs( + args.base_model_traces_path, + args.compare_model_traces_path, + include_non_tensor_outputs=args.include_non_tensor_outputs, + verbose=args.verbose, + ) + + +if __name__ == "__main__": + main()