Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LambdaCallback #5347

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0df6bd9
Add LambdaCallback
marload Jan 4, 2021
aa13ddf
docs
marload Jan 4, 2021
01bd0a5
add pr link
Borda Jan 4, 2021
b0953dd
convention
marload Jan 4, 2021
7863e67
Fix Callback Typo
marload Jan 4, 2021
6792408
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 4, 2021
d934a23
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 4, 2021
9fc981a
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 4, 2021
a93e468
use Misconfigureation
marload Jan 5, 2021
2ef199f
update docs
marload Jan 5, 2021
cb294e0
sort export
marload Jan 5, 2021
aadde9e
use inspect
marload Jan 5, 2021
8c10b14
string fill
marload Jan 5, 2021
39b1970
use fast dev run
marload Jan 5, 2021
dc11767
isort
marload Jan 5, 2021
0cfef59
remove unused import
marload Jan 5, 2021
6835771
sort
marload Jan 5, 2021
0263e3a
hilightning
marload Jan 5, 2021
7249a10
highlighting
marload Jan 5, 2021
3038d2f
highlighting
marload Jan 5, 2021
c400b98
remove debug log
marload Jan 5, 2021
8518382
eq
marload Jan 5, 2021
8bfe53e
res
marload Jan 5, 2021
9fd4c6b
results
marload Jan 5, 2021
c4563c7
add misconfig exception test
marload Jan 5, 2021
a329d4a
use pytest raises
marload Jan 5, 2021
571b941
Merge remote-tracking branch 'upstream/release/1.2-dev' into feature/…
marload Jan 5, 2021
d1f8d4a
fix
marload Jan 5, 2021
7293115
Apply suggestions from code review
Borda Jan 6, 2021
4d85f59
Update pytorch_lightning/callbacks/lambda_cb.py
marload Jan 6, 2021
c9ecb8a
hc
marload Jan 6, 2021
2044291
rm pt
marload Jan 6, 2021
5359ce6
Merge branch 'release/1.2-dev' into feature/lambdacallback
tchaton Jan 6, 2021
556ea09
fix
marload Jan 8, 2021
d190e15
try fix
rohitgr7 Jan 9, 2021
d7bfc4a
Merge branch 'release/1.2-dev' into feature/lambdacallback
rohitgr7 Jan 9, 2021
a27dbff
whitespace
rohitgr7 Jan 9, 2021
d1bd19a
new hook
rohitgr7 Jan 9, 2021
afe018a
add raise
marload Jan 10, 2021
709fb5b
fix
marload Jan 10, 2021
9b93a2c
remove unused
marload Jan 10, 2021
72f3f0c
rename
marload Jan 12, 2021
7ed3eea
Merge branch 'release/1.2-dev' into feature/lambdacallback
SkafteNicki Jan 12, 2021
2ce0131
Merge branch 'release/1.2-dev' into feature/lambdacallback
SkafteNicki Jan 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 115 additions & 11 deletions pytorch_lightning/callbacks/lambda_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@

"""

import inspect
from typing import Callable, Optional

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LambdaCallback(Callback):
Expand All @@ -40,12 +39,117 @@ class LambdaCallback(Callback):
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
"""

