-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Accelerator abstraction #2320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerator abstraction #2320
Changes from all commits
a7d5a01
0937e5a
3895c07
368d77d
190c950
5858a59
066219f
40b4b81
3fadd46
89a926c
5d05bb5
e2a0319
eeb1671
c09d2f5
916ea71
96dcd07
beb7c54
0e3b5b8
d72c37c
8d27c76
7fbfee6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .abstract_accelerator import DeepSpeedAccelerator |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| import abc | ||
| from abc import ABC | ||
|
|
||
|
|
||
| class DeepSpeedAccelerator(ABC): | ||
| def __init__(self): | ||
| self._name = None | ||
| self._communication_backend_name = None | ||
| self.BFloat16Tensor = None | ||
| self.ByteTensor = None | ||
| self.DoubleTensor = None | ||
| self.FloatTensor = None | ||
| self.HalfTensor = None | ||
| self.IntTensor = None | ||
| self.LongTensor = None | ||
|
|
||
| # Device APIs | ||
| @abc.abstractmethod | ||
| def device_name(self, device_index): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def device(self, device_index): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def set_device(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def current_device(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def device_count(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def synchronize(self, device_index=None): | ||
| ... | ||
|
|
||
| # RNG APIs | ||
| @abc.abstractmethod | ||
| def set_rng_state(self, new_state, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def get_rng_state(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def manual_seed(self, seed): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def manual_seed_all(self, seed): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def initial_seed(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def default_generator(self, device_index): | ||
| ... | ||
|
|
||
| # Streams/Events | ||
| @abc.abstractmethod | ||
| def Stream(self, device_index=None, priority=0, **kwargs): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def StreamContext(self, stream): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def current_stream(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def default_stream(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def Event(self, **kwargs): | ||
| ... | ||
|
|
||
| # Memory management | ||
| @abc.abstractmethod | ||
| def empty_cache(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def memory_allocated(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def max_memory_allocated(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def reset_max_memory_allocated(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def reset_max_memory_cached(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def memory_stats(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def reset_peak_memory_stats(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def memory_reserved(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def max_memory_reserved(self, device_index=None): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def total_memory(self, device_index=None): | ||
| ... | ||
|
|
||
| # Data types | ||
| @abc.abstractmethod | ||
| def is_bf16_supported(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def is_fp16_supported(self): | ||
| ... | ||
|
|
||
| # Misc | ||
| @abc.abstractmethod | ||
| def is_available(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def range_push(self, msg): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def range_pop(self, msg): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. range_pop does not take msg as argument, only range_push does. |
||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def lazy_call(self, callback): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def name(self): | ||
| ... | ||
|
|
||
| @abc.abstractmethod | ||
| def communication_backend_name(self): | ||
| ... | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator | ||
| import torch.cuda | ||
|
|
||
|
|
||
| class CUDA_Accelerator(DeepSpeedAccelerator): | ||
| def __init__(self): | ||
| self._name = 'cuda' | ||
| self._communication_backend_name = 'nccl' | ||
| self.DoubleTensor = torch.cuda.DoubleTensor | ||
| self.LongTensor = torch.cuda.LongTensor | ||
| self.FloatTensor = torch.cuda.FloatTensor | ||
| self.BFloat16Tensor = torch.cuda.BFloat16Tensor | ||
| self.HalfTensor = torch.cuda.HalfTensor | ||
| self.IntTensor = torch.cuda.IntTensor | ||
| self.ByteTensor = torch.cuda.ByteTensor | ||
|
|
||
| # Device APIs | ||
| def device_name(self, device_index=None): | ||
| idx = torch.cuda.current_device() if device_index is None else device_index | ||
| return f'cuda:{idx}' | ||
|
|
||
| def device(self, device_index=None): | ||
| return torch.cuda.device(device_index) | ||
|
|
||
| def set_device(self, device_index): | ||
| torch.cuda.set_device(device_index) | ||
|
|
||
| def current_device(self): | ||
| return torch.cuda.current_device() | ||
|
|
||
| def device_count(self): | ||
| return torch.cuda.device_count() | ||
|
|
||
| def synchronize(self, device_index=None): | ||
| return torch.cuda.synchronize(device_index) | ||
|
|
||
| # RNG APIs | ||
| def set_rng_state(self, new_state, device_index=None): | ||
| if device_index is None: | ||
| return torch.cuda.set_rng_state(new_state) | ||
|
|
||
| return torch.cuda.set_rng_state(new_state, device_index) | ||
|
|
||
| def get_rng_state(self, device_index=None): | ||
| if device_index is None: | ||
| return torch.cuda.get_rng_state() | ||
|
|
||
| return torch.cuda.get_rng_state(device_index) | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def manual_seed(self, seed): | ||
| return torch.cuda.manual_seed(seed) | ||
|
|
||
| def manual_seed_all(self, seed): | ||
| return torch.cuda.manual_seed_all(seed) | ||
|
|
||
| def initial_seed(self, seed): | ||
| return torch.cuda.initial_seed(seed) | ||
|
|
||
| def default_generator(self, device_index): | ||
| return torch.cuda.default_generators[device_index] | ||
|
|
||
| # Streams/Events | ||
| def Stream(self, device_index=None, priority=0, **kwargs): | ||
| return torch.cuda.Stream(device_index, priority, **kwargs) | ||
|
|
||
| def StreamContext(self, stream): | ||
| return torch.cuda.StreamContext(stream) | ||
|
|
||
| def current_stream(self, device_index=None): | ||
| return torch.cuda.current_stream(device_index) | ||
|
|
||
| def default_stream(self, device_index=None): | ||
| return torch.cuda.default_stream(device_index) | ||
|
|
||
| def Event(self, **kwargs): | ||
| return torch.cuda.Event(**kwargs) | ||
|
|
||
| # Memory management | ||
| def empty_cache(self): | ||
| return torch.cuda.empty_cache() | ||
|
|
||
| def memory_allocated(self, device_index=None): | ||
| return torch.cuda.memory_allocated(device_index) | ||
|
|
||
| def max_memory_allocated(self, device_index=None): | ||
| return torch.cuda.max_memory_allocated(device_index) | ||
|
|
||
| def reset_max_memory_allocated(self, device_index=None): | ||
| return torch.cuda.reset_max_memory_allocated(device_index) | ||
|
|
||
| def reset_max_memory_cached(self, device_index=None): | ||
| return torch.cuda.reset_max_memory_cached(device_index) | ||
|
|
||
| def memory_stats(self, device_index=None): | ||
| if hasattr(torch.cuda, 'memory_stats'): | ||
| return torch.cuda.memory_stats(device_index) | ||
|
|
||
| def reset_peak_memory_stats(self, device_index=None): | ||
| if hasattr(torch.cuda, 'reset_peak_memory_stats'): | ||
| return torch.cuda.reset_peak_memory_stats(device_index) | ||
|
|
||
| def memory_reserved(self, device_index=None): | ||
| if hasattr(torch.cuda, 'memory_reserved'): | ||
| return torch.cuda.memory_reserved(device_index) | ||
|
|
||
| def max_memory_reserved(self, device_index=None): | ||
| if hasattr(torch.cuda, 'max_memory_reserved'): | ||
| return torch.cuda.max_memory_reserved(device_index) | ||
|
|
||
| def total_memory(self, device_index=None): | ||
| return torch.cuda.get_device_properties(device_index).total_memory | ||
|
|
||
| # Data types | ||
| def is_bf16_supported(self): | ||
| return torch.cuda.is_bf16_supported() | ||
|
|
||
| def is_fp16_supported(self): | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return torch.cuda.is_fp16_supported() | ||
|
|
||
| # Misc | ||
| def is_available(self): | ||
| return torch.cuda.is_available() | ||
|
|
||
| def range_push(self, msg): | ||
| if hasattr(torch.cuda.nvtx, 'range_push'): | ||
| return torch.cuda.nvtx.range_push(msg) | ||
|
|
||
| def range_pop(self, msg): | ||
| if hasattr(torch.cuda.nvtx, 'range_pop'): | ||
| return torch.cuda.nvtx.range_pop(msg) | ||
|
|
||
| def lazy_call(self, callback): | ||
| return torch.cuda._lazy_call(callback) | ||
|
|
||
| def name(self): | ||
| return self._name | ||
|
|
||
| def communication_backend_name(self): | ||
| return self._communication_backend_name | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| from .abstract_accelerator import DeepSpeedAccelerator | ||
|
|
||
| ds_accelerator = None | ||
|
|
||
|
|
||
| def _validate_accelerator(accel_obj): | ||
| assert isinstance(accel_obj, DeepSpeedAccelerator), \ | ||
| f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator' | ||
|
|
||
| assert accel_obj.is_available(), \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our internal test shows this check 'is_available()' breaks unit test. Call to is_available would initialize cuda too early and cause initialization error. specifically, this test is broken with this assertion.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please share a stack trace for this? Thanks!
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the stack trace of this error in our CUDA environment |
||
| f'{accel_obj.__class__.__name__} accelerator fails is_available() test' | ||
|
|
||
|
|
||
| def get_accelerator(): | ||
| global ds_accelerator | ||
| if ds_accelerator is None: | ||
| from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator | ||
| ds_accelerator = CUDA_Accelerator() | ||
| _validate_accelerator(ds_accelerator) | ||
| return ds_accelerator | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def set_accelerator(accel_obj): | ||
| global ds_accelerator | ||
| _validate_accelerator(accel_obj) | ||
| ds_accelerator = accel_obj | ||
|
|
||
|
|
||
| ''' | ||
| -----------[code] test_get.py ----------- | ||
| from deepspeed.accelerator.real_accelerator import get_accelerator | ||
| my_accelerator = get_accelerator() | ||
| print(f'{my_accelerator.name=}') | ||
| print(f'{my_accelerator.communication_backend=}') | ||
| print(f'{my_accelerator.HalfTensor().device=}') | ||
| print(f'{my_accelerator.total_memory()=}') | ||
| -----------[code] test_get.py ----------- | ||
| ---[output] python test_get.py--------- | ||
| my_accelerator.name()='cuda' | ||
| my_accelerator.communication_backend='nccl' | ||
| my_accelerator.HalfTensor().device=device(type='cuda', index=0) | ||
| my_accelerator.total_memory()=34089730048 | ||
| ---[output] python test_get.py--------- | ||
| ************************************************************************** | ||
| -----------[code] test_set.py ----------- | ||
| from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator | ||
| cu_accel = CUDA_Accelerator() | ||
| print(f'{id(cu_accel)=}') | ||
| from deepspeed.accelerator.real_accelerator import set_accelerator, get_accelerator | ||
| set_accelerator(cu_accel) | ||
| my_accelerator = get_accelerator() | ||
| print(f'{id(my_accelerator)=}') | ||
| print(f'{my_accelerator.name=}') | ||
| print(f'{my_accelerator.communication_backend=}') | ||
| print(f'{my_accelerator.HalfTensor().device=}') | ||
| print(f'{my_accelerator.total_memory()=}') | ||
| -----------[code] test_set.py ----------- | ||
| ---[output] python test_set.py--------- | ||
| id(cu_accel)=139648165478304 | ||
| my_accelerator=<deepspeed.accelerator.cuda_accelerator.CUDA_Accelerator object at 0x7f025f4bffa0> | ||
| my_accelerator.name='cuda' | ||
| my_accelerator.communication_backend='nccl' | ||
| my_accelerator.HalfTensor().device=device(type='cuda', index=0) | ||
| my_accelerator.total_memory()=34089730048 | ||
| ---[output] python test_set.py--------- | ||
| ''' | ||
Uh oh!
There was an error while loading. Please reload this page.