Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a7d5a01
Accelerator abstraction
tjruwase Sep 13, 2022
0937e5a
Format fixes
tjruwase Sep 13, 2022
3895c07
Use ABC
tjruwase Sep 13, 2022
368d77d
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Sep 13, 2022
190c950
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Sep 14, 2022
5858a59
Integration guide
tjruwase Sep 15, 2022
066219f
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Sep 15, 2022
40b4b81
Permit only import exceptions; Check is_available()
tjruwase Sep 19, 2022
3fadd46
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Sep 19, 2022
89a926c
Merge branch 'olruwase/accelerator_abstraction' of github.com:microso…
tjruwase Sep 19, 2022
5d05bb5
Explicit set/get accelerator
tjruwase Sep 21, 2022
e2a0319
Sanity checks
tjruwase Sep 22, 2022
eeb1671
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Sep 27, 2022
c09d2f5
Default cuda device
tjruwase Sep 27, 2022
916ea71
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Sep 28, 2022
96dcd07
Add device_name()
tjruwase Sep 29, 2022
beb7c54
Merge branch 'olruwase/accelerator_abstraction' of github.com:microso…
tjruwase Sep 29, 2022
0e3b5b8
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Sep 30, 2022
d72c37c
More interfaces
tjruwase Sep 30, 2022
8d27c76
Merge branch 'olruwase/accelerator_abstraction' of github.com:microso…
tjruwase Sep 30, 2022
7fbfee6
Merge branch 'master' into olruwase/accelerator_abstraction
tjruwase Oct 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepspeed/accelerator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .abstract_accelerator import DeepSpeedAccelerator
161 changes: 161 additions & 0 deletions deepspeed/accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import abc
from abc import ABC


class DeepSpeedAccelerator(ABC):
def __init__(self):
self._name = None
self._communication_backend_name = None
self.BFloat16Tensor = None
self.ByteTensor = None
self.DoubleTensor = None
self.FloatTensor = None
self.HalfTensor = None
self.IntTensor = None
self.LongTensor = None

# Device APIs
@abc.abstractmethod
def device_name(self, device_index):
...

@abc.abstractmethod
def device(self, device_index):
...

@abc.abstractmethod
def set_device(self):
...

@abc.abstractmethod
def current_device(self):
...

@abc.abstractmethod
def device_count(self):
...

@abc.abstractmethod
def synchronize(self, device_index=None):
...

# RNG APIs
@abc.abstractmethod
def set_rng_state(self, new_state, device_index=None):
...

@abc.abstractmethod
def get_rng_state(self, device_index=None):
...

@abc.abstractmethod
def manual_seed(self, seed):
...

@abc.abstractmethod
def manual_seed_all(self, seed):
...

@abc.abstractmethod
def initial_seed(self):
...

@abc.abstractmethod
def default_generator(self, device_index):
...

# Streams/Events
@abc.abstractmethod
def Stream(self, device_index=None, priority=0, **kwargs):
...

@abc.abstractmethod
def StreamContext(self, stream):
...

@abc.abstractmethod
def current_stream(self, device_index=None):
...

@abc.abstractmethod
def default_stream(self, device_index=None):
...

@abc.abstractmethod
def Event(self, **kwargs):
...

# Memory management
@abc.abstractmethod
def empty_cache(self):
...

@abc.abstractmethod
def memory_allocated(self, device_index=None):
...

@abc.abstractmethod
def max_memory_allocated(self, device_index=None):
...

@abc.abstractmethod
def reset_max_memory_allocated(self, device_index=None):
...

@abc.abstractmethod
def reset_max_memory_cached(self, device_index=None):
...

@abc.abstractmethod
def memory_stats(self, device_index=None):
...

@abc.abstractmethod
def reset_peak_memory_stats(self, device_index=None):
...

@abc.abstractmethod
def memory_reserved(self, device_index=None):
...

@abc.abstractmethod
def max_memory_reserved(self, device_index=None):
...

@abc.abstractmethod
def total_memory(self, device_index=None):
...

# Data types
@abc.abstractmethod
def is_bf16_supported(self):
...

@abc.abstractmethod
def is_fp16_supported(self):
...

# Misc
@abc.abstractmethod
def is_available(self):
...

@abc.abstractmethod
def range_push(self, msg):
...

@abc.abstractmethod
def range_pop(self, msg):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