def __init__(self, **kwargs):
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
for k, v in kwargs.items():
if k not in hooks:
raise MisconfigurationException(
f"The event function: `{k}` does not exist in supported callbacks function."
f" Currently, `Callback` implements the following functions {hooks}"
)
setattr(self, k, v)
def __init__(
self,
setup: Optional[Callable] = None,
teardown: Optional[Callable] = None,
on_init_start: Optional[Callable] = None,
on_init_end: Optional[Callable] = None,
on_fit_start: Optional[Callable] = None,
on_fit_end: Optional[Callable] = None,
on_sanity_check_start: Optional[Callable] = None,
on_sanity_check_end: Optional[Callable] = None,
on_train_batch_start: Optional[Callable] = None,
on_train_batch_end: Optional[Callable] = None,
on_train_epoch_start: Optional[Callable] = None,
on_train_epoch_end: Optional[Callable] = None,
on_validation_epoch_start: Optional[Callable] = None,
on_validation_epoch_end: Optional[Callable] = None,
on_test_epoch_start: Optional[Callable] = None,
on_test_epoch_end: Optional[Callable] = None,
on_epoch_start: Optional[Callable] = None,
on_epoch_end: Optional[Callable] = None,
on_batch_start: Optional[Callable] = None,
on_validation_batch_start: Optional[Callable] = None,
on_validation_batch_end: Optional[Callable] = None,
on_test_batch_start: Optional[Callable] = None,
on_test_batch_end: Optional[Callable] = None,
on_batch_end: Optional[Callable] = None,
on_train_start: Optional[Callable] = None,
on_train_end: Optional[Callable] = None,
on_pretrain_routine_start: Optional[Callable] = None,
on_pretrain_routine_end: Optional[Callable] = None,
on_validation_start: Optional[Callable] = None,
on_validation_end: Optional[Callable] = None,
on_test_start: Optional[Callable] = None,
on_test_end: Optional[Callable] = None,
on_keyboard_interrupt: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
):
if setup is not None:
self.setup = setup
if teardown is not None:
self.teardown = teardown
if on_init_start is not None:
self.on_init_start = on_init_start
if on_init_end is not None:
self.on_init_end = on_init_end
if on_fit_start is not None:
self.on_fit_start = on_fit_start
if on_fit_end is not None:
self.on_fit_end = on_fit_end
if on_sanity_check_start is not None:
self.on_sanity_check_start = on_sanity_check_start
if on_sanity_check_end is not None:
self.on_sanity_check_end = on_sanity_check_end
if on_train_batch_start is not None:
self.on_train_batch_start = on_train_batch_start
if on_train_batch_end is not None:
self.on_train_batch_end = on_train_batch_end
if on_train_epoch_start is not None:
self.on_train_epoch_start = on_train_epoch_start
if on_train_epoch_end is not None:
self.on_train_epoch_end = on_train_epoch_end
if on_validation_epoch_start is not None:
self.on_validation_epoch_start = on_validation_epoch_start
if on_validation_epoch_end is not None:
self.on_validation_epoch_end = on_validation_epoch_end
if on_test_epoch_start is not None:
self.on_test_epoch_start = on_test_epoch_start
if on_test_epoch_end is not None:
self.on_test_epoch_end = on_test_epoch_end
if on_epoch_start is not None:
self.on_epoch_start = on_epoch_start
if on_epoch_end is not None:
self.on_epoch_end = on_epoch_end
if on_batch_start is not None:
self.on_batch_start = on_batch_start
if on_validation_batch_start is not None:
self.on_validation_batch_start = on_validation_batch_start
if on_validation_batch_end is not None:
self.on_validation_batch_end = on_validation_batch_end
if on_test_batch_start is not None:
self.on_test_batch_start = on_test_batch_start
if on_test_batch_end is not None:
self.on_test_batch_end = on_test_batch_end
if on_batch_end is not None:
self.on_batch_end = on_batch_end
if on_train_start is not None:
self.on_train_start = on_train_start
if on_train_end is not None:
self.on_train_end = on_train_end
if on_pretrain_routine_start is not None:
self.on_pretrain_routine_start = on_pretrain_routine_start
if on_pretrain_routine_end is not None:
self.on_pretrain_routine_end = on_pretrain_routine_end
if on_validation_start is not None:
self.on_validation_start = on_validation_start
if on_validation_end is not None:
self.on_validation_end = on_validation_end
if on_test_start is not None:
self.on_test_start = on_test_start
if on_test_end is not None:
self.on_test_end = on_test_end
if on_keyboard_interrupt is not None:
self.on_keyboard_interrupt = on_keyboard_interrupt
if on_save_checkpoint is not None:
self.on_save_checkpoint = on_save_checkpoint
if on_load_checkpoint is not None:
self.on_load_checkpoint = on_load_checkpoint
if on_after_backward is not None:
self.on_after_backward = on_after_backward
if on_before_zero_grad is not None:
self.on_before_zero_grad = on_before_zero_grad
6 changes: 0 additions & 6 deletions tests/callbacks/test_lambda_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,9 @@

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LambdaCallback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel


def test_lambda_raise_misconfiguration():
with pytest.raises(MisconfigurationException, match='does not exist in supported callbacks function'):
LambdaCallback(invalid=lambda *args: args)


marload marked this conversation as resolved.
Show resolved Hide resolved
def test_lambda_call(tmpdir):
seed_everything(42)

Expand Down