Skip to content

Commit

Permalink
Add message that prediction dir exists and that the checkpoint was lo…
Browse files Browse the repository at this point in the history
…aded succesfully
  • Loading branch information
ibro45 committed Jan 30, 2024
1 parent cbc17db commit 7083cc7
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 6 additions & 3 deletions lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path

import torch
from loguru import logger
from pytorch_lightning import Callback, Trainer

from lighter import LighterSystem
Expand All @@ -31,8 +32,7 @@ def __init__(self, directory: str, writer: Union[str, Callable]) -> None:
directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function.
"""
# Create a unique directory using the current date and time
self.directory = Path(directory) / datetime.now().strftime("%Y%m%d_%H%M%S")
self.directory = Path(directory)

# Check if the writer is a string and if it exists in the writers dictionary
if isinstance(writer, str):
Expand Down Expand Up @@ -83,8 +83,11 @@ def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:

# Ensure all distributed nodes write to the same directory
self.directory = trainer.strategy.broadcast(self.directory, src=0)
# Warn if the directory already exists
if self.directory.exists():
logger.warning(f"{self.directory} already exists, existing predictions will be overwritten.")
if trainer.is_global_zero:
self.directory.mkdir(parents=True)
self.directory.mkdir(parents=True, exist_ok=True)
# Wait for rank 0 to create the directory
trainer.strategy.barrier()

Expand Down
2 changes: 2 additions & 0 deletions lighter/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,7 @@ def adjust_prefix_and_load_state_dict(
# Log the incompatible keys during checkpoint loading.
if len(incompatible_keys.missing_keys) > 0 or len(incompatible_keys.unexpected_keys) > 0:
logger.info(f"Encountered incompatible keys during checkpoint loading. If intended, ignore.\n{incompatible_keys}")
else:
logger.info("Checkpoint loaded successfully.")

return model

0 comments on commit 7083cc7

Please sign in to comment.