range_pop does not take msg as argument, only range_push does.
https://pytorch.org/docs/stable/generated/torch.cuda.nvtx.range_pop.html

...

@abc.abstractmethod
def lazy_call(self, callback):
...

@abc.abstractmethod
def name(self):
...

@abc.abstractmethod
def communication_backend_name(self):
...
139 changes: 139 additions & 0 deletions deepspeed/accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import torch.cuda


class CUDA_Accelerator(DeepSpeedAccelerator):
def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
self.DoubleTensor = torch.cuda.DoubleTensor
self.LongTensor = torch.cuda.LongTensor
self.FloatTensor = torch.cuda.FloatTensor
self.BFloat16Tensor = torch.cuda.BFloat16Tensor
self.HalfTensor = torch.cuda.HalfTensor
self.IntTensor = torch.cuda.IntTensor
self.ByteTensor = torch.cuda.ByteTensor

# Device APIs
def device_name(self, device_index=None):
idx = torch.cuda.current_device() if device_index is None else device_index
return f'cuda:{idx}'

def device(self, device_index=None):
return torch.cuda.device(device_index)

def set_device(self, device_index):
torch.cuda.set_device(device_index)

def current_device(self):
return torch.cuda.current_device()

def device_count(self):
return torch.cuda.device_count()

def synchronize(self, device_index=None):
return torch.cuda.synchronize(device_index)

# RNG APIs
def set_rng_state(self, new_state, device_index=None):
if device_index is None:
return torch.cuda.set_rng_state(new_state)

return torch.cuda.set_rng_state(new_state, device_index)

def get_rng_state(self, device_index=None):
if device_index is None:
return torch.cuda.get_rng_state()

return torch.cuda.get_rng_state(device_index)

def manual_seed(self, seed):
return torch.cuda.manual_seed(seed)

def manual_seed_all(self, seed):
return torch.cuda.manual_seed_all(seed)

def initial_seed(self, seed):
return torch.cuda.initial_seed(seed)

def default_generator(self, device_index):
return torch.cuda.default_generators[device_index]

# Streams/Events
def Stream(self, device_index=None, priority=0, **kwargs):
return torch.cuda.Stream(device_index, priority, **kwargs)

def StreamContext(self, stream):
return torch.cuda.StreamContext(stream)

def current_stream(self, device_index=None):
return torch.cuda.current_stream(device_index)

def default_stream(self, device_index=None):
return torch.cuda.default_stream(device_index)

def Event(self, **kwargs):
return torch.cuda.Event(**kwargs)

# Memory management
def empty_cache(self):
return torch.cuda.empty_cache()

def memory_allocated(self, device_index=None):
return torch.cuda.memory_allocated(device_index)

def max_memory_allocated(self, device_index=None):
return torch.cuda.max_memory_allocated(device_index)

def reset_max_memory_allocated(self, device_index=None):
return torch.cuda.reset_max_memory_allocated(device_index)

def reset_max_memory_cached(self, device_index=None):
return torch.cuda.reset_max_memory_cached(device_index)

def memory_stats(self, device_index=None):
if hasattr(torch.cuda, 'memory_stats'):
return torch.cuda.memory_stats(device_index)

def reset_peak_memory_stats(self, device_index=None):
if hasattr(torch.cuda, 'reset_peak_memory_stats'):
return torch.cuda.reset_peak_memory_stats(device_index)

def memory_reserved(self, device_index=None):
if hasattr(torch.cuda, 'memory_reserved'):
return torch.cuda.memory_reserved(device_index)

def max_memory_reserved(self, device_index=None):
if hasattr(torch.cuda, 'max_memory_reserved'):
return torch.cuda.max_memory_reserved(device_index)

def total_memory(self, device_index=None):
return torch.cuda.get_device_properties(device_index).total_memory

# Data types
def is_bf16_supported(self):
return torch.cuda.is_bf16_supported()

def is_fp16_supported(self):
return torch.cuda.is_fp16_supported()

# Misc
def is_available(self):
return torch.cuda.is_available()

def range_push(self, msg):
if hasattr(torch.cuda.nvtx, 'range_push'):
return torch.cuda.nvtx.range_push(msg)

def range_pop(self, msg):
if hasattr(torch.cuda.nvtx, 'range_pop'):
return torch.cuda.nvtx.range_pop(msg)

def lazy_call(self, callback):
return torch.cuda._lazy_call(callback)

def name(self):
return self._name

def communication_backend_name(self):
return self._communication_backend_name
71 changes: 71 additions & 0 deletions deepspeed/accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from .abstract_accelerator import DeepSpeedAccelerator

ds_accelerator = None


def _validate_accelerator(accel_obj):
assert isinstance(accel_obj, DeepSpeedAccelerator), \
f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator'

assert accel_obj.is_available(), \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our internal test shows this check 'is_available()' breaks unit test. Call to is_available would initialize cuda too early and cause initialization error.

specifically, this test is broken with this assertion.
pytest -k "test_ckpt_arg_none" test_activation_checkpointing.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please share a stack trace for this? Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the stack trace of this error in our CUDA environment

$ pytest -k "test_ckpt_arg_none" test_activation_checkpointing.py
......
Worker 0 exited with code 1
----------------------------- Captured stdout call -----------------------------
[2022-10-10 10:27:14,784] [INFO] [comm.py:639:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
------------------------------------------ Captured stderr call ------------------
Process Process-1:
Traceback (most recent call last):
  File "/home/gma/anaconda3/envs/ds/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/gma/anaconda3/envs/ds/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/gma/mingzhil/CI/frameworks.ai.benchmarking.other.deepspeed/tests/unit/common.py", line 250, in dist_init
    dist.barrier()
  File "/home/gma/mll/CI/frameworks.ai.benchmarking.other.deepspeed/deepspeed/comm/comm.py", line 127, in log_wrapper
    return func(*args, **kwargs)
  File "/home/gma/mll/CI/frameworks.ai.benchmarking.other.deepspeed/deepspeed/comm/comm.py", line 459, in barrier
    return cdb.barrier()
  File "/home/gma/mll/CI/frameworks.ai.benchmarking.other.deepspeed/deepspeed/comm/torch.py", line 153, in barrier
    return torch.distributed.barrier()
  File "/home/gma/anaconda3/envs/ds/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2776, in barrier
    work = default_pg.barrier(opts=opts)
RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
......
===================== short test summary info ==========================
FAILED test_activation_checkpointing.py::test_ckpt_arg_none
=================== 1 failed, 21 deselected, 2 warnings in 1.06s ========================

f'{accel_obj.__class__.__name__} accelerator fails is_available() test'


def get_accelerator():
global ds_accelerator
if ds_accelerator is None:
from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
ds_accelerator = CUDA_Accelerator()
_validate_accelerator(ds_accelerator)
return ds_accelerator


def set_accelerator(accel_obj):
global ds_accelerator
_validate_accelerator(accel_obj)
ds_accelerator = accel_obj


'''
-----------[code] test_get.py -----------
from deepspeed.accelerator.real_accelerator import get_accelerator
my_accelerator = get_accelerator()
print(f'{my_accelerator.name=}')
print(f'{my_accelerator.communication_backend=}')
print(f'{my_accelerator.HalfTensor().device=}')
print(f'{my_accelerator.total_memory()=}')
-----------[code] test_get.py -----------
---[output] python test_get.py---------
my_accelerator.name()='cuda'
my_accelerator.communication_backend='nccl'
my_accelerator.HalfTensor().device=device(type='cuda', index=0)
my_accelerator.total_memory()=34089730048
---[output] python test_get.py---------
**************************************************************************
-----------[code] test_set.py -----------
from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator
cu_accel = CUDA_Accelerator()
print(f'{id(cu_accel)=}')
from deepspeed.accelerator.real_accelerator import set_accelerator, get_accelerator
set_accelerator(cu_accel)
my_accelerator = get_accelerator()
print(f'{id(my_accelerator)=}')
print(f'{my_accelerator.name=}')
print(f'{my_accelerator.communication_backend=}')
print(f'{my_accelerator.HalfTensor().device=}')
print(f'{my_accelerator.total_memory()=}')
-----------[code] test_set.py -----------
---[output] python test_set.py---------
id(cu_accel)=139648165478304
my_accelerator=<deepspeed.accelerator.cuda_accelerator.CUDA_Accelerator object at 0x7f025f4bffa0>
my_accelerator.name='cuda'
my_accelerator.communication_backend='nccl'
my_accelerator.HalfTensor().device=device(type='cuda', index=0)
my_accelerator.total_memory()=34089730048
---[output] python test_set.py---------
'''