Skip to content

Commit

Permalink
Update model_checkpoint.py
Browse files Browse the repository at this point in the history
remove trainer/lightning modules types to avoid circular import

Update gradient_accumulation_scheduler.py
  • Loading branch information
ananthsub committed Sep 19, 2020
1 parent dcfe725 commit 8f201c1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(self, scheduling: dict):

minimal_epoch = min(scheduling.keys())
if minimal_epoch < 0:
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")
raise IndexError(
f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
)
if minimal_epoch != 0: # if user didnt define first epoch accumulation factor
scheduling.update({0: 1})

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,16 @@ def on_validation_end(self, trainer, pl_module):
self._del_model(self.last_model_path)

def _do_check_save(
<<<<<<< HEAD
self, filepath: str, current: torch.Tensor, epoch: int, trainer, pl_module
=======
self,
filepath: str,
current: torch.Tensor,
epoch: int,
trainer,
pl_module,
>>>>>>> Update model_checkpoint.py
):
# remove kth

Expand Down Expand Up @@ -452,7 +461,13 @@ def _do_check_save(
if cur_path != filepath:
self._del_model(cur_path)

<<<<<<< HEAD
def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
=======
def on_save_checkpoint(
self, trainer, pl_module
) -> Dict[str, Any]:
>>>>>>> Update model_checkpoint.py
return {
"best_model_score": self.best_model_score,
"best_model_path": self.best_model_path,
Expand Down

0 comments on commit 8f201c1

Please sign in to comment.