-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AcceleratorRegistry
#12180
Add AcceleratorRegistry
#12180
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this! Provides more flexibilities for customized accelerator and it's consistent with Strategy behavior.
Do you think we can remove AcceleratorType after this PR?
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: ananthsub <[email protected]>
Co-authored-by: ananthsub <[email protected]>
68f1df5
to
096afb9
Compare
32dc47f
to
096afb9
Compare
Codecov Report
@@ Coverage Diff @@
## master #12180 +/- ##
=======================================
- Coverage 88% 88% -0%
=======================================
Files 207 209 +2
Lines 17704 17747 +43
=======================================
+ Hits 15524 15561 +37
- Misses 2180 2186 +6 |
accelerator_registry.register( | ||
"gpu", | ||
cls, | ||
description=f"{cls.__class__.__name__}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__class__.__name__
is already a string so the fstring is redundant.
@AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True) | ||
class SOTAAccelerator(Accelerator): | ||
def __init__(self, a, b): | ||
... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work:
Traceback (most recent call last):
File "/Users/carmocca/git/pytorch-lightning/thing.py", line 5, in <module>
class SOTAAccelerator:
TypeError: do_register() missing 1 required positional argument: 'accelerator'
@@ -75,7 +75,6 @@ def auto_device_count() -> int: | |||
def is_available() -> bool: | |||
"""Detect if the hardware is available.""" | |||
|
|||
@staticmethod | |||
@abstractmethod | |||
def name() -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get the reasoning for removing the name
.
If we kept it, the registry could use it to automatically define the name for the class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related discussion #12180 (comment)
from typing import Any | ||
|
||
|
||
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use is_overridden
?
def register( | ||
self, | ||
name: str, | ||
accelerator: Optional[Callable] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type should be Optional[Type]
, not Callable
from pytorch_lightning.accelerators.tpu import TPUAccelerator # noqa: F401 | ||
|
||
ACCELERATORS_BASE_MODULE = "pytorch_lightning.accelerators" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why expose this string?
If it was to avoid registering these by default, it would not work because to change the path you'd need to import the variable
And at import time, they would get registered anyways.
"""Name of the Accelerator.""" | ||
return "cpu" | ||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The argument should be typed as _AcceleratorRegistry
, not Dict
def name() -> str: | ||
"""Name of the Accelerator.""" | ||
@classmethod | ||
def register_accelerators(cls, accelerator_registry: Dict) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typing is wrong. should be the actual accelerator registry
an accelerator, e.g., "gpu". It also returns Optional description and | ||
parameters to initialize the Accelerator, which were defined during the | ||
registration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... an optional description ...
registration. | ||
|
||
The motivation for having a AcceleratorRegistry is to make it convenient | ||
for the Users to try different accelerators by passing mapped aliases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
users capitalized
What does this PR do?
Follow up to #12030
Does your PR introduce any breaking changes? If yes, please list them.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃
cc @Borda @akihironitta @rohitgr7 @justusschock