-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support kwargs input for LayerSummary
#17709
Conversation
LayerSummary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@leng-yue CI is raising some issues. Looks like Traceback (most recent call last):
File "/Users/runner/work/lightning/lightning/tests/legacy/simple_classif_training.py", line 56, in <module>
main_train(path_dir)
File "/Users/runner/work/lightning/lightning/tests/legacy/simple_classif_training.py", line 46, in main_train
trainer.fit(model, datamodule=dm)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 577, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 962, in _run
call._call_callback_hooks(self, "on_fit_start")
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 189, in _call_callback_hooks
fn(trainer, trainer.lightning_module, *args, **kwargs)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_summary.py", line 60, in on_fit_start
model_summary = self._summary(trainer, pl_module)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_summary.py", line 74, in _summary
return summarize(pl_module, max_depth=self._max_depth)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 462, in summarize
return ModelSummary(lightning_module, max_depth=max_depth)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 201, in __init__
self._layer_summary = self.summarize()
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 260, in summarize
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 260, in <genexpr>
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 79, in __init__
self._hook_handle = self._register_hook()
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 107, in _register_hook
handle = self._module.register_forward_hook(hook, with_kwargs=True)
TypeError: Module.register_forward_hook() got an unexpected keyword argument 'with_kwargs'
Exception ignored in: <function LayerSummary.__del__ at 0x1300868c0>
Traceback (most recent call last):
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 84, in __del__
self.detach_hook()
File "/Users/runner/hostedtoolcache/Python/3.10.11/x64/lib/python3.10/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py", line 115, in detach_hook
if self._hook_handle is not None:
AttributeError: 'LayerSummary' object has no attribute '_hook_handle' You need to pass it conditionally based on And for the 2.0 tests, the following are failing: FAILED utilities/test_model_summary.py::test_example_input_array_types[example_input3-?-1] - AssertionError: assert [[1, 2, 3]] == ['?']
At index 0 diff: [1, 2, 3] != '?'
Full diff:
- ['?']
+ [[1, 2, 3]]
FAILED utilities/test_model_summary.py::test_example_input_array_types[example_input3-?--1] - AssertionError: assert [[1, 2, 3]] == ['?']
At index 0 diff: [1, 2, 3] != '?'
Full diff:
- ['?']
+ [[1, 2, 3]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks clean. Thank you!
I removed the bug label because PyTorch added this functionality in newer versions and so the expectation can't be that Lightning was supposed to support this in the past. This is a new (and great) feature and extends the flexibility of the layer summary 🎉 |
Co-authored-by: Adrian Wälchli <[email protected]>
for more information, see https://pre-commit.ci
What does this PR do?
Support kwargs input for LayerSummary #17676, and add unit test.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist