diff --git a/lighter/callbacks/writer/base.py b/lighter/callbacks/writer/base.py index 143688e9..43c57f7a 100644 --- a/lighter/callbacks/writer/base.py +++ b/lighter/callbacks/writer/base.py @@ -5,6 +5,7 @@ from pathlib import Path import torch +from loguru import logger from pytorch_lightning import Callback, Trainer from lighter import LighterSystem @@ -31,8 +32,7 @@ def __init__(self, directory: str, writer: Union[str, Callable]) -> None: directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside. writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function. """ - # Create a unique directory using the current date and time - self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S") + self.directory = Path(directory) # Check if the writer is a string and if it exists in the writers dictionary if isinstance(writer, str): @@ -83,8 +83,11 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None: # Ensure all distributed nodes write to the same directory self.directory = trainer.strategy.broadcast(self.directory, src=0) + # Warn if the directory already exists + if self.directory.exists(): + logger.warning(f"{self.directory} already exists, existing predictions will be overwritten.") if trainer.is_global_zero: - self.directory.mkdir(parents=True) + self.directory.mkdir(parents=True, exist_ok=True) # Wait for rank 0 to create the directory trainer.strategy.barrier() diff --git a/lighter/utils/model.py b/lighter/utils/model.py index fe6d9620..58abe4c8 100644 --- a/lighter/utils/model.py +++ b/lighter/utils/model.py @@ -116,5 +116,7 @@ def adjust_prefix_and_load_state_dict( # Log the incompatible keys during checkpoint loading. if len(incompatible_keys.missing_keys) > 0 or len(incompatible_keys.unexpected_keys) > 0: logger.info(f"Encountered incompatible keys during checkpoint loading. If intended, ignore.\n{incompatible_keys}") + else: + logger.info("Checkpoint loaded successfully.") return model