Skip to content

Commit

Permalink
Make checkpoint loading more informative. Remove incorrect Metric typ…
Browse files Browse the repository at this point in the history
…e check. Make TableWriter expect a path instead of dir. (#126)

* Remove the incorrect Metric check from coderabbit

* Change TableWriter to expect a csv path instead of a dir

* Update no overlap between model and checkpoint state dicts error to show their key names
  • Loading branch information
ibro45 authored Jun 14, 2024
1 parent e531d53 commit 7b19df8
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 33 deletions.
39 changes: 18 additions & 21 deletions lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,12 @@ class LighterBaseWriter(ABC, Callback):
2) `self.write()` method to specify the saving strategy for a prediction.
Args:
directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside.
path (Union[str, Path]): Path for saving. It can be a directory or a specific file.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function.
"""

def __init__(self, directory: str, writer: Union[str, Callable]) -> None:
"""
Initialize the LighterBaseWriter.
Args:
directory (str): Base directory for saving. A new sub-directory with current date and time will be created inside.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers`, or a custom writer function.
"""
self.directory = Path(directory)
def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
self.path = Path(path)

# Check if the writer is a string and if it exists in the writers dictionary
if isinstance(writer, str):
Expand Down Expand Up @@ -70,30 +63,34 @@ def write(self, tensor: torch.Tensor, id: int) -> None:

def setup(self, trainer: Trainer, pl_module: LighterSystem, stage: str) -> None:
"""
Callback function to set up necessary prerequisites: prediction count and prediction directory.
Callback function to set up necessary prerequisites: prediction count and prediction file or directory.
When executing in a distributed environment, it ensures that:
1. Each distributed node initializes a prediction count based on its rank.
2. All distributed nodes write predictions to the same directory.
3. The directory is accessible to all nodes, i.e., all nodes share the same storage.
2. All distributed nodes write predictions to the same path.
3. The path is accessible to all nodes, i.e., all nodes share the same storage.
"""
if stage != "predict":
return

# Initialize the prediction count with the rank of the current process
self._pred_counter = torch.distributed.get_rank() if trainer.world_size > 1 else 0

# Ensure all distributed nodes write to the same directory
self.directory = trainer.strategy.broadcast(self.directory, src=0)
# Warn if the directory already exists
if self.directory.exists():
logger.warning(f"{self.directory} already exists, existing predictions will be overwritten.")
# Ensure all distributed nodes write to the same path
self.path = trainer.strategy.broadcast(self.path, src=0)
directory = self.path.parent if self.path.suffix else self.path

# Warn if the path already exists
if self.path.exists():
logger.warning(f"{self.path} already exists, existing predictions will be overwritten.")

if trainer.is_global_zero:
self.directory.mkdir(parents=True, exist_ok=True)
directory.mkdir(parents=True, exist_ok=True)

# Wait for rank 0 to create the directory
trainer.strategy.barrier()

# Ensure all distributed nodes have access to the directory
if not self.directory.exists():
# Ensure all distributed nodes have access to the path
if not directory.exists():
raise RuntimeError(
f"Rank {trainer.global_rank} does not share storage with rank 0. Ensure nodes have common storage access."
)
Expand Down
10 changes: 4 additions & 6 deletions lighter/callbacks/writer/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ class LighterTableWriter(LighterBaseWriter):
Writer for saving predictions in a table format.
Args:
directory (Path): Directory where the CSV will be saved.
path (Path): CSV filepath.
writer (Union[str, Callable]): Name of the writer function registered in `self.writers` or a custom writer function.
Available writers: "tensor". A custom writer function must take a single argument: `tensor`, and return the record
to be saved in the CSV file. The tensor will be a single tensor without the batch dimension.
"""

def __init__(self, directory: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(directory, writer)
def __init__(self, path: Union[str, Path], writer: Union[str, Callable]) -> None:
super().__init__(path, writer)
self.csv_records = {}

@property
Expand Down Expand Up @@ -52,8 +52,6 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No
If training was done in a distributed setting, it gathers predictions from all processes
and then saves them from the rank 0 process.
"""
csv_path = self.directory / "predictions.csv"

# Sort the records by ID and convert the dictionary to a list
self.csv_records = [self.csv_records[id] for id in sorted(self.csv_records)]

Expand All @@ -69,4 +67,4 @@ def on_predict_epoch_end(self, trainer: Trainer, pl_module: LighterSystem) -> No

# Save the records to a CSV file
if trainer.is_global_zero:
pd.DataFrame(self.csv_records).to_csv(csv_path)
pd.DataFrame(self.csv_records).to_csv(self.path)
2 changes: 0 additions & 2 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ def _log_stats(
# Metrics
if metrics is not None:
for name, metric in metrics.items():
if not isinstance(metric, Metric):
raise TypeError(f"Expected type for metric is 'Metric', got '{type(metric).__name__}' instead.")
on_step_log(f"{mode}/metrics/{name}/step", metric)
on_epoch_log(f"{mode}/metrics/{name}/epoch", metric)
# Optimizer's lr, momentum, beta. Logged in train mode and once per epoch.
Expand Down
11 changes: 7 additions & 4 deletions lighter/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,14 @@ def adjust_prefix_and_load_state_dict(
# Add the model_prefix before the current key name if there's no specific ckpt_prefix
ckpt = {f"{model_prefix}{key}": value for key, value in ckpt.items() if ckpt_prefix in key}

# Check if there is no overlap between the checkpoint's and model's state_dict.
if not set(ckpt.keys()) & set(model.state_dict().keys()):
# Check if the checkpoint's and model's state_dicts have no overlap.
model_keys = list(model.state_dict().keys())
ckpt_keys = list(ckpt.keys())
if not set(ckpt_keys) & set(model_keys):
raise ValueError(
"There is no overlap between checkpoint's and model's state_dict. Check their "
"`state_dict` keys and adjust accordingly using `ckpt_prefix` and `model_prefix`."
"There is no overlap between checkpoint's and model's state_dict."
f"\nModel keys: '{model_keys[0]}', ..., '{model_keys[-1]}', "
f"\nCheckpoint keys: '{ckpt_keys[0]}', ..., '{ckpt_keys[-1]}'"
)

# Remove the layers that are not to be loaded.
Expand Down

0 comments on commit 7b19df8

Please sign in to comment.