diff --git a/CHANGELOG.md b/CHANGELOG.md index e89f22d50..e28a4c03b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Rename `Olmo` to `OLMo` everywhere in the codebase +- Disabled automatic garbage collection during training, instead we run manually at regular intervals to avoid ranks getting out-of-sync with their own gc. ### Removed diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index a68914f5f..1dd355a36 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -50,7 +50,13 @@ from .exceptions import OLMoCheckpointError from .optim import Optimizer, fix_optim_state_dict from .safetensors_util import safetensors_file_to_state_dict -from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size +from .torch_util import ( + barrier, + gc_cuda, + get_fs_local_rank, + get_global_rank, + get_world_size, +) from .util import ( _get_s3_client, default_thread_count, @@ -191,7 +197,7 @@ def load_fsdp_model_and_optim_state( ), ) del model_state - torch.cuda.empty_cache() + gc_cuda() load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"]) @@ -212,7 +218,7 @@ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[ v = state[k] if isinstance(v, torch.Tensor): state[k] = v.to(device="cpu") - torch.cuda.empty_cache() + gc_cuda() optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd)) diff --git a/olmo/torch_util.py b/olmo/torch_util.py index 22afe57f5..a2149fb22 100644 --- a/olmo/torch_util.py +++ b/olmo/torch_util.py @@ -1,3 +1,4 @@ +import gc import os from typing import Optional, TypeVar @@ -130,3 +131,9 @@ def synchronize_value(value: V, device: torch.device) -> V: def synchronize_flag(flag: bool, device: torch.device) -> bool: return synchronize_value(flag, device) + + +def gc_cuda(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/olmo/train.py b/olmo/train.py index 3f20bde08..1494a1b49 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1,6 +1,7 @@ from __future__ import annotations import cProfile +import gc import logging import math import os @@ -38,6 +39,7 @@ from .optim import Optimizer, Scheduler from .torch_util import ( barrier, + gc_cuda, get_fs_local_rank, get_global_rank, get_world_size, @@ -136,6 +138,7 @@ class Trainer: cur_train_loss: float = float("inf") indices_file: Optional[TextIO] = None _start_time: float = 0.0 + _gc_init_state: bool = True loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore last_sharded_checkpoint_step: Optional[int] = None last_unsharded_checkpoint_step: Optional[int] = None @@ -537,15 +540,19 @@ def restore_unsharded_checkpoint( def save_checkpoint( self, checkpoint_type: CheckpointType = CheckpointType.sharded ) -> Tuple[PathOrStr, Optional[PathOrStr]]: + result: Tuple[PathOrStr, Optional[PathOrStr]] if checkpoint_type == CheckpointType.sharded: - return self.save_sharded_checkpoint() + result = self.save_sharded_checkpoint() elif checkpoint_type == CheckpointType.unsharded: - return self.save_unsharded_checkpoint() + result = self.save_unsharded_checkpoint() elif checkpoint_type == CheckpointType.sharded_ephemeral: - return self.save_ephemeral_checkpoint() + result = self.save_ephemeral_checkpoint() else: raise NotImplementedError(checkpoint_type) + gc_cuda() + return result + def restore_checkpoint( self, load_path: PathOrStr, @@ -576,6 +583,8 @@ def restore_checkpoint( elif checkpoint_type is not None: raise NotImplementedError(checkpoint_type) + gc_cuda() + def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded): if checkpoint_type == CheckpointType.sharded: self.remove_sharded_checkpoint(idx=idx) @@ -936,6 +945,10 @@ def fit(self): self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after) self._start_time = time.time() + self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close. + + # Disable automatic garbage collection, FSDP doesn't work well with it. + gc.disable() if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load: eval_metrics = self.eval() @@ -1141,6 +1154,9 @@ def on_trace_ready(p): if stop_at is not None and self.global_step >= stop_at: break + # Run generation 1 garbage collection. + gc.collect(1) + # Python Profiler stuff # We do this now, at the bottom of this loop, so we capture the work of getting the next batch. if python_profiler is not None: @@ -1178,9 +1194,15 @@ def on_trace_ready(p): log.info(f"Checkpoint saved to {checkpoint_path}") def close(self, exit_code: int = 0) -> None: + gc_cuda() + if self.indices_file is not None: self.indices_file.flush() self.indices_file.close() + if self._gc_init_state: + gc.enable() + else: + gc.disable() if wandb.run is not None: wandb.finish(exit_code=exit_code, quiet=True)