-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes access to callback_metrics in ddp_spawn #7916
Conversation
Hello @edgarriba! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-06-17 13:37:47 UTC |
Codecov Report
@@ Coverage Diff @@
## master #7916 +/- ##
========================================
Coverage 92% 92%
========================================
Files 207 211 +4
Lines 13375 14557 +1182
========================================
+ Hits 12245 13347 +1102
- Misses 1130 1210 +80 |
pytorch_lightning/trainer/connectors/logger_connector/logger_connector_new.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we put the callback metrics directly in the queue instead?
We don't want users to have to use a different attribute depending on the accelerator.
How would this impact performance?
a2bb4ac
to
d6d6c19
Compare
@carmocca My initial proposal using the return trainer.spawn_callback_metrics["val_acc"] however, as @tchaton proposes to make it more generic we could follow the approach below return trainer.spawn_extra_parameters["callback_metrics"]["val_acc"] open for an api discussion ** this is the gist to entry point script: https://gist.github.com/edgarriba/af6247edb32586b19e740f17735ff055 |
9aee3ac
to
9b5a97c
Compare
for more information, see https://pre-commit.ci
88dd15e
to
c637998
Compare
for more information, see https://pre-commit.ci
1671e57
to
0cd331f
Compare
82bcdbf
to
90ff74e
Compare
for more information, see https://pre-commit.ci
@carmocca @awaelchli your comments I believe that were addressed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great
I think just the changelog entries missing now.
Adding 1.4 milestone
6c8e0a9
to
6c219e1
Compare
for more information, see https://pre-commit.ci
Co-authored-by: Carlos Mocholí <[email protected]>
|
||
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: | ||
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. | ||
|
||
To avoid issues with memory sharing, we cast the data to numpy. | ||
|
||
Args: | ||
queue: the instance of the queue to append the data. | ||
""" | ||
callback_metrics: dict = apply_to_collection( | ||
self.trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() | ||
) # send as numpy to avoid issues with memory sharing | ||
queue.put(callback_metrics) | ||
|
||
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: | ||
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. | ||
|
||
To preserve consistency, we cast back the data to ``torch.Tensor``. | ||
|
||
Args: | ||
queue: the instance of the queue from where to get the data. | ||
""" | ||
# NOTE: `add_to_queue` needs to be called before | ||
callback_metrics: dict = queue.get() | ||
self.trainer.callback_metrics.update( | ||
apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the only alternative to populate these metrics? why is this on a user the user interace of the lightning module? what happens if someone overrides this? is it meant to be overridden?
it feels like the lightning module is used as a go-between between different parts of the trainer, in particular because the training type plugin technically has no reference to the trainer.
structrually, we are repeatedly reaching through the lightning module to access the trainer in a very roundabout way. another example: https://github.com/PyTorchLightning/pytorch-lightning/blob/55a90af7fc0805855684e93dbad669f5bbe76eee/pytorch_lightning/plugins/training_type/sharded.py#L42-L57
it feels backwards and it also makes efforts like #7315 harder to work through when we keep relying on the trainer like this
Do |
The point was to let users add and get from these. See #7916 (comment) and the rest of the discussions in this PR |
@@ -202,6 +202,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): | |||
self.mp_queue.put(best_model_path) | |||
self.mp_queue.put(last_path) | |||
self.mp_queue.put(results) | |||
self.lightning_module.add_to_queue(self.mp_queue) # adds the `callback_metrics` to the queue | |||
|
|||
def save(self, state_dict: Dict, path: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@edgarriba is there a reason you add_to_q in tpu_spawn, but dont get_from_q?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah nevermind, I think its just bc tpu_spawn doesnt override post_dispatch
""" | ||
# NOTE: `add_to_queue` needs to be called before | ||
callback_metrics: dict = queue.get() | ||
self.trainer.callback_metrics.update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have to update the callback metrics here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll answer for Edgar:
The purpose of this PR was to provide a mechanism for users to add items to consume from callbacks in the spawn
environment.
Hence why we update callback metrics here. Callbacks read metrics off that dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carmocca thanks so much!!
What does this PR do?
Fixes #7671
Fixes access to callback_metrics in ddp_spawn
TODO:
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