diff --git a/benchmarks/communication/all_gather.py b/benchmarks/communication/all_gather.py index d99d2aa0e4c9..3e4e38964636 100644 --- a/benchmarks/communication/all_gather.py +++ b/benchmarks/communication/all_gather.py @@ -1,5 +1,6 @@ from benchmarks.communication.utils import * from benchmarks.communication.constants import * +from deepspeed.accelerator import get_accelerator import time @@ -83,16 +84,20 @@ def run_all_gather(local_rank, args): try: mat = torch.ones(world_size, M, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + dtype=getattr( + torch, + args.dtype)).to( + get_accelerator().device_name(local_rank)) sync_all() input = ((mat.mul_(float(global_rank))).view(-1)) # Delete original mat to avoid OOM del mat - torch.cuda.empty_cache() + get_accelerator().empty_cache() output = torch.zeros(input.nelement() * world_size, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + dtype=getattr( + torch, + args.dtype)).to( + get_accelerator().device_name(local_rank)) except RuntimeError as e: if 'out of memory' in str(e): if dist.get_rank() == 0: @@ -121,15 +126,17 @@ def run_all_gather(local_rank, args): try: mat = torch.ones(elements_per_gpu, dtype=getattr(torch, - args.dtype)).cuda(local_rank) + args.dtype)).to( + get_accelerator().device_name(local_rank)) # multiply each GPU's tensor by the rank to ease debugging input = ((mat.mul_(float(global_rank))).view(-1)) # Delete original mat to avoid OOM del mat - torch.cuda.empty_cache() - output = torch.zeros(elements_per_gpu * world_size, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + get_accelerator().empty_cache() + output = torch.zeros( + elements_per_gpu * world_size, + dtype=getattr(torch, + args.dtype)).to(get_accelerator().device_name(local_rank)) except RuntimeError as e: if 'out of memory' in str(e): if dist.get_rank() == 0: diff --git a/benchmarks/communication/all_reduce.py b/benchmarks/communication/all_reduce.py index e31f51733609..a0f91e03dcf5 100644 --- a/benchmarks/communication/all_reduce.py +++ b/benchmarks/communication/all_reduce.py @@ -1,5 +1,6 @@ from benchmarks.communication.utils import * from benchmarks.communication.constants import * +from deepspeed.accelerator import get_accelerator import time @@ -62,8 +63,10 @@ def run_all_reduce(local_rank, args): try: mat = torch.ones(world_size, M, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + dtype=getattr( + torch, + args.dtype)).to( + get_accelerator().device_name(local_rank)) sync_all() input = ((mat.mul_(float(global_rank))).view(-1)) except RuntimeError as e: @@ -86,7 +89,8 @@ def run_all_reduce(local_rank, args): try: mat = torch.ones(elements_per_gpu, dtype=getattr(torch, - args.dtype)).cuda(local_rank) + args.dtype)).to( + get_accelerator().device_name(local_rank)) input = ((mat.mul_(float(global_rank))).view(-1)) except RuntimeError as e: if 'out of memory' in str(e): diff --git a/benchmarks/communication/all_to_all.py b/benchmarks/communication/all_to_all.py index 6ee99a48ee62..063b877b5b3d 100644 --- a/benchmarks/communication/all_to_all.py +++ b/benchmarks/communication/all_to_all.py @@ -1,5 +1,6 @@ from benchmarks.communication.utils import * from benchmarks.communication.constants import * +from deepspeed.accelerator import get_accelerator import time @@ -61,8 +62,10 @@ def run_all_to_all(local_rank, args): try: mat = torch.ones(world_size, M, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + dtype=getattr( + torch, + args.dtype)).to( + get_accelerator().device_name(local_rank)) assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks" sync_all() input = ((mat.mul_(float(global_rank))).view(-1)) @@ -86,15 +89,17 @@ def run_all_to_all(local_rank, args): try: mat = torch.ones(elements_per_gpu, dtype=getattr(torch, - args.dtype)).cuda(local_rank) + args.dtype)).to( + get_accelerator().device_name(local_rank)) assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks" input = ((mat.mul_(float(global_rank))).view(-1)) # Delete original mat to avoid OOM del mat - torch.cuda.empty_cache() - output = torch.zeros(elements_per_gpu, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + get_accelerator().empty_cache() + output = torch.zeros( + elements_per_gpu, + dtype=getattr(torch, + args.dtype)).to(get_accelerator().device_name(local_rank)) except RuntimeError as e: if 'out of memory' in str(e): if dist.get_rank() == 0: diff --git a/benchmarks/communication/broadcast.py b/benchmarks/communication/broadcast.py index e9d89779ec66..54f840934013 100644 --- a/benchmarks/communication/broadcast.py +++ b/benchmarks/communication/broadcast.py @@ -1,6 +1,7 @@ import torch from benchmarks.communication.utils import * from benchmarks.communication.constants import * +from deepspeed.accelerator import get_accelerator import time @@ -63,8 +64,10 @@ def run_broadcast(local_rank, args): try: mat = torch.ones(world_size, M, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + dtype=getattr( + torch, + args.dtype)).to( + get_accelerator().device_name(local_rank)) sync_all() input = ((mat.mul_(float(global_rank))).view(-1)) except RuntimeError as e: @@ -87,7 +90,8 @@ def run_broadcast(local_rank, args): try: mat = torch.ones(elements_per_gpu, dtype=getattr(torch, - args.dtype)).cuda(local_rank) + args.dtype)).to( + get_accelerator().device_name(local_rank)) input = ((mat.mul_(float(global_rank))).view(-1)) except RuntimeError as e: if 'out of memory' in str(e): diff --git a/benchmarks/communication/constants.py b/benchmarks/communication/constants.py index 4b3356894b5f..c17557e8012c 100644 --- a/benchmarks/communication/constants.py +++ b/benchmarks/communication/constants.py @@ -1,7 +1,9 @@ +from deepspeed.accelerator import get_accelerator + DEFAULT_WARMUPS = 5 DEFAULT_TRIALS = 50 DEFAULT_TYPE = 'float' -DEFAULT_BACKEND = 'nccl' +DEFAULT_BACKEND = get_accelerator().communication_backend_name() DEFAULT_UNIT = 'Gbps' DEFAULT_DIST = 'deepspeed' DEFAULT_MAXSIZE = 24 diff --git a/benchmarks/communication/pt2pt.py b/benchmarks/communication/pt2pt.py index cb99b20b9097..37e9c3be93d4 100644 --- a/benchmarks/communication/pt2pt.py +++ b/benchmarks/communication/pt2pt.py @@ -1,5 +1,6 @@ from benchmarks.communication.utils import * from benchmarks.communication.constants import * +from deepspeed.accelerator import get_accelerator import time @@ -81,8 +82,10 @@ def run_pt2pt(local_rank, args): try: mat = torch.ones(world_size, M, - dtype=getattr(torch, - args.dtype)).cuda(local_rank) + dtype=getattr( + torch, + args.dtype)).to( + get_accelerator().device_name(local_rank)) sync_all() input = ((mat.mul_(float(global_rank))).view(-1)) except RuntimeError as e: @@ -105,7 +108,8 @@ def run_pt2pt(local_rank, args): try: mat = torch.ones(elements_per_gpu, dtype=getattr(torch, - args.dtype)).cuda(local_rank) + args.dtype)).to( + get_accelerator().device_name(local_rank)) input = ((mat.mul_(float(global_rank))).view(-1)) except RuntimeError as e: if 'out of memory' in str(e): diff --git a/benchmarks/communication/utils.py b/benchmarks/communication/utils.py index 305f2f3dad37..097c828ce262 100644 --- a/benchmarks/communication/utils.py +++ b/benchmarks/communication/utils.py @@ -3,6 +3,7 @@ import math import argparse from benchmarks.communication.constants import * +from deepspeed.accelerator import get_accelerator global dist @@ -12,7 +13,7 @@ def init_torch_distributed(backend): import torch.distributed as dist torch.distributed.init_process_group(backend) local_rank = int(os.environ['LOCAL_RANK']) - torch.cuda.set_device(local_rank) + get_accelerator().set_device(local_rank) def init_deepspeed_comm(backend): @@ -21,7 +22,7 @@ def init_deepspeed_comm(backend): import deepspeed.comm as dist deepspeed.init_distributed(dist_backend=backend) local_rank = int(os.environ['LOCAL_RANK']) - torch.cuda.set_device(local_rank) + get_accelerator().set_device(local_rank) def init_processes(local_rank, args): @@ -99,14 +100,13 @@ def get_metric_strings(args, tput, busbw, duration): def sync_all(): - torch.cuda.synchronize() + get_accelerator().synchronize() dist.barrier() def max_numel(comm_op, dtype, mem_factor, local_rank, args): dtype_size = _element_size(dtype) - max_memory_per_gpu = torch.cuda.get_device_properties( - local_rank).total_memory * mem_factor + max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast': elements_per_gpu = int(max_memory_per_gpu // dtype_size) elif comm_op == 'all_gather': @@ -183,7 +183,8 @@ def benchmark_parser(): parser.add_argument("--backend", type=str, default=DEFAULT_BACKEND, - choices=['nccl'], + choices=['nccl', + 'ccl'], help='Communication library to use') parser.add_argument("--dist", type=str, diff --git a/benchmarks/inference/bert-bench.py b/benchmarks/inference/bert-bench.py index e576d67f7d82..f10d38efecf4 100644 --- a/benchmarks/inference/bert-bench.py +++ b/benchmarks/inference/bert-bench.py @@ -3,6 +3,7 @@ import deepspeed import argparse from transformers import pipeline +from deepspeed.accelerator import get_accelerator parser = argparse.ArgumentParser() parser.add_argument("--model", "-m", type=str, help="hf model name") @@ -44,7 +45,7 @@ def print_latency(latency_set, title, warmup=3): print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000)) -deepspeed.init_distributed("nccl") +deepspeed.init_distributed() print(args.model, args.max_tokens, args.dtype) @@ -72,10 +73,10 @@ def print_latency(latency_set, title, warmup=3): times = [] mtimes = [] for i in range(args.trials): - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() r = pipe("Hello I'm a [MASK] model") - torch.cuda.synchronize() + get_accelerator().synchronize() end = time.time() responses.append(r) times.append((end - start)) diff --git a/benchmarks/inference/gpt-bench.py b/benchmarks/inference/gpt-bench.py index 9d3905946e1b..cb490c18a4b6 100644 --- a/benchmarks/inference/gpt-bench.py +++ b/benchmarks/inference/gpt-bench.py @@ -4,6 +4,7 @@ import deepspeed import argparse from transformers import pipeline +from deepspeed.accelerator import get_accelerator parser = argparse.ArgumentParser() parser.add_argument("--model", "-m", type=str, help="hf model name") @@ -61,7 +62,7 @@ def print_latency(latency_set, title, warmup=3): print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000)) -deepspeed.init_distributed("nccl") +deepspeed.init_distributed() if args.local_rank == 0: print("BENCHMARK SETTINGS:") @@ -101,10 +102,10 @@ def print_latency(latency_set, title, warmup=3): times = [] mtimes = [] for i in range(args.trials): - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() r = pipe("DeepSpeed is", do_sample=False, max_new_tokens=args.max_tokens) - torch.cuda.synchronize() + get_accelerator().synchronize() end = time.time() responses.append(r) times.append(end - start) # / (args.max_tokens - 3)) diff --git a/csrc/aio/py_test/aio_bench_perf_sweep.py b/csrc/aio/py_test/aio_bench_perf_sweep.py index f16e962b1648..040cc0957dd1 100644 --- a/csrc/aio/py_test/aio_bench_perf_sweep.py +++ b/csrc/aio/py_test/aio_bench_perf_sweep.py @@ -15,6 +15,7 @@ from test_ds_aio_utils import refine_integer_value from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \ READ_IO_DIR, WRITE_IO_DIR, READ_LOG_DIR, WRITE_LOG_DIR +from deepspeed.accelerator import get_accelerator OTHER_OPTIONS = '--handle' PERF_SCRIPT = 'test_ds_aio.py' @@ -277,8 +278,7 @@ def script_path(): def async_io_setup(): - from deepspeed.ops.aio import AsyncIOBuilder - return AsyncIOBuilder().is_compatible() + return get_accelerator().create_op_builder("AsyncIOBuilder").is_compatible() def get_block_size_and_count(io_bytes): diff --git a/csrc/aio/py_test/ds_aio_basic.py b/csrc/aio/py_test/ds_aio_basic.py index cf70b6655e9c..62e9f79f1b94 100755 --- a/csrc/aio/py_test/ds_aio_basic.py +++ b/csrc/aio/py_test/ds_aio_basic.py @@ -8,9 +8,9 @@ import torch import os import time -from deepspeed.ops.aio import AsyncIOBuilder from multiprocessing import Pool, Barrier from test_ds_aio_utils import report_results, task_log, task_barrier +from deepspeed.accelerator import get_accelerator def pre_basic(args, tid, read_op): @@ -19,7 +19,10 @@ def pre_basic(args, tid, read_op): file = args.read_file if read_op else f'{args.write_file}.{tid}' task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() + buffer = get_accelerator().pin_memory( + torch.empty(num_bytes, + dtype=torch.uint8, + device='cpu')) task_log( tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' @@ -56,13 +59,14 @@ def post_basic(pool_params): def main_basic_read(pool_params): args, tid, ctxt = pool_params start_time = time.time() - AsyncIOBuilder().load().aio_read(ctxt['buffer'], - ctxt['file'], - args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - args.validate) + get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_read( + ctxt['buffer'], + ctxt['file'], + args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + args.validate) end_time = time.time() ctxt['elapsed_sec'] += end_time - start_time @@ -72,13 +76,14 @@ def main_basic_read(pool_params): def main_basic_write(pool_params): args, tid, ctxt = pool_params start_time = time.time() - AsyncIOBuilder().load().aio_write(ctxt['buffer'], - ctxt['file'], - args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - args.validate) + get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_write( + ctxt['buffer'], + ctxt['file'], + args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + args.validate) end_time = time.time() ctxt['elapsed_sec'] += end_time - start_time diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py index 947ee2e6cb63..d6db9a59460c 100755 --- a/csrc/aio/py_test/ds_aio_handle.py +++ b/csrc/aio/py_test/ds_aio_handle.py @@ -9,8 +9,8 @@ import os import time from multiprocessing import Pool, Barrier -from deepspeed.ops.aio import AsyncIOBuilder from test_ds_aio_utils import report_results, task_log, task_barrier +from deepspeed.accelerator import get_accelerator def pre_handle(args, tid, read_op): @@ -20,20 +20,26 @@ def pre_handle(args, tid, read_op): task_log(tid, f'Allocate tensor of size {num_bytes} bytes') if args.gpu: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cuda') + buffer = torch.empty(num_bytes, + dtype=torch.uint8, + device=get_accelerator().device_name()) else: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() + buffer = get_accelerator().pin_memory( + torch.empty(num_bytes, + dtype=torch.uint8, + device='cpu')) task_log( tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' ) io_parallel = args.io_parallel if args.io_parallel else 1 - handle = AsyncIOBuilder().load().aio_handle(args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - io_parallel) + handle = get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_handle( + args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + io_parallel) task_log(tid, f'created deepspeed aio handle') ctxt = {} diff --git a/csrc/aio/py_test/validate_async_io.py b/csrc/aio/py_test/validate_async_io.py index 4db25fe1bc33..cebe1ef22417 100644 --- a/csrc/aio/py_test/validate_async_io.py +++ b/csrc/aio/py_test/validate_async_io.py @@ -4,5 +4,5 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ -from deepspeed.ops.aio import AsyncIOBuilder -assert AsyncIOBuilder().is_compatible() +from deepspeed.accelerator import get_accelerator +assert get_accelerator().create_op_builder("AsyncIOBuilder").is_compatible() diff --git a/deepspeed/accelerator/__init__.py b/deepspeed/accelerator/__init__.py new file mode 100644 index 000000000000..aae3e8f9d32e --- /dev/null +++ b/deepspeed/accelerator/__init__.py @@ -0,0 +1,2 @@ +from .abstract_accelerator import DeepSpeedAccelerator +from .real_accelerator import get_accelerator, set_accelerator diff --git a/deepspeed/accelerator/abstract_accelerator.py b/deepspeed/accelerator/abstract_accelerator.py new file mode 100644 index 000000000000..551096d879ba --- /dev/null +++ b/deepspeed/accelerator/abstract_accelerator.py @@ -0,0 +1,198 @@ +import abc +from abc import ABC + + +class DeepSpeedAccelerator(ABC): + def __init__(self): + self._name = None + self._communication_backend_name = None + self.BFloat16Tensor = None + self.ByteTensor = None + self.DoubleTensor = None + self.FloatTensor = None + self.HalfTensor = None + self.IntTensor = None + self.LongTensor = None + + # Device APIs + @abc.abstractmethod + def device_name(self, device_index): + ... + + @abc.abstractmethod + def device(self, device_index): + ... + + @abc.abstractmethod + def set_device(self): + ... + + @abc.abstractmethod + def current_device(self): + ... + + @abc.abstractmethod + def current_device_name(self): + ... + + @abc.abstractmethod + def device_count(self): + ... + + @abc.abstractmethod + def synchronize(self, device_index=None): + ... + + # RNG APIs + @abc.abstractmethod + def random(self): + ... + + @abc.abstractmethod + def set_rng_state(self, new_state, device_index=None): + ... + + @abc.abstractmethod + def get_rng_state(self, device_index=None): + ... + + @abc.abstractmethod + def manual_seed(self, seed): + ... + + @abc.abstractmethod + def manual_seed_all(self, seed): + ... + + @abc.abstractmethod + def initial_seed(self): + ... + + @abc.abstractmethod + def default_generator(self, device_index): + ... + + # Streams/Events + @abc.abstractmethod + def Stream(self, device=None, priority=0, **kwargs): + ... + + @abc.abstractmethod + def StreamContext(self, stream): + ... + + @abc.abstractmethod + def stream(self, stream): + ... + + @abc.abstractmethod + def current_stream(self, device_index=None): + ... + + @abc.abstractmethod + def default_stream(self, device_index=None): + ... + + @abc.abstractmethod + def Event(self, **kwargs): + ... + + # Memory management + @abc.abstractmethod + def empty_cache(self): + ... + + @abc.abstractmethod + def memory_allocated(self, device_index=None): + ... + + @abc.abstractmethod + def max_memory_allocated(self, device_index=None): + ... + + @abc.abstractmethod + def reset_max_memory_allocated(self, device_index=None): + ... + + @abc.abstractmethod + def memory_cached(self, device_index=None): + ... + + @abc.abstractmethod + def max_memory_cached(self, device_index=None): + ... + + @abc.abstractmethod + def reset_max_memory_cached(self, device_index=None): + ... + + @abc.abstractmethod + def memory_stats(self, device_index=None): + ... + + @abc.abstractmethod + def reset_peak_memory_stats(self, device_index=None): + ... + + @abc.abstractmethod + def memory_reserved(self, device_index=None): + ... + + @abc.abstractmethod + def max_memory_reserved(self, device_index=None): + ... + + @abc.abstractmethod + def total_memory(self, device_index=None): + ... + + # Data types + @abc.abstractmethod + def is_bf16_supported(self): + ... + + @abc.abstractmethod + def is_fp16_supported(self): + ... + + # Misc + @abc.abstractmethod + def amp(self): + ... + + @abc.abstractmethod + def is_available(self): + ... + + @abc.abstractmethod + def range_push(self, msg): + ... + + @abc.abstractmethod + def range_pop(self): + ... + + @abc.abstractmethod + def lazy_call(self, callback): + ... + + @abc.abstractmethod + def communication_backend_name(self): + ... + + # Tensor operations + @abc.abstractmethod + def pin_memory(self, tensor): + ... + + @abc.abstractmethod + def on_accelerator(self, tensor): + ... + + @abc.abstractmethod + def create_op_builder(self, class_name): + ... + + @abc.abstractmethod + def build_extension(self): + ... diff --git a/deepspeed/accelerator/cuda_accelerator.py b/deepspeed/accelerator/cuda_accelerator.py new file mode 100644 index 000000000000..4faaf5c015b3 --- /dev/null +++ b/deepspeed/accelerator/cuda_accelerator.py @@ -0,0 +1,203 @@ +from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator +import torch.cuda + + +class CUDA_Accelerator(DeepSpeedAccelerator): + def __init__(self): + self._name = 'cuda' + self._communication_backend_name = 'nccl' + self.DoubleTensor = torch.cuda.DoubleTensor + self.LongTensor = torch.cuda.LongTensor + self.FloatTensor = torch.cuda.FloatTensor + self.BFloat16Tensor = torch.cuda.BFloat16Tensor + self.HalfTensor = torch.cuda.HalfTensor + self.IntTensor = torch.cuda.IntTensor + self.ByteTensor = torch.cuda.ByteTensor + + # Device APIs + def device_name(self, device_index=None): + if device_index == None: + return 'cuda' + return 'cuda:{}'.format(device_index) + + def device(self, device_index=None): + return torch.cuda.device(device_index) + + def set_device(self, device_index): + torch.cuda.set_device(device_index) + + def current_device(self): + return torch.cuda.current_device() + + def current_device_name(self): + return 'cuda:{}'.format(torch.cuda.current_device()) + + def device_count(self): + return torch.cuda.device_count() + + def synchronize(self, device_index=None): + return torch.cuda.synchronize(device_index) + + # RNG APIs + def random(self): + return torch.random() + + def set_rng_state(self, new_state, device_index=None): + if device_index is None: + return torch.cuda.set_rng_state(new_state) + + return torch.cuda.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + if device_index is None: + return torch.cuda.get_rng_state() + + return torch.cuda.get_rng_state(device_index) + + def manual_seed(self, seed): + return torch.cuda.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.cuda.manual_seed_all(seed) + + def initial_seed(self, seed): + return torch.cuda.initial_seed(seed) + + def default_generator(self, device_index): + return torch.cuda.default_generators[device_index] + + # Streams/Events + def Stream(self, device=None, priority=0, **kwargs): + return torch.cuda.Stream(device, priority, **kwargs) + + def StreamContext(self, stream): + return torch.cuda.StreamContext(stream) + + def stream(self, stream): + return torch.cuda.stream(stream) + + def current_stream(self, device_index=None): + return torch.cuda.current_stream(device_index) + + def default_stream(self, device_index=None): + return torch.cuda.default_stream(device_index) + + def Event(self, **kwargs): + return torch.cuda.Event(**kwargs) + + # Memory management + def empty_cache(self): + return torch.cuda.empty_cache() + + def memory_allocated(self, device_index=None): + return torch.cuda.memory_allocated(device_index) + + def max_memory_allocated(self, device_index=None): + return torch.cuda.max_memory_allocated(device_index) + + def reset_max_memory_allocated(self, device_index=None): + return torch.cuda.reset_max_memory_allocated(device_index) + + def memory_cached(self, device_index=None): + return torch.cuda.memory_cached(device_index) + + def max_memory_cached(self, device_index=None): + return torch.cuda.max_memory_cached(device_index) + + def reset_max_memory_cached(self, device_index=None): + return torch.cuda.reset_max_memory_cached(device_index) + + def memory_stats(self, device_index=None): + if hasattr(torch.cuda, 'memory_stats'): + return torch.cuda.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + if hasattr(torch.cuda, 'reset_peak_memory_stats'): + return torch.cuda.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + if hasattr(torch.cuda, 'memory_reserved'): + return torch.cuda.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + if hasattr(torch.cuda, 'max_memory_reserved'): + return torch.cuda.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + return torch.cuda.get_device_properties(device_index).total_memory + + # Data types + def is_bf16_supported(self): + return torch.cuda.is_bf16_supported() + + def is_fp16_supported(self): + major, _ = torch.cuda.get_device_capability() + if major >= 7: + return True + else: + return False + + # Misc + def amp(self): + if hasattr(torch.cuda, 'amp'): + return torch.cuda.amp + return None + + def is_available(self): + return torch.cuda.is_available() + + def range_push(self, msg): + if hasattr(torch.cuda.nvtx, 'range_push'): + return torch.cuda.nvtx.range_push(msg) + + def range_pop(self): + if hasattr(torch.cuda.nvtx, 'range_pop'): + return torch.cuda.nvtx.range_pop() + + def lazy_call(self, callback): + return torch.cuda._lazy_call(callback) + + def communication_backend_name(self): + return self._communication_backend_name + + # Tensor operations + def pin_memory(self, tensor): + return tensor.pin_memory() + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('cuda:'): + return True + else: + return False + + def create_op_builder(self, class_name): + from deepspeed.ops.op_builder import AsyncIOBuilder, CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, FusedLambBuilder, QuantizerBuilder, SparseAttnBuilder, StochasticTransformerBuilder, TransformerBuilder, InferenceBuilder, UtilsBuilder + if class_name == "AsyncIOBuilder": + return AsyncIOBuilder() + elif class_name == "CPUAdagradBuilder": + return CPUAdagradBuilder() + elif class_name == "CPUAdamBuilder": + return CPUAdamBuilder() + elif class_name == "FusedAdamBuilder": + return FusedAdamBuilder() + elif class_name == "FusedLambBuilder": + return FusedLambBuilder() + elif class_name == "QuantizerBuilder": + return QuantizerBuilder() + elif class_name == "SparseAttnBuilder": + return SparseAttnBuilder() + elif class_name == "StochasticTransformerBuilder": + return StochasticTransformerBuilder() + elif class_name == "TransformerBuilder": + return TransformerBuilder() + elif class_name == "InferenceBuilder": + return InferenceBuilder() + elif class_name == "UtilsBuilder": + return UtilsBuilder() + else: + return None + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension diff --git a/deepspeed/accelerator/real_accelerator.py b/deepspeed/accelerator/real_accelerator.py new file mode 100644 index 000000000000..5d26a47277dd --- /dev/null +++ b/deepspeed/accelerator/real_accelerator.py @@ -0,0 +1,81 @@ +from .abstract_accelerator import DeepSpeedAccelerator + +ds_accelerator = None + + +def _validate_accelerator(accel_obj): + assert isinstance(accel_obj, DeepSpeedAccelerator), \ + f'{accel_obj.__class__.__name__} accelerator is not subclass of DeepSpeedAccelerator' + + # TODO: turn off is_available test since this breaks tests + #assert accel_obj.is_available(), \ + # f'{accel_obj.__class__.__name__} accelerator fails is_available() test' + + +def get_accelerator(): + global ds_accelerator + if ds_accelerator is None: + try: + from intel_extension_for_deepspeed import XPU_Accelerator + except ImportError as e: + pass + else: + ds_accelerator = XPU_Accelerator() + _validate_accelerator(ds_accelerator) + return ds_accelerator + + from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator + ds_accelerator = CUDA_Accelerator() + _validate_accelerator(ds_accelerator) + return ds_accelerator + + +def set_accelerator(accel_obj): + global ds_accelerator + _validate_accelerator(accel_obj) + ds_accelerator = accel_obj + + +''' +-----------[code] test_get.py ----------- +from deepspeed.accelerator import get_accelerator +my_accelerator = get_accelerator() +print(f'{my_accelerator._name=}') +print(f'{my_accelerator._communication_backend=}') +print(f'{my_accelerator.HalfTensor().device=}') +print(f'{my_accelerator.total_memory()=}') +-----------[code] test_get.py ----------- + +---[output] python test_get.py--------- +my_accelerator.name()='cuda' +my_accelerator.communication_backend='nccl' +my_accelerator.HalfTensor().device=device(type='cuda', index=0) +my_accelerator.total_memory()=34089730048 +---[output] python test_get.py--------- + +************************************************************************** +-----------[code] test_set.py ----------- +from deepspeed.accelerator.cuda_accelerator import CUDA_Accelerator +cu_accel = CUDA_Accelerator() +print(f'{id(cu_accel)=}') +from deepspeed.accelerator import set_accelerator, get_accelerator +set_accelerator(cu_accel) + +my_accelerator = get_accelerator() +print(f'{id(my_accelerator)=}') +print(f'{my_accelerator._name=}') +print(f'{my_accelerator._communication_backend=}') +print(f'{my_accelerator.HalfTensor().device=}') +print(f'{my_accelerator.total_memory()=}') +-----------[code] test_set.py ----------- + + +---[output] python test_set.py--------- +id(cu_accel)=139648165478304 +my_accelerator= +my_accelerator.name='cuda' +my_accelerator.communication_backend='nccl' +my_accelerator.HalfTensor().device=device(type='cuda', index=0) +my_accelerator.total_memory()=34089730048 +---[output] python test_set.py--------- +''' diff --git a/deepspeed/autotuning/autotuner.py b/deepspeed/autotuning/autotuner.py index b8a67075b55e..56c16e6d285e 100755 --- a/deepspeed/autotuning/autotuner.py +++ b/deepspeed/autotuning/autotuner.py @@ -1,6 +1,5 @@ import shutil import subprocess -import torch import time import datetime import math @@ -16,6 +15,7 @@ from .scheduler import ResourceManager from .tuner import GridSearchTuner, RandomTuner, ModelBasedTuner from .utils import * +from deepspeed.accelerator import get_accelerator try: from tabulate import tabulate @@ -249,7 +249,7 @@ def fp16_enabled(self): return False def get_gpu_memory_info(self): - return torch.cuda.get_device_properties(0).total_memory + return get_accelerator().total_memory() def get_activation_memory_per_gpu(self): if self.model_info and "activation_mem_per_gpu" in self.model_info: diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 98536b1e7d21..99496cf063f2 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -31,6 +31,7 @@ from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout from .constants import * +from deepspeed.accelerator import get_accelerator class ReduceOp(Enum): @@ -127,7 +128,7 @@ def log_wrapper(*args, **kwargs): finally: if comms_logger.enabled: # Need to make op blocking for accurate logging - torch.cuda.synchronize() + get_accelerator().synchronize() # If we're using MPI, we can't simply sync the stream if cdb.using_mpi: cdb.barrier() @@ -574,7 +575,7 @@ def get_global_rank(group=None, group_rank=0): # Main DeepSpeed Comms. public API. -def init_distributed(dist_backend="nccl", +def init_distributed(dist_backend=None, auto_mpi_discovery=True, distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True, @@ -629,6 +630,8 @@ def init_distributed(dist_backend="nccl", utils.logger.info('Distributed backend already initialized') else: assert isinstance(timeout, timedelta) + if dist_backend == None: + dist_backend = get_accelerator().communication_backend_name() if int(os.getenv('RANK', '0')) == 0: utils.logger.info( 'Initializing TorchBackend in DeepSpeed with backend {}'.format( diff --git a/deepspeed/env_report.py b/deepspeed/env_report.py index c5949a8c4d31..38f75ce83459 100644 --- a/deepspeed/env_report.py +++ b/deepspeed/env_report.py @@ -2,8 +2,9 @@ import deepspeed import subprocess import argparse -from .ops.op_builder import ALL_OPS +from .ops.op_builder.all_ops import ALL_OPS from .git_version_info import installed_ops, torch_info +from deepspeed.accelerator import get_accelerator GREEN = '\033[92m' RED = '\033[91m' @@ -82,28 +83,46 @@ def debug_report(): hip_version = None if hasattr(torch.version, 'hip'): hip_version = torch.version.hip + if get_accelerator().device_name() == 'cuda': + report = [ + ("torch install path", + torch.__path__), + ("torch version", + torch.__version__), + ("torch cuda version", + torch.version.cuda), + ("torch hip version", + hip_version), + ("nvcc version", + (None if hip_version else nvcc_version())), + ("deepspeed install path", + deepspeed.__path__), + ("deepspeed info", + f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}" + ), + ("deepspeed wheel compiled w.", + f"torch {torch_info['version']}, " + + (f"hip {torch_info['hip_version']}" + if hip_version else f"cuda {torch_info['cuda_version']}")), + ] + else: + report = [ + ("torch install path", + torch.__path__), + ("torch version", + torch.__version__), + ("torch hip version", + hip_version), + ("deepspeed install path", + deepspeed.__path__), + ("deepspeed info", + f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}" + ), + ("deepspeed wheel compiled w.", + f"torch {torch_info['version']} " + + (f", hip {torch_info['hip_version']}" if hip_version else "")), + ] - report = [ - ("torch install path", - torch.__path__), - ("torch version", - torch.__version__), - ("torch cuda version", - torch.version.cuda), - ("torch hip version", - hip_version), - ("nvcc version", - (None if hip_version else nvcc_version())), - ("deepspeed install path", - deepspeed.__path__), - ("deepspeed info", - f"{deepspeed.__version__}, {deepspeed.__git_hash__}, {deepspeed.__git_branch__}" - ), - ("deepspeed wheel compiled w.", - f"torch {torch_info['version']}, " + - (f"hip {torch_info['hip_version']}" - if hip_version else f"cuda {torch_info['cuda_version']}")), - ] print("DeepSpeed general environment info:") for name, value in report: print(name, "." * (max_dots - len(name)), value) diff --git a/deepspeed/git_version_info.py b/deepspeed/git_version_info.py index 5cd6d9f2f940..071d4289b948 100644 --- a/deepspeed/git_version_info.py +++ b/deepspeed/git_version_info.py @@ -11,7 +11,7 @@ git_hash = '[none]' git_branch = '[none]' - from .ops.op_builder import ALL_OPS + from .ops.op_builder.all_ops import ALL_OPS installed_ops = dict.fromkeys(ALL_OPS.keys(), False) compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) torch_info = {'version': "0.0", "cuda_version": "0.0", "hip_version": "0.0"} diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index d51cfda79001..7f77154467c6 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -20,6 +20,7 @@ from ..moe.utils import has_moe_layers from ..runtime.zero import GatheredParameters from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing +from deepspeed.accelerator import get_accelerator from ..module_inject.replace_policy import DSPolicy DS_INFERENCE_ENABLED = False @@ -107,7 +108,7 @@ def __init__(self, # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture self.remove_mask_prepare_for_bloom() - if enable_cuda_graph: + if get_accelerator().device_name() == 'cuda' and enable_cuda_graph: assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ "If you want to use cuda graph, please upgrade torch to at least v1.10" @@ -157,13 +158,14 @@ def __init__(self, save_mp_checkpoint_path=save_mp_checkpoint_path, base_dir=base_dir) - device = torch.cuda.current_device() + device = get_accelerator().current_device_name() self.module.to(device) if self.mp_world_size > 1: - _rng_state = torch.cuda.get_rng_state().to(torch.cuda.current_device()) + _rng_state = get_accelerator().get_rng_state().to( + get_accelerator().current_device_name()) dist.broadcast(_rng_state, 0) - torch.cuda.set_rng_state(_rng_state.cpu()) + get_accelerator().set_rng_state(_rng_state.cpu()) if self.mp_world_size > 1: assert not self.enable_cuda_graph, "Cuda graph is not supported for model parallelism" @@ -184,11 +186,11 @@ def remove_mask_prepare_for_bloom(self): self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask def _pre_forward_hook(self, module, *inputs, **kwargs): - torch.cuda.synchronize() + get_accelerator().synchronize() self._start = time.time() def _post_forward_hook(self, module, input, output): - torch.cuda.synchronize() + get_accelerator().synchronize() self._end = time.time() self._model_times.append(self._end - self._start) @@ -198,7 +200,7 @@ def _create_model_parallel_group(self): init_distributed() local_rank = int(os.getenv('LOCAL_RANK', '0')) - torch.cuda.set_device(local_rank) + get_accelerator().set_device(local_rank) ranks = [i for i in range(self.mp_world_size)] self.mp_group = dist.new_group(ranks) @@ -312,7 +314,7 @@ def load(module, state_dict, prefix): state_dict[prefix + 'bias']) else: data = state_dict[prefix + 'bias'] - data = data.to(torch.cuda.current_device()) + data = data.to(get_accelerator().current_device_name()) module.bias = self.mp_replace.copy(module.bias, data) layer_policies = { @@ -441,7 +443,8 @@ def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): for i in range(1, len(sd_loader)): if not dist.is_initialized() or dist.get_rank() == 0: print(f"loading checkpoint ({i})") - self.sd = torch.load(sd_loader[i], map_location='cuda') + self.sd = torch.load(sd_loader[i], + map_location=get_accelerator().device_name()) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) else: @@ -503,12 +506,12 @@ def _convert_to_dtype(self): def _create_cuda_graph(self, *inputs, **kwargs): # warmup to create the workspace and cublas handle - cuda_stream = torch.cuda.Stream() - cuda_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(cuda_stream): + cuda_stream = get_accelerator().Stream() + cuda_stream.wait_stream(get_accelerator().current_stream()) + with get_accelerator().stream(cuda_stream): for i in range(3): ret = self.module(*inputs, **kwargs) - torch.cuda.current_stream().wait_stream(cuda_stream) + get_accelerator().current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs self._cuda_graphs = torch.cuda.CUDAGraph() @@ -550,11 +553,12 @@ def forward(self, *inputs, **kwargs): **kwargs: variable length keyword arguments """ start = None - if self.model_profile_enabled and self.enable_cuda_graph: - torch.cuda.synchronize() + if self.model_profile_enabled and get_accelerator().device_name( + ) == 'cuda' and self.enable_cuda_graph: + get_accelerator().synchronize() start = time.time() - if self.enable_cuda_graph: + if get_accelerator().device_name() == 'cuda' and self.enable_cuda_graph: if self.cuda_graph_created: outputs = self._graph_replay(*inputs, **kwargs) else: @@ -564,7 +568,7 @@ def forward(self, *inputs, **kwargs): outputs = self.module(*inputs, **kwargs) if self.model_profile_enabled and self.enable_cuda_graph: - torch.cuda.synchronize() + get_accelerator().synchronize() duration = time.time() - start self._model_times.append(duration) diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 9709f304947c..aa407cee8d10 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -5,7 +5,7 @@ import warnings from shlex import split from abc import ABC, abstractmethod - +from deepspeed.accelerator import get_accelerator from ..utils import logger from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE @@ -220,7 +220,8 @@ def __init__(self, args, world_info_base64, resource_pool): self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1') # Enabled cuda-aware communication - self.add_export('MV2_USE_CUDA', '1') + if get_accelerator().device_name() == 'cuda': + self.add_export('MV2_USE_CUDA', '1') # Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/ self.add_export('MV2_SUPPORT_DL', '1') diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index bb78c0d09eb1..1860bcfcbc68 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -16,7 +16,6 @@ from copy import deepcopy import signal import time -import torch.cuda from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER @@ -25,6 +24,7 @@ from ..utils import logger from ..autotuning import Autotuner +from deepspeed.accelerator import get_accelerator DLTS_HOSTFILE = "/job/hostfile" EXPORT_ENVS = ['NCCL', 'PYTHON', 'MV2', 'UCX'] @@ -378,7 +378,7 @@ def main(args=None): multi_node_exec = True if not resource_pool: resource_pool = {} - device_count = torch.cuda.device_count() + device_count = get_accelerator().device_count() if device_count == 0: raise RuntimeError("Unable to proceed, no GPU resources available") resource_pool['localhost'] = device_count diff --git a/deepspeed/module_inject/encoder.py b/deepspeed/module_inject/encoder.py index 41d593daed92..580b344b26ee 100644 --- a/deepspeed/module_inject/encoder.py +++ b/deepspeed/module_inject/encoder.py @@ -2,6 +2,7 @@ Copyright 2022 The Microsoft DeepSpeed Team ''' import torch +from deepspeed.accelerator import get_accelerator class DSClipEncoder(torch.nn.Module): @@ -18,7 +19,7 @@ def _build_causal_attention_mask(self, bsz, seq_len, dtype): seq_len, seq_len, dtype=dtype, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.triu_(1) mask = mask.unsqueeze(1) diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 1ae5bd5ccf9f..c456367b043c 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -4,6 +4,7 @@ from torch.nn import functional as F from torch.nn.parameter import Parameter +from deepspeed.accelerator import get_accelerator class LinearAllreduce(nn.Module): @@ -32,11 +33,11 @@ def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): self.weight = Parameter( torch.empty(weight_shape, dtype=dtype, - device=torch.cuda.current_device())) + device=get_accelerator().current_device_name())) self.bias = Parameter( torch.empty(weight_shape[0], dtype=dtype, - device=torch.cuda.current_device())) + device=get_accelerator().current_device_name())) def forward(self, input): output = torch.matmul(input, self.weight.transpose(-1, -2)) @@ -48,7 +49,9 @@ def forward(self, input): class Normalize(nn.Module): def __init__(self, dim, dtype=torch.float, eps=1e-5): super(Normalize, self).__init__() - self.norm = nn.LayerNorm(dim, eps=eps).to(dtype).to(torch.cuda.current_device()) + self.norm = nn.LayerNorm(dim, + eps=eps).to(dtype).to( + get_accelerator().current_device_name()) self.weight = self.norm.weight self.bias = self.norm.bias @@ -63,7 +66,7 @@ def __init__(self, weight_shape, dtype=torch.float): torch.empty(weight_shape[0], weight_shape[1], dtype=dtype, - device=torch.cuda.current_device())) + device=get_accelerator().current_device_name())) def forward(self, input): return F.embedding(input, self.weight) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 560f1bc83bc7..557ce654a590 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -4,6 +4,7 @@ from .layers import LinearLayer, Normalize, EmbeddingLayer import torch import gc +from deepspeed.accelerator import get_accelerator def load_model_with_checkpoint(r_module, @@ -46,9 +47,10 @@ def load_parameters(module, prefix): if type(sd[0][prefix + n]) is list: tmp_data, scale = sd[0][prefix + n] tmp_data = tmp_data - scale = scale.to(torch.cuda.current_device()) + scale = scale.to(get_accelerator().current_device_name()) else: - tmp_data = sd[0][prefix + n].to(torch.cuda.current_device()) + tmp_data = sd[0][prefix + n].to( + get_accelerator().current_device_name()) scale = None src_shape = tmp_data.shape dst_shape = p.shape @@ -74,7 +76,8 @@ def load_parameters(module, prefix): weight_partition = torch.split( tmp_data, dst_shape[dim1], - dim=dim)[rank].to(torch.cuda.current_device()) + dim=dim)[rank].to( + get_accelerator().current_device_name()) assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \ '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\ Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!''' @@ -90,17 +93,19 @@ def load_parameters(module, prefix): all_data = [ sd[j][prefix + n] if type(sd[j][prefix + n]) is list else - sd[j][prefix + n].to(torch.cuda.current_device()) + sd[j][prefix + n].to( + get_accelerator().current_device_name()) for j in range(len(sd)) ] weight_partition = torch.cat([ - ad[0].to(torch.cuda.current_device()) + ad[0].to(get_accelerator().current_device_name()) if type(ad) is list else ad for ad in all_data ], dim=dim) if tmp_data.dtype == torch.int8: scale = torch.cat([ - ad[1].to(torch.cuda.current_device()) + ad[1].to( + get_accelerator().current_device_name()) for ad in all_data ], dim=dim) @@ -123,15 +128,15 @@ def load_parameters(module, prefix): if src_shape[0] > dst_shape[0]: bias_split = torch.split( tmp_data, - dst_shape[-1])[rank].to( - torch.cuda.current_device()).contiguous() + dst_shape[-1])[rank].to(get_accelerator( + ).current_device_name()).contiguous() p.data.copy_(bias_split) else: p.data.copy_( torch.cat( [sd[j][prefix + n] for j in range(len(sd))], - dim=0).to(torch.cuda.current_device()). - contiguous()) + dim=0).to(get_accelerator( + ).current_device_name()).contiguous()) load_parameters(module, prefix) for n, child in module.named_children(): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index d7fa50eca4ce..bc1ac5ecb412 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -3,6 +3,7 @@ import tqdm import deepspeed import deepspeed.ops.transformer as transformer_inference +from deepspeed.accelerator import get_accelerator from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy from .replace_policy import replace_policies, generic_policies #from ..runtime.weight_quantizer import WeightQuantization @@ -60,10 +61,10 @@ def qkv_copy(self, dst, src): axis=self.out_dim) for i in range(len(qkv_split[0])) ] dst.data.copy_(weight_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + get_accelerator().current_device_name()).contiguous()) else: dst.data.copy_(src_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + get_accelerator().current_device_name()).contiguous()) else: if src_shape[0] == dst_shape[0]: return torch.nn.parameter.Parameter(src) @@ -75,10 +76,10 @@ def qkv_copy(self, dst, src): axis=0) for i in range(len(qkv_split[0])) ] dst.data.copy_(bias_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + get_accelerator().current_device_name()).contiguous()) else: dst.data.copy_(src_split[self.gpu_index].to( - torch.cuda.current_device()).contiguous()) + get_accelerator().current_device_name()).contiguous()) return torch.nn.parameter.Parameter(dst) @@ -98,22 +99,23 @@ def copy(self, dst, src): src, dst_shape[self.in_dim], dim=self.in_dim)[self.gpu_index].to( - torch.cuda.current_device()).contiguous() + get_accelerator().current_device_name()).contiguous() else: self.merge_assert(src_shape[self.out_dim], dst_shape[self.out_dim]) weight_split = torch.split( src.data, dst_shape[self.out_dim], dim=self.out_dim)[self.gpu_index].to( - torch.cuda.current_device()).contiguous() + get_accelerator().current_device_name()).contiguous() dst.data.copy_(weight_split.contiguous()) else: if src_shape[0] == dst_shape[0]: dst.data.copy_(src) else: - bias_split = torch.split(src.data, - dst_shape[-1])[self.gpu_index].to( - torch.cuda.current_device()).contiguous() + bias_split = torch.split( + src.data, + dst_shape[-1])[self.gpu_index].to( + get_accelerator().current_device_name()).contiguous() dst.data.copy_(bias_split) dst = torch.nn.parameter.Parameter(dst, requires_grad=False) if hasattr(src, 'scale'): @@ -149,7 +151,7 @@ def quantize(self, inputs, qkv=True, count=1, parallel_dim=0): inputs.scale = torch.empty(1) return inputs q_range = 2**self.num_bits - inputs = inputs.to(torch.cuda.current_device()) + inputs = inputs.to(get_accelerator().current_device_name()) input_flat = inputs.reshape(self.num_groups, -1).contiguous() input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float() input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float() @@ -216,7 +218,7 @@ def replace_attn(child, policy, layer_id): def transpose(data): data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) - data.to(torch.cuda.current_device()) + data.to(get_accelerator().current_device_name()) return data if len(policy_attn) == 5: @@ -229,7 +231,8 @@ def transpose(data): attn_module.attn_qkvb = None attn_module.attn_ow.data = transpose(attn_ow.data) - attn_module.attn_ob.data.copy_(attn_ob.data.to(torch.cuda.current_device())) + attn_module.attn_ob.data.copy_( + attn_ob.data.to(get_accelerator().current_device_name())) return attn_module if isinstance(module, torch.nn.Module): @@ -516,7 +519,7 @@ def transpose(data): data = data.to('cpu') data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) - data.to(torch.cuda.current_device()) + data.to(get_accelerator().current_device_name()) return data attn_block = new_module.attention @@ -642,28 +645,31 @@ def _transpose(x): for ep_index in range(local_ep_size): mpl_block[ep_index].inter_w.data = _h4h_w[ gpu_index * local_ep_size + ep_index].to( - torch.cuda.current_device()) + get_accelerator().current_device_name()) mpl_block[ep_index].inter_b.data = _h4h_b[ gpu_index * local_ep_size + ep_index].to( - torch.cuda.current_device()) + get_accelerator().current_device_name()) mpl_block[ep_index].output_w.data = _4hh_w[ gpu_index * local_ep_size + ep_index].to( - torch.cuda.current_device()) + get_accelerator().current_device_name()) mpl_block[ep_index].output_b.data = _4hh_b[ gpu_index * local_ep_size + ep_index].to( - torch.cuda.current_device()) - new_module.attn_nw.data = attn_nw.to(torch.cuda.current_device()) - new_module.attn_nb.data = attn_nb.to(torch.cuda.current_device()) + get_accelerator().current_device_name()) + new_module.attn_nw.data = attn_nw.to( + get_accelerator().current_device_name()) + new_module.attn_nb.data = attn_nb.to( + get_accelerator().current_device_name()) if moe_type == 'residual': new_module.res_mlp.inter_w.data = _res_h4h_w.to( - torch.cuda.current_device()) + get_accelerator().current_device_name()) new_module.res_mlp.inter_b.data = _res_h4h_b.to( - torch.cuda.current_device()) + get_accelerator().current_device_name()) new_module.res_mlp.output_w.data = _res_4hh_w.to( - torch.cuda.current_device()) + get_accelerator().current_device_name()) new_module.res_mlp.output_b.data = _res_4hh_b.to( - torch.cuda.current_device()) - new_module.res_coef.data = _res_coef.to(torch.cuda.current_device()) + get_accelerator().current_device_name()) + new_module.res_coef.data = _res_coef.to( + get_accelerator().current_device_name()) else: if _4hh_w.numel() == 0 or _4hh_w.is_meta: @@ -709,14 +715,14 @@ def _transpose(x): else: with GatheredParameters([attn_nw, attn_nb], modifier_rank=0): new_module.mlp.attn_nw.data.copy_( - attn_nw.to(torch.cuda.current_device())) + attn_nw.to(get_accelerator().current_device_name())) new_module.mlp.attn_nb.data.copy_( - attn_nb.to(torch.cuda.current_device())) + attn_nb.to(get_accelerator().current_device_name())) else: new_module.mlp.attn_nw.data.copy_( - attn_nw.to(torch.cuda.current_device())) + attn_nw.to(get_accelerator().current_device_name())) new_module.mlp.attn_nb.data.copy_( - attn_nb.to(torch.cuda.current_device())) + attn_nb.to(get_accelerator().current_device_name())) if input_nw.is_meta or input_nw.numel() == 0: if input_nw.is_meta or input_nw.ds_tensor.numel( @@ -725,12 +731,14 @@ def _transpose(x): else: with GatheredParameters([input_nw, input_nb], modifier_rank=0): new_module.norm_w.data.copy_( - input_nw.to(torch.cuda.current_device())) + input_nw.to(get_accelerator().current_device_name())) new_module.norm_b.data.copy_( - input_nb.to(torch.cuda.current_device())) + input_nb.to(get_accelerator().current_device_name())) else: - new_module.norm_w.data.copy_(input_nw.to(torch.cuda.current_device())) - new_module.norm_b.data.copy_(input_nb.to(torch.cuda.current_device())) + new_module.norm_w.data.copy_( + input_nw.to(get_accelerator().current_device_name())) + new_module.norm_b.data.copy_( + input_nb.to(get_accelerator().current_device_name())) else: transformer_config = deepspeed.DeepSpeedTransformerConfig( batch_size=micro_batch_size if micro_batch_size > 0 else 1, @@ -809,7 +817,7 @@ def _replace(child, name, conv_linear_layer): elif child.bias is not None: new_bias.data.copy_(child.bias.data) return LinearAllreduce(data, child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(new_bias.to(torch.cuda.current_device())), mp_group) + torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group) else: new_weight = torch.empty(( (weight_shape[1] if conv_linear_layer else weight_shape[0]) // @@ -839,12 +847,13 @@ def _replace(child, name, conv_linear_layer): with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0): bias_data = None if child.bias is None else mp_replace.copy( new_bias, - child.bias.data).to(torch.cuda.current_device()) + child.bias.data).to(get_accelerator().current_device_name()) else: bias_data = None if child.bias is None else mp_replace.copy( new_bias, - child.bias.data).to(torch.cuda.current_device()) - return LinearLayer(weight=data.to(torch.cuda.current_device()), + child.bias.data).to(get_accelerator().current_device_name()) + return LinearLayer(weight=data.to( + get_accelerator().current_device_name()), bias=bias_data) def _slice_embedding(child, name, conv_linear_layer): diff --git a/deepspeed/ops/adagrad/cpu_adagrad.py b/deepspeed/ops/adagrad/cpu_adagrad.py index 2527259b1382..98dbcb15fb35 100755 --- a/deepspeed/ops/adagrad/cpu_adagrad.py +++ b/deepspeed/ops/adagrad/cpu_adagrad.py @@ -3,7 +3,7 @@ ''' import torch -from ..op_builder import CPUAdagradBuilder +from deepspeed.accelerator import get_accelerator from deepspeed.utils.logging import should_log_le @@ -24,7 +24,8 @@ def __init__(self, self.opt_id = DeepSpeedCPUAdagrad.optimizer_id DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1 self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adagrad = CPUAdagradBuilder().load() + self.ds_opt_adagrad = get_accelerator().create_op_builder( + "CPUAdagradBuilder").load() self.ds_opt_adagrad.create_adagrad(self.opt_id, lr, diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 911e4924dfbc..e0529bd4aaa9 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -4,9 +4,9 @@ import torch from cpuinfo import get_cpu_info -from ..op_builder import CPUAdamBuilder from deepspeed.utils import logger from deepspeed.utils.logging import should_log_le +from deepspeed.accelerator import get_accelerator class DeepSpeedCPUAdam(torch.optim.Optimizer): @@ -91,7 +91,7 @@ def __init__(self, DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 self.adam_w_mode = adamw_mode self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adam = CPUAdamBuilder().load() + self.ds_opt_adam = get_accelerator().create_op_builder("CPUAdamBuilder").load() self.ds_opt_adam.create_adam(self.opt_id, lr, diff --git a/deepspeed/ops/adam/fused_adam.py b/deepspeed/ops/adam/fused_adam.py index 5a1a1ddcaed3..989823be99d4 100644 --- a/deepspeed/ops/adam/fused_adam.py +++ b/deepspeed/ops/adam/fused_adam.py @@ -9,7 +9,7 @@ from .multi_tensor_apply import MultiTensorApply multi_tensor_applier = MultiTensorApply(2048 * 32) -from ..op_builder import FusedAdamBuilder +from deepspeed.accelerator import get_accelerator class FusedAdam(torch.optim.Optimizer): @@ -69,9 +69,9 @@ def __init__(self, self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none - fused_adam_cuda = FusedAdamBuilder().load() + fused_adam_cuda = get_accelerator().create_op_builder("FusedAdamBuilder").load() # Skip buffer - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + self._dummy_overflow_buf = get_accelerator().IntTensor([0]) self.multi_tensor_adam = fused_adam_cuda.multi_tensor_adam def zero_grad(self): diff --git a/deepspeed/ops/lamb/fused_lamb.py b/deepspeed/ops/lamb/fused_lamb.py index e9210cdda9bc..42bcaebbcf2f 100644 --- a/deepspeed/ops/lamb/fused_lamb.py +++ b/deepspeed/ops/lamb/fused_lamb.py @@ -6,7 +6,7 @@ ''' import types import torch -from ..op_builder import FusedLambBuilder +from deepspeed.accelerator import get_accelerator class FusedLamb(torch.optim.Optimizer): @@ -48,7 +48,8 @@ def __init__(self, max_coeff=10.0, min_coeff=0.01, amsgrad=False): - self.fused_lamb_cuda = FusedLambBuilder().load() + self.fused_lamb_cuda = get_accelerator().create_op_builder( + "FusedLambBuilder").load() if amsgrad: raise RuntimeError('FusedLamb does not support the AMSGrad variant.') diff --git a/deepspeed/ops/quantizer/quantizer.py b/deepspeed/ops/quantizer/quantizer.py index 6b25d02d87e7..5a5f3d84617d 100755 --- a/deepspeed/ops/quantizer/quantizer.py +++ b/deepspeed/ops/quantizer/quantizer.py @@ -3,7 +3,7 @@ ''' import torch -from ..op_builder import QuantizerBuilder +from deepspeed.accelerator import get_accelerator # Cuda modules will be imported if needed quantizer_cuda_module = None @@ -13,7 +13,8 @@ def ds_quantizer(input, groups=1, bit_num=8, sr=False, asym=False): # Load cuda modules if needed global quantizer_cuda_module if quantizer_cuda_module is None: - quantizer_cuda_module = QuantizerBuilder().load() + quantizer_cuda_module = get_accelerator().create_op_builder( + "QuantizerBuilder").load() if sr: if asym: quantize_func = quantizer_cuda_module.ds_sr_quantize_asym_fp16 if input.dtype == torch.half else quantizer_cuda_module.ds_sr_quantize_asym_fp32 diff --git a/deepspeed/ops/sparse_attention/matmul.py b/deepspeed/ops/sparse_attention/matmul.py index 986666bde57f..554b0b18ce6f 100755 --- a/deepspeed/ops/sparse_attention/matmul.py +++ b/deepspeed/ops/sparse_attention/matmul.py @@ -6,6 +6,7 @@ import triton import triton.language as tl import triton._C.libtriton as libtriton +from deepspeed.accelerator import get_accelerator @triton.jit @@ -948,7 +949,7 @@ def _validate_inputs(self, a, b): raise ValueError( f"Inputs must be on the same device; got {a.device} for tensor A " f"and {b.device} for tensor B") - if not a.is_cuda: + if not get_accelerator().on_accelerator(a): raise ValueError("Only GPU devices are supported for now") # When autocast is enabled, torch.matmul autocasts to float16, so we do the same here diff --git a/deepspeed/ops/transformer/inference/attention.py b/deepspeed/ops/transformer/inference/attention.py index f8bad34d5bff..4902bcbd68e7 100644 --- a/deepspeed/ops/transformer/inference/attention.py +++ b/deepspeed/ops/transformer/inference/attention.py @@ -4,10 +4,11 @@ import math import torch from torch.autograd import Function -from ... import op_builder import torch.nn as nn from packaging import version as pkg_version from deepspeed.utils.logging import log_dist +from deepspeed.accelerator import get_accelerator + # Cuda modules will be imported if needed inference_cuda_module = None minus_inf = -10000.0 @@ -169,14 +170,15 @@ def __init__( self.config = config self.config.layer_id = DeepSpeedAttention.layer_id DeepSpeedAttention.layer_id += 1 - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = get_accelerator().current_device_name( + ) if config.bigscience_bloom else 'cpu' qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float global inference_cuda_module if inference_cuda_module is None: - builder = op_builder.InferenceBuilder() + builder = get_accelerator().create_op_builder("InferenceBuilder") inference_cuda_module = builder.load() if DeepSpeedAttention.layer_id == 1: diff --git a/deepspeed/ops/transformer/inference/moe_inference.py b/deepspeed/ops/transformer/inference/moe_inference.py index ca4b5b7a9702..52f0a4caf5a4 100644 --- a/deepspeed/ops/transformer/inference/moe_inference.py +++ b/deepspeed/ops/transformer/inference/moe_inference.py @@ -5,7 +5,6 @@ import math import torch from torch.autograd import Function -from ... import op_builder #from ...inference.engine import inference_cuda_module, specialized_mode # Cuda modules will be imported if needed inference_cuda_module = None @@ -14,6 +13,7 @@ from .transformer_inference import DeepSpeedSelfAttention, DeepSpeedInferenceConfig from ....moe.sharded_moe import TopKGate from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator class DeepSpeedMoEInferenceConfig(DeepSpeedInferenceConfig): @@ -238,15 +238,13 @@ def __init__(self, global specialized_mode if inference_cuda_module is None: specialized_mode = False - if hasattr(op_builder, 'InferenceSpecializedBuilder'): - builder = op_builder.InferenceSpecializedBuilder() - if builder.is_compatible(): - inference_cuda_module = builder.load() - specialized_mode = True - else: - inference_cuda_module = op_builder.InferenceBuilder().load() + builder = get_accelerator().create_op_builder("InferenceSpecializedBuilder") + if builder != None and builder.is_compatible(): + inference_cuda_module = builder.load() + specialized_mode = True else: - inference_cuda_module = op_builder.InferenceBuilder().load() + inference_cuda_module = get_accelerator().create_op_builder( + "InferenceBuilder").load() self.config.specialized_mode = specialized_mode DeepSpeedMoEInference.layer_id += 1 diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index a7e043676875..e650a66939e0 100644 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -5,10 +5,10 @@ import math import torch from torch.autograd import Function -from ... import op_builder import torch.nn as nn from deepspeed import comm as dist from deepspeed.utils.logging import log_dist +from deepspeed.accelerator import get_accelerator from deepspeed.utils.types import ActivationFuncType # Cuda modules will be imported if needed @@ -194,7 +194,7 @@ def split_tensor_along_last_dim(tensor, return tensor_list def backup_attention(mixed_x_layer, layer_past, alibi, input_mask, norm_factor): - alibi = alibi.to(torch.cuda.current_device()) + alibi = alibi.to(get_accelerator().current_device_name()) head_dim = hidden_size_per_partition // num_attention_heads_per_partition new_tensor_shape = mixed_x_layer.size()[:-1] + ( num_attention_heads_per_partition, @@ -501,7 +501,8 @@ def __init__(self, data_type_fp = torch.half if config.fp16 else torch.float self.config.layer_id = DeepSpeedSelfAttention.num_layers DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = get_accelerator().current_device_name( + ) if config.bigscience_bloom else 'cpu' qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, @@ -667,7 +668,8 @@ def __init__(self, self.config = config data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = get_accelerator().current_device_name( + ) if config.bigscience_bloom else 'cpu' self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), @@ -773,7 +775,7 @@ def __init__(self, data_type = torch.half if config.fp16 else torch.float global inference_cuda_module if inference_cuda_module is None: - builder = op_builder.InferenceBuilder() + builder = get_accelerator().create_op_builder("InferenceBuilder") inference_cuda_module = builder.load() if DeepSpeedTransformerInference.layer_id == 1: @@ -792,7 +794,8 @@ def __init__(self, merge_count, mlp_extra_grouping) - device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' + device = get_accelerator().current_device_name( + ) if config.bigscience_bloom else 'cpu' self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type, device=device), diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index 7963d11774e3..48770a78a27d 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -6,8 +6,7 @@ import torch from torch import nn from torch.autograd import Function - -from ..op_builder import TransformerBuilder, StochasticTransformerBuilder +from deepspeed.accelerator import get_accelerator # Cuda modules will be imported if needed transformer_cuda_module = None @@ -481,7 +480,7 @@ def __init__(self, config, initial_weights=None, initial_biases=None): print("DeepSpeed Transformer config is ", self.config.__dict__) if self.config.local_rank >= 0: - torch.cuda.set_device(self.config.local_rank) + get_accelerator().set_device(self.config.local_rank) if initial_weights is None and initial_biases is None: self.attn_qkvw = nn.Parameter( @@ -531,9 +530,11 @@ def __init__(self, config, initial_weights=None, initial_biases=None): # Load cuda modules if needed global transformer_cuda_module, stochastic_transformer_cuda_module if transformer_cuda_module is None and not self.config.stochastic_mode: - transformer_cuda_module = TransformerBuilder().load() + transformer_cuda_module = get_accelerator().create_op_builder( + "TransformerBuilder").load() if stochastic_transformer_cuda_module is None and self.config.stochastic_mode: - stochastic_transformer_cuda_module = StochasticTransformerBuilder().load() + stochastic_transformer_cuda_module = get_accelerator().create_op_builder( + "StochasticTransformerBuilder").load() # create the layer in cuda kernels. cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index d599736b8f1e..750829b1a984 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -6,6 +6,7 @@ from typing import List, Optional from collections import OrderedDict import numpy as np +from deepspeed.accelerator import get_accelerator Tensor = torch.Tensor @@ -98,7 +99,7 @@ def post_hook(module, input, output): module.__post_hook_handle__ = module.register_forward_hook(post_hook) def start_time_hook(module, input): - torch.cuda.synchronize() + get_accelerator().synchronize() module.__start_time__ = time.time() if not hasattr(module, "__start_time_hook_handle"): @@ -106,7 +107,7 @@ def start_time_hook(module, input): start_time_hook) def end_time_hook(module, input, output): - torch.cuda.synchronize() + get_accelerator().synchronize() module.__duration__ += time.time() - module.__start_time__ if not hasattr(module, "__end_time_hook_handle__"): diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index e3a945f2eff6..8f20b390934f 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -20,12 +20,12 @@ import mmap from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.utils import logger from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers +from deepspeed.accelerator import get_accelerator # DeepSpeed Checkpointing Enabled or Disabled deepspeed_checkpointing_enabled = False @@ -98,25 +98,25 @@ def _set_cuda_rng_state(new_state, device=-1): if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): # older PyTorch def cb(): - with device_ctx_manager(device): + with get_accelerator().device(device): _C._cuda_setRNGState(new_state) else: # newer PyTorch if device == -1: - device = torch.device('cuda') + device = torch.device(get_accelerator().device_name()) elif isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - device = torch.device('cuda', device) + device = torch.device(get_accelerator().device_name(), device) def cb(): idx = device.index if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] + idx = get_accelerator().current_device() + default_generator = get_accelerator().default_generator(idx) default_generator.set_state(new_state) - _lazy_call(cb) + get_accelerator().lazy_call(cb) class CudaRNGStatesTracker: @@ -158,10 +158,10 @@ def add(self, name, seed): if name in self.states_: raise Exception('cuda rng state {} already exists'.format(name)) # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() + orig_rng_state = get_accelerator().get_rng_state() # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() + get_accelerator().manual_seed(seed) + self.states_[name] = get_accelerator().get_rng_state() # Reset rng state to what it was. _set_cuda_rng_state(orig_rng_state) @@ -173,7 +173,7 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): if name not in self.states_: raise Exception('cuda rng state {} is not added'.format(name)) # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() + orig_cuda_rng_state = get_accelerator().get_rng_state() # Set rng state to the desired one _set_cuda_rng_state(self.states_[name]) # Do the stuff we wanted to do. @@ -181,7 +181,7 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): yield finally: # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() + self.states_[name] = get_accelerator().get_rng_state() # And set the state to the original state we started with. _set_cuda_rng_state(orig_cuda_rng_state) @@ -199,7 +199,7 @@ def model_parallel_cuda_manual_seed(seed): """Initialize model parallel cuda seed. This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called + initialized. Also, no get_accelerator().manual_seed should be called after this function. Basically, this is replacement for that function. Two set of RNG states are tracked: @@ -235,7 +235,7 @@ def model_parallel_cuda_manual_seed(seed): ) _CUDA_RNG_STATE_TRACKER.reset() # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) + get_accelerator().manual_seed(data_parallel_seed) # and model parallel state. _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) @@ -270,6 +270,12 @@ def gather_partitioned_activations(tensors, device=None): inputs.append(item) continue + # don't need to do all_gather if model parallel size is 1 + if mp_size == 1: + item = item.view(list(size.numpy())) + inputs.append(item) + continue + partition_size = item.numel() tensor_size = partition_size * mp_size if device is not None: @@ -511,7 +517,7 @@ def save_args_for_backward(*all_args): ctx.tensor_flags = tensor_flags if SYNCHRONIZE: - torch.cuda.synchronize() + get_accelerator().synchronize() if timers is None and PROFILE_TIME: timers = Timers() @@ -554,8 +560,8 @@ def save_args_for_backward(*all_args): logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") - cuda_device = torch.cuda.current_device() - transport_stream = torch.cuda.Stream(device=cuda_device) + cuda_device = get_accelerator().current_device_name() + transport_stream = get_accelerator().Stream(device=cuda_device) if PARTITION_ACTIVATIONS: inputs = partition_activations(args, @@ -573,7 +579,7 @@ def save_args_for_backward(*all_args): # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) @@ -601,7 +607,7 @@ def save_args_for_backward(*all_args): timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: - torch.cuda.synchronize() + get_accelerator().synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): @@ -628,7 +634,7 @@ def backward(ctx, *grads): # so that they can be garbage collected once the checkpoints # have been used if SYNCHRONIZE: - torch.cuda.synchronize() + get_accelerator().synchronize() if PROFILE_TIME: timers('backward').start() @@ -654,7 +660,7 @@ def backward(ctx, *grads): global cuda_device, transport_stream, PARTITION_ACTIVATIONS if PARTITION_ACTIVATIONS: - # with torch.cuda.stream(transport_stream): + # with get_accelerator().stream(transport_stream): inputs = gather_partitioned_activations( ctx.deepspeed_saved_tensors, device=cuda_device if CPU_CHECKPOINT else None) @@ -675,7 +681,7 @@ def backward(ctx, *grads): # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state = get_accelerator().get_rng_state() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() # Set the states to what it used to be before the forward pass. @@ -684,7 +690,7 @@ def backward(ctx, *grads): get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) # if PARTITION_ACTIVATIONS: - # current_stream=torch.cuda.current_stream() + # current_stream=get_accelerator().current_stream() # current_stream.wait_stream(transport_stream) see_memory_usage("In backward checkpointing code before forward", force=False) @@ -729,7 +735,7 @@ def backward(ctx, *grads): timers('backward').stop() timers.log(['backward']) if SYNCHRONIZE: - torch.cuda.synchronize() + get_accelerator().synchronize() ret_list = [None, None] # first None for ctx for inp in detached_inputs: if torch.is_tensor(inp): @@ -856,7 +862,7 @@ def configure( checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with partition_activation. Default is false. Will overwrite deepspeed_config if provided - synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of + synchronize: Optional: Performs get_accelerator().synchronize() at the beginning and end of each call to deepspeed.checkpointing.checkpoint for both forward and backward pass. By default false. Will overwrite deepspeed_config if provided diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 49c51fa7c91f..c03474a9d61b 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -8,7 +8,7 @@ import os from deepspeed import comm as dist from deepspeed.runtime.constants import PIPE_REPLICATED -from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.accelerator import get_accelerator from deepspeed.runtime import ZeROOptimizer from packaging import version as pkg_version @@ -63,7 +63,7 @@ def __init__(self, ] # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() + util_ops = get_accelerator().create_op_builder("UtilsBuilder").load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten diff --git a/deepspeed/runtime/comm/nccl.py b/deepspeed/runtime/comm/nccl.py index c892316a3dc5..692c6fc33236 100644 --- a/deepspeed/runtime/comm/nccl.py +++ b/deepspeed/runtime/comm/nccl.py @@ -8,6 +8,7 @@ import numpy as np from deepspeed.runtime.compression.cupy import CupyBackend +from deepspeed.accelerator import get_accelerator class NcclBackend(object): @@ -100,7 +101,8 @@ def compressed_allreduce(self, recvbuf_scale = [ torch.zeros(1, dtype=worker_scale.dtype, - device=torch.device(local_rank)) for i in range(self.size) + device=torch.device(get_accelerator().device_name(local_rank))) + for i in range(self.size) ] # communication phase 1 diff --git a/deepspeed/runtime/dataloader.py b/deepspeed/runtime/dataloader.py index 84b206e1eecd..2f1fefac7486 100644 --- a/deepspeed/runtime/dataloader.py +++ b/deepspeed/runtime/dataloader.py @@ -2,9 +2,9 @@ Copyright 2019 The Microsoft DeepSpeed Team ''' -import torch from torch.utils.data import DataLoader, RandomSampler from torch.utils.data.distributed import DistributedSampler +from deepspeed.accelerator import get_accelerator class RepeatingLoader: @@ -55,7 +55,7 @@ def __init__(self, else: if data_sampler is None: data_sampler = RandomSampler(dataset) - device_count = torch.cuda.device_count() + device_count = get_accelerator().device_count() batch_size *= device_count if num_local_io_workers is None: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0b54642abb0a..e77078fe6a0e 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -68,7 +68,6 @@ from .pipe.module import PipelineModule from .utils import ensure_directory_exists, get_ma_status -from ..ops.op_builder import UtilsBuilder from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE @@ -78,6 +77,8 @@ from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler from deepspeed.utils.logging import print_json_dist +from deepspeed.accelerator import get_accelerator + # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init dist = None @@ -98,11 +99,12 @@ def split_half_float_double_sparse(tensors): + device_type = get_accelerator().device_name() supported_types = [ - "torch.cuda.HalfTensor", - "torch.cuda.FloatTensor", - "torch.cuda.DoubleTensor", - "torch.cuda.BFloat16Tensor", + "torch.{}.HalfTensor".format(device_type), + "torch.{}.FloatTensor".format(device_type), + "torch.{}.DoubleTensor".format(device_type), + "torch.{}.BFloat16Tensor".format(device_type), SparseTensor.type() ] @@ -215,7 +217,7 @@ def __init__( self.eigenvalue = None self.block_eigenvalue = None self.gas_boundary_ctr = 0 - self.dist_backend = "nccl" + self.dist_backend = get_accelerator().communication_backend_name() self.has_moe_layers = False self.num_experts = [] self.gate_modules = [] @@ -371,7 +373,7 @@ def __init__( print_configuration(self, "DeepSpeedEngine") # Load pre-installed or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() + util_ops = get_accelerator().create_op_builder("UtilsBuilder").load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten @@ -865,14 +867,14 @@ def _set_distributed_vars(self, args): args, 'device_rank') else self.local_rank if device_rank >= 0: - torch.cuda.set_device(device_rank) - self.device = torch.device("cuda", device_rank) + get_accelerator().set_device(device_rank) + self.device = torch.device(get_accelerator().device_name(), device_rank) self.world_size = dist.get_world_size() self.global_rank = dist.get_rank() else: self.world_size = 1 self.global_rank = 0 - self.device = torch.device("cuda") + self.device = torch.device(get_accelerator().device_name()) # Configure based on command line arguments def _configure_with_arguments(self, args, mpu): diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index aeed2f4b18e1..68152e418468 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -14,6 +14,7 @@ from deepspeed.utils import groups, logger, log_dist from deepspeed import comm as dist from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD +from deepspeed.accelerator import get_accelerator class FP16_Optimizer(DeepSpeedOptimizer): @@ -41,8 +42,8 @@ def __init__(self, self.deepspeed = deepspeed self.has_moe_layers = has_moe_layers self.using_pipeline = self.deepspeed.pipeline_parallelism - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") + if not get_accelerator().is_available(): + raise SystemError("No accelerator or accelerator does not support FP16.") self.optimizer = init_optimizer # param flattened by groups @@ -457,7 +458,7 @@ def load_state_dict(self, state_dict, load_optimizer_states=True): will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() + model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py index 706d2a3dac1f..5eb22fb64d73 100644 --- a/deepspeed/runtime/fp16/onebit/adam.py +++ b/deepspeed/runtime/fp16/onebit/adam.py @@ -4,6 +4,7 @@ import types import torch import numpy as np +from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist @@ -174,12 +175,12 @@ def step(self, closure=None, grads=None): (self.size * self.divider))) state['server_chunk_size'] = state[ 'corrected_tensor_size'] // self.size - torch.cuda.empty_cache() + get_accelerator().empty_cache() state['worker_error'] = torch.zeros(state['corrected_tensor_size'], device=p.device) state['server_error'] = torch.zeros(state['server_chunk_size'], device=p.device) - torch.cuda.empty_cache() + get_accelerator().empty_cache() self.adam_freeze_key = True if not self.initialize and dist.get_rank() == 0: print("Cupy Buffers Initialized Successfully.") diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py index 696550ca41ba..87c24695e23d 100644 --- a/deepspeed/runtime/fp16/onebit/lamb.py +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -6,6 +6,7 @@ import numpy as np from deepspeed import comm as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from deepspeed.accelerator import get_accelerator class OnebitLamb(torch.optim.Optimizer): @@ -283,7 +284,7 @@ def step(self, closure=None, grads=None): p.data = q.data if self.initialize and len(self.worker_errors) == 0: - torch.cuda.empty_cache() + get_accelerator().empty_cache() for i in range(len(self.exp_avg_flat)): self.worker_errors.append( torch.zeros(self.corrected_tensor_sizes[i], @@ -291,20 +292,20 @@ def step(self, closure=None, grads=None): self.server_errors.append( torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device)) - torch.cuda.empty_cache() + get_accelerator().empty_cache() if self.lamb_freeze_key: if self.size > 1: for i in range(len(self.exp_avg_flat)): if not self.initialize: - torch.cuda.empty_cache() + get_accelerator().empty_cache() self.worker_errors.append( torch.zeros(self.corrected_tensor_sizes[i], device=self.exp_avg_flat[i].device)) self.server_errors.append( torch.zeros(self.server_chunk_sizes[i], device=self.exp_avg_flat[i].device)) - torch.cuda.empty_cache() + get_accelerator().empty_cache() if dist.get_rank() == 0: print("Cupy Buffers Initialized Successfully.") diff --git a/deepspeed/runtime/fp16/onebit/zoadam.py b/deepspeed/runtime/fp16/onebit/zoadam.py index 53dfde99ae38..f86ae86f36cb 100644 --- a/deepspeed/runtime/fp16/onebit/zoadam.py +++ b/deepspeed/runtime/fp16/onebit/zoadam.py @@ -4,6 +4,7 @@ import types import torch import numpy as np +from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist @@ -185,14 +186,14 @@ def step(self, closure=None, grads=None): (self.size * self.divider))) state['server_chunk_size'] = state[ 'corrected_tensor_size'] // self.size - torch.cuda.empty_cache() + get_accelerator().empty_cache() state['worker_error'] = torch.zeros(state['corrected_tensor_size'], device=p.device) state['server_error'] = torch.zeros(state['server_chunk_size'], device=p.device) # Accumulation of momentum, i.e., the u variable in the 0/1 Adam paper state['momentum_accumulator'] = torch.zeros_like(p.data) - torch.cuda.empty_cache() + get_accelerator().empty_cache() # self.freeze_key = True if not self.initialize and dist.get_rank() == 0: print("Cupy Buffers Initialized Successfully.") diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index 88f0df443405..0a6f21f4e9d2 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -14,6 +14,7 @@ from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT +from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist @@ -40,8 +41,8 @@ def __init__(self, if dist.get_rank() == 0: logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ') - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") + if not get_accelerator().is_available(): + raise SystemError("No accelerator or accelerator does not support FP16.") self.optimizer = init_optimizer # param groups @@ -387,7 +388,7 @@ def load_state_dict(self, state_dict, load_optimizer_states=True): will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() + model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... @@ -432,13 +433,13 @@ def initialize_optimizer_states(self): for param in group: param.grad = torch.zeros(param.size(), dtype=param.dtype, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) for i, group in enumerate(self.fp32_groups): for param in group: param.grad = torch.zeros(param.size(), dtype=param.dtype, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) self.optimizer.step() diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 1a57bb4e84a2..0e1dda98f1e1 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -7,6 +7,7 @@ from deepspeed.utils import logger from deepspeed.utils.timer import ThroughputTimer +from deepspeed.accelerator import get_accelerator from ..engine import DeepSpeedEngine, MEMORY_OPT_ALLREDUCE_SIZE from ..utils import PartitionedTensor @@ -1272,14 +1273,14 @@ def mem_status(self, msg, print_rank=-1, reset_max=False): if print_rank != -1 and rank != print_rank: return - torch.cuda.synchronize() + get_accelerator().synchronize() if reset_max: - torch.cuda.reset_max_memory_cached() - torch.cuda.reset_max_memory_allocated() + get_accelerator().reset_max_memory_cached() + get_accelerator().reset_max_memory_allocated() - new_alloced = torch.cuda.memory_allocated() - new_cached = torch.cuda.memory_cached() + new_alloced = get_accelerator().memory_allocated() + new_cached = get_accelerator().memory_cached() delta_alloced = new_alloced - mem_alloced delta_cached = new_cached - mem_cached @@ -1287,8 +1288,8 @@ def mem_status(self, msg, print_rank=-1, reset_max=False): mem_cached = new_cached mem_alloced = new_alloced - max_alloced = torch.cuda.max_memory_allocated() - max_cached = torch.cuda.max_memory_cached() + max_alloced = get_accelerator().max_memory_allocated() + max_cached = get_accelerator().max_memory_cached() # convert to GB for printing new_alloced /= 1024**3 diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 294db38b3bfb..27272e80fc42 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -14,6 +14,7 @@ from ..activation_checkpointing import checkpointing from .topology import PipeDataParallelTopology, PipelineParallelGrid from deepspeed.runtime.state_dict_factory import SDLoaderFactory +from deepspeed.accelerator import get_accelerator class PipelineError(Exception): @@ -192,12 +193,12 @@ def forward(self, inputs): self.tied_weight_attrs = {} # Offset the random seed by the stage ID. - #newseed = torch.cuda.initial_seed() + self._grid.get_stage_id() + #newseed = get_accelerator().initial_seed() + self._grid.get_stage_id() #ds_utils.set_random_seed(newseed) - #with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + #with torch.random.fork_rng(devices=[get_accelerator().current_device_name()]): self._build() - self.to(f'cuda:{self.local_rank}') + self.to(get_accelerator().device_name(self.local_rank)) self.tied_comms = self._index_tied_modules() self._synchronize_tied_weights() diff --git a/deepspeed/runtime/pipe/p2p.py b/deepspeed/runtime/pipe/p2p.py index d0c9c2f9f364..93c61114c558 100644 --- a/deepspeed/runtime/pipe/p2p.py +++ b/deepspeed/runtime/pipe/p2p.py @@ -11,6 +11,7 @@ # To query whether we have send/recv support from packaging.version import Version from deepspeed.git_version_info import torch_info +from deepspeed.accelerator import get_accelerator _groups = None _grid = None @@ -92,7 +93,7 @@ def wait(): op.wait() _async = [] - torch.cuda.synchronize() + get_accelerator().synchronize() def send_obj(msg: typing.Any, dest: int): @@ -110,10 +111,12 @@ def send_obj(msg: typing.Any, dest: int): # serialize the message msg = pickle.dumps(msg) # construct a tensor to send - msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).cuda() + msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to( + get_accelerator().device_name()) # Send meta and message - length_tensor = torch.tensor([len(msg)], dtype=torch.long).cuda() + length_tensor = torch.tensor([len(msg)], + dtype=torch.long).to(get_accelerator().device_name()) dist.send(length_tensor, dst=dest) dist.send(msg, dst=dest) @@ -128,11 +131,12 @@ def recv_obj(sender: int) -> typing.Any: sender (int): The rank sending the message. """ # Get message meta - length = torch.tensor([0], dtype=torch.long).cuda() + length = torch.tensor([0], dtype=torch.long).to(get_accelerator().device_name()) dist.recv(length, src=sender) # Receive and deserialize - msg = torch.empty(length.item(), dtype=torch.uint8).cuda() + msg = torch.empty(length.item(), + dtype=torch.uint8).to(get_accelerator().device_name()) dist.recv(msg, src=sender) msg = pickle.loads(msg.cpu().numpy().tobytes()) @@ -140,7 +144,7 @@ def recv_obj(sender: int) -> typing.Any: def _to(x): """Recursively move to the current device.""" if torch.is_tensor(x): - return x.cuda() + return x.to(get_accelerator().device_name()) if isinstance(x, (tuple, list)): ret = [_to(x_) for x_ in x] if isinstance(x, tuple): diff --git a/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py index c83a69544d56..2c4e338e9994 100644 --- a/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py +++ b/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py @@ -8,7 +8,7 @@ import torch from deepspeed.utils.logging import logger -from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist from deepspeed.runtime.swap_tensor.constants import * @@ -44,7 +44,7 @@ def __init__(self, dtype, timers) - aio_op = AsyncIOBuilder().load() + aio_op = get_accelerator().create_op_builder("AsyncIOBuilder").load() self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH], aio_config[AIO_SINGLE_SUBMIT], diff --git a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py index 22e11b01f0f4..b2596b4a57c0 100644 --- a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py +++ b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py @@ -10,8 +10,7 @@ from enum import Enum import torch from deepspeed import comm as dist - -from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator from .constants import * from .utils import swap_in_tensors, swap_out_tensors, MIN_AIO_BYTES, AIO_ALIGNED_BYTES, print_object, SwapBufferPool @@ -35,7 +34,8 @@ class PartitionedParamStatus(Enum): class AsyncPartitionedParameterSwapper(object): def __init__(self, ds_config, model_dtype): - aio_op = AsyncIOBuilder().load(verbose=False) + aio_op = get_accelerator().create_op_builder("AsyncIOBuilder").load( + verbose=False) self.aio_handle = aio_op.aio_handle self.dtype = model_dtype @@ -107,11 +107,10 @@ def _configure_aio(self, ds_config): self.available_buffer_ids = [i for i in range(self.param_buffer_count)] self.reserved_buffer_ids = [] - self.buffers = torch.empty(int(self.aligned_elements_per_buffer * - self.param_buffer_count), - dtype=self.dtype, - pin_memory=True, - requires_grad=False) + self.buffers = get_accelerator().pin_memory( + torch.empty(int(self.aligned_elements_per_buffer * self.param_buffer_count), + dtype=self.dtype, + requires_grad=False)) self.aio_read_handle = self.aio_handle(self.aio_config[AIO_BLOCK_SIZE], self.aio_config[AIO_QUEUE_DEPTH], @@ -393,9 +392,10 @@ def _is_io_aligned(self, numel): def reserve_partitioned_swap_space(self, partition_num_elems): aligned_numel = sum( [self._io_aligned_numel(numel) for numel in partition_num_elems]) - self.partitioned_swap_buffer = torch.zeros(aligned_numel, - device='cpu', - dtype=self.dtype).pin_memory() + self.partitioned_swap_buffer = get_accelerator().pin_memory( + torch.zeros(aligned_numel, + device='cpu', + dtype=self.dtype)) self.partitioned_swap_pool = SwapBufferPool([self.partitioned_swap_buffer]) def swap_out_partitioned_params(self, dst_fp16_params, src_fp32_params): diff --git a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py index c74a40ca7891..648813c961d5 100644 --- a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py +++ b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py @@ -5,7 +5,7 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices. """ -from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist from deepspeed.runtime.swap_tensor.constants import * @@ -72,7 +72,7 @@ def __init__(self, dtype, timers) - aio_op = AsyncIOBuilder().load() + aio_op = get_accelerator().create_op_builder("AsyncIOBuilder").load() self.write_aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH], aio_config[AIO_SINGLE_SUBMIT], diff --git a/deepspeed/runtime/swap_tensor/utils.py b/deepspeed/runtime/swap_tensor/utils.py index 3059e0d15798..1d236bd7fcac 100644 --- a/deepspeed/runtime/swap_tensor/utils.py +++ b/deepspeed/runtime/swap_tensor/utils.py @@ -7,6 +7,7 @@ import torch from deepspeed.utils.logging import logger +from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist @@ -179,9 +180,10 @@ def __init__(self, num_elems, count, dtype): self.count = count self.dtype = dtype self.all_buffers = [ - torch.zeros(num_elems, - device='cpu', - dtype=dtype).pin_memory() for _ in range(count) + get_accelerator().pin_memory( + torch.zeros(num_elems, + device='cpu', + dtype=dtype)) for _ in range(count) ] self.free_buffer_index = [i for i in range(count)] self.used_buffer_index = {} diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index fa90486cf8d5..c56cfe18c54c 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -22,16 +22,10 @@ from deepspeed.utils import groups, logger from deepspeed.runtime.constants import PIPE_REPLICATED from numpy import prod +from deepspeed.accelerator import get_accelerator -# pt-1.9 deprecations -if hasattr(torch.cuda, "memory_reserved"): - torch_memory_reserved = torch.cuda.memory_reserved -else: - torch_memory_reserved = torch.cuda.memory_allocated -if hasattr(torch.cuda, "max_memory_reserved"): - torch_max_memory_reserved = torch.cuda.max_memory_reserved -else: - torch_max_memory_reserved = torch.cuda.memory_cached +torch_memory_reserved = get_accelerator().memory_reserved +torch_max_memory_reserved = get_accelerator().max_memory_reserved class DummyOptim(): @@ -191,7 +185,7 @@ def __init__(self, def check_using_norm(self, norm_group, reduce_overflow=True): # TODO: I don't think reduce_overflow is needed if mpu is None overflow = -1 in norm_group - overflow_gpu = torch.cuda.FloatTensor([overflow]) + overflow_gpu = get_accelerator().FloatTensor([overflow]) if self.has_moe_params: # In this case, we need to do an all_reduce across # the expert_parallel_group, so that if there was @@ -242,7 +236,7 @@ def has_overflow(self, params, has_moe_params=None): overflow = self.has_overflow_serial(params) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs - overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow_gpu = get_accelerator().ByteTensor([overflow]) # deepspeeed.comm.all_reduce(overflow_gpu, # op=deepspeed.comm.ReduceOp.MAX, # group=mpu.get_model_parallel_group()) @@ -352,7 +346,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, @@ -372,7 +366,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, @@ -383,7 +377,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): pg = groups._get_data_parallel_group() scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) - scaled_norm_tensor = torch.cuda.FloatTensor([float(scaled_norm)]) + scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)]) dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor.item() @@ -418,7 +412,7 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.grad.data.abs().max() for p in parameters) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, @@ -442,7 +436,7 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, @@ -488,7 +482,7 @@ def get_grad_zeros(parameters, mpu=None): total_zeros += count_zeros.item() # Sum across all model parallel GPUs. - total_zeros_cuda = torch.cuda.FloatTensor([float(total_zeros)]) + total_zeros_cuda = get_accelerator().FloatTensor([float(total_zeros)]) if mpu is not None: dist.all_reduce(total_zeros_cuda, op=dist.ReduceOp.SUM, @@ -521,7 +515,7 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): norm_type = float(norm_type) if norm_type == inf: total_norm = max(p.data.abs().max() for p in parameters) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm_cuda, @@ -545,7 +539,7 @@ def get_weight_norm(parameters, norm_type=2, mpu=None): total_norm += param_norm**norm_type # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, @@ -669,7 +663,7 @@ def __init__(self, tensor, group, partition_meta=None): self.local_data, self.partition = self._partition_tensor(tensor) @classmethod - def from_meta(cls, meta, local_part, group, device='cuda'): + def from_meta(cls, meta, local_part, group, device=get_accelerator().device_name()): assert meta.dtype == torch.long dummy = torch.ones(dist.get_world_size(group=group)) part_obj = cls(tensor=dummy, group=group) @@ -773,14 +767,14 @@ def memory_status(msg, print_rank=-1, reset_max=False): if print_rank != -1 and rank != print_rank: return - torch.cuda.synchronize() + get_accelerator().synchronize() if reset_max: - torch.cuda.reset_max_memory_cached() - torch.cuda.reset_max_memory_allocated() + get_accelerator().reset_max_memory_cached() + get_accelerator().reset_max_memory_allocated() - new_alloced = torch.cuda.memory_allocated() - new_cached = torch.cuda.memory_cached() + new_alloced = get_accelerator().memory_allocated() + new_cached = get_accelerator().memory_cached() delta_alloced = new_alloced - mem_alloced delta_cached = new_cached - mem_cached @@ -788,8 +782,8 @@ def memory_status(msg, print_rank=-1, reset_max=False): mem_cached = new_cached mem_alloced = new_alloced - max_alloced = torch.cuda.max_memory_allocated() - max_cached = torch.cuda.max_memory_cached() + max_alloced = get_accelerator().max_memory_allocated() + max_cached = get_accelerator().max_memory_cached() # convert to GB for printing new_alloced /= 1024**3 @@ -802,7 +796,7 @@ def memory_status(msg, print_rank=-1, reset_max=False): print( f'RANK={rank} MEMSTATS', msg, - f'device={torch.cuda.current_device()} ' + f'device={get_accelerator().current_device_name()} ' f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) ' f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)' ) @@ -811,7 +805,7 @@ def memory_status(msg, print_rank=-1, reset_max=False): def get_ma_status(): if dist.is_initialized() and not dist.get_rank() == 0: return 0 - return torch.cuda.memory_allocated() + return get_accelerator().memory_allocated() def see_memory_usage(message, force=False): @@ -826,8 +820,8 @@ def see_memory_usage(message, force=False): # Print message except when distributed but not rank 0 logger.info(message) logger.info( - f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ - Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ + f"MA {round(get_accelerator().memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ + Max_MA {round(get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \ Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB ") @@ -837,8 +831,7 @@ def see_memory_usage(message, force=False): f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%') # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ - torch.cuda.reset_peak_memory_stats() + get_accelerator().reset_peak_memory_stats() def call_to_str(base, *args, **kwargs): @@ -912,7 +905,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): norm_type = float(norm_type) if norm_type == inf: total_norm = max(t.data.abs().max() for t in input_tensors) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, @@ -921,7 +914,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): else: total_norm = sum( [t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, diff --git a/deepspeed/runtime/weight_quantizer.py b/deepspeed/runtime/weight_quantizer.py index 8b7b1eb9d8ea..491493df3045 100644 --- a/deepspeed/runtime/weight_quantizer.py +++ b/deepspeed/runtime/weight_quantizer.py @@ -1,5 +1,6 @@ import torch from ..module_inject.replace_policy import HFBertLayerPolicy, replace_policies +from deepspeed.accelerator import get_accelerator class WeightQuantization(object): @@ -44,9 +45,11 @@ def Quantize(self, value_list, quantize_bits, groups, key, merge_dim=0): q_scale.append(data_scale) value_list[index] = data_int index += 1 - q_scale = (1 / torch.cat(q_scale, - dim=merge_dim).to( - torch.cuda.current_device()).view(-1).unsqueeze(0)) + q_scale = ( + 1 / + torch.cat(q_scale, + dim=merge_dim).to( + get_accelerator().current_device_name()).view(-1).unsqueeze(0)) if "mlp.dense_4h_to_h.weight" in key: self.mlp4hh_scales.append(q_scale) elif "mlp.dense_h_to_4h.weight" in key: @@ -63,7 +66,7 @@ def merge_layer_scales(self, layer_scales): torch.cat((s, torch.zeros((1, max_dim - s.shape[-1]), - device=torch.cuda.current_device())), + device=get_accelerator().current_device_name())), dim=-1) if s.shape[-1] < max_dim else s for s in layer_scales ] return torch.cat(layer_scales).unsqueeze(0) @@ -134,9 +137,8 @@ def quantize_fn(layer, policy_cls): else: data_quantized, data_scale = self.quantize_data(keys[key], quantize_bits, groups) keys[key].copy_(data_quantized) - layer_scales.append( - (1 / - data_scale.to(torch.cuda.current_device()).view(-1).unsqueeze(0))) + layer_scales.append((1 / data_scale.to( + get_accelerator().current_device_name()).view(-1).unsqueeze(0))) all_scales.append(self.merge_layer_scales(layer_scales)) return layer diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 02a585b02b01..7dd849dfba11 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -19,6 +19,7 @@ from torch.nn.modules.module import Module from deepspeed.runtime.utils import noop_decorator from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator tensor_map = {} @@ -29,8 +30,13 @@ def print_rank_0(message, debug=False, force=False): try: - autocast_custom_fwd = torch.cuda.amp.custom_fwd - autocast_custom_bwd = torch.cuda.amp.custom_bwd + device = get_accelerator().device_name() + if device == 'cuda': + autocast_custom_fwd = torch.cuda.amp.custom_fwd + autocast_custom_bwd = torch.cuda.amp.custom_bwd + else: + autocast_custom_fwd = noop_decorator + autocast_custom_bwd = noop_decorator except (ImportError, AttributeError) as exp: autocast_custom_fwd = noop_decorator autocast_custom_bwd = noop_decorator diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index f0f0c4f56e40..b9178c39e287 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -5,7 +5,6 @@ import sys import torch -from torch.cuda import Stream from collections import OrderedDict from deepspeed.runtime.utils import see_memory_usage from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -13,6 +12,7 @@ from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator FWD_MODULE_STACK = list() @@ -217,8 +217,8 @@ def __init__(self, self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) - self.__allgather_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() + self.__allgather_stream = get_accelerator().Stream( + ) if overlap_comm else get_accelerator().default_stream() self.forward_hooks = [] self.backward_hooks = [] diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index b6bd5ed645f9..ec1df03c9e09 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -31,6 +31,7 @@ debug_module2name, debug_param2name_id, debug_param2name_id_shape_status) +from deepspeed.accelerator import get_accelerator from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus param_count = 0 @@ -191,7 +192,8 @@ def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable: def wrapped_fn(*args, **kwargs) -> Tensor: if kwargs.get("device", None) is None: - kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + kwargs['device'] = torch.device(get_accelerator().device_name( + os.environ["LOCAL_RANK"])) tensor: Tensor = fn(*args, **kwargs) if tensor.is_floating_point(): tensor = tensor.to(target_fp_dtype) @@ -203,7 +205,7 @@ def wrapped_fn(*args, **kwargs) -> Tensor: def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable: def new_tensor(cls, *args) -> Tensor: - device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) tensor = _orig_torch_empty(0, device=device).new_empty(*args) if tensor.is_floating_point(): tensor = tensor.to(dtype) @@ -231,10 +233,10 @@ def recurse(cl): def free_param(param: Parameter) -> None: """Free underlying storage of a parameter.""" assert not param.ds_active_sub_modules, param.ds_summary() - if param.data.is_cuda: + if get_accelerator().on_accelerator(param.data): # need to make sure that we don't free the parameter while it is still # being used for computation - param.data.record_stream(torch.cuda.current_stream()) + param.data.record_stream(get_accelerator().current_stream()) # param.data doesn't store anything meaningful in partitioned state param.data = torch.empty(0, dtype=param.dtype, device=param.device) param.ds_status = ZeroParamStatus.NOT_AVAILABLE @@ -526,7 +528,7 @@ def wait(self) -> None: param.ds_status = ZeroParamStatus.AVAILABLE for part_to_copy in partitions: - part_to_copy.record_stream(torch.cuda.current_stream()) + part_to_copy.record_stream(get_accelerator().current_stream()) param_offset += param.ds_tensor.ds_numel @@ -672,8 +674,9 @@ def get_model(): # Local device is the device where the parameters are consumed, must be default device. # It is the device where parameters are fully instantiated using allgather - self.local_device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - torch.cuda.set_device(self.local_device) + self.local_device = torch.device(get_accelerator().device_name( + os.environ["LOCAL_RANK"])) + get_accelerator().set_device(self.local_device) if _ds_config is not None and _ds_config.zero_config.offload_param is not None: remote_device = _ds_config.zero_config.offload_param.device @@ -747,7 +750,7 @@ def _post_init_method(self, module): f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}" ) - if param.is_cuda: + if get_accelerator().on_accelerator(param): dist.broadcast(param, 0, self.ds_process_group) else: if dist.get_rank() == 0: @@ -839,11 +842,11 @@ def all_gather_coalesced(params: Iterable[Parameter], param_buffer = torch.empty( math.ceil(param.ds_numel / self.world_size) * self.world_size, dtype=param.dtype, - device=torch.cuda.current_device(), + device=get_accelerator().current_device_name(), requires_grad=False, ) handle = _dist_allgather_fn( - param.ds_tensor.to(torch.cuda.current_device()), + param.ds_tensor.to(get_accelerator().current_device_name()), param_buffer, self.ds_process_group) param.data = param_buffer.narrow(0, @@ -856,7 +859,7 @@ def all_gather_coalesced(params: Iterable[Parameter], flat_tensor = torch.empty(partition_sz * self.world_size, dtype=get_only_unique_item(p.dtype for p in params), - device=torch.cuda.current_device(), + device=get_accelerator().current_device_name(), requires_grad=False) partitions: List[Parameter] = [] for i in range(self.world_size): @@ -865,9 +868,11 @@ def all_gather_coalesced(params: Iterable[Parameter], partition_sz * i, partition_sz)) - instrument_w_nvtx(torch.cat)( - [p.ds_tensor.to(torch.cuda.current_device()) for p in params], - out=partitions[self.rank]) + instrument_w_nvtx(torch.cat)([ + p.ds_tensor.to(get_accelerator().current_device_name()) + for p in params + ], + out=partitions[self.rank]) handle = _dist_allgather_fn(partitions[self.rank], flat_tensor, self.ds_process_group) @@ -1103,7 +1108,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): device=OffloadDeviceEnum.cpu if self.remote_device == OffloadDeviceEnum.nvme else self.remote_device) if self.pin_memory: - partitioned_tensor = partitioned_tensor.pin_memory() + partitioned_tensor = get_accelerator().pin_memory( + partitioned_tensor) partitioned_tensor.requires_grad = False param.ds_tensor = partitioned_tensor @@ -1195,7 +1201,7 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ', force=False) - torch.cuda.synchronize() + get_accelerator().synchronize() print_rank_0( f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}" @@ -1209,7 +1215,8 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): if self.use_all_gather_base: # try the _all_gather_base on PyTorch master branch handle = dist.all_gather_base(flat_tensor, - param.ds_tensor.cuda(), + param.ds_tensor.to( + get_accelerator().device_name()), group=self.ds_process_group, async_op=async_op) else: @@ -1243,7 +1250,7 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): local_tensors = [] for param in param_list: partition_sizes.append(param.ds_tensor.ds_numel) - local_tensors.append(param.ds_tensor.cuda()) + local_tensors.append(param.ds_tensor.to(get_accelerator().device_name())) # allocate memory for allgather params allgather_params = [] @@ -1274,7 +1281,7 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): psize = partition_sizes[param_idx] partition = allgather_params[param_idx].narrow(0, i * psize, psize) output_list.append(partition) - if not partition.is_cuda: + if not get_accelerator().on_accelerator(partition): logger.warning( f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}' ) @@ -1297,7 +1304,7 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): param.ds_numel).view(param.ds_shape).data # guarantee the communication to be completed - torch.cuda.synchronize() + get_accelerator().synchronize() return None diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 1dcff3f1c12f..9bf5a3c67883 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -7,7 +7,6 @@ import collections from collections import UserDict from typing import Deque, Set -from torch.cuda import Event, Stream from deepspeed import comm as dist from deepspeed.utils.logging import logger @@ -15,6 +14,7 @@ from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id +from deepspeed.accelerator import get_accelerator def debug_rank0(message: str) -> None: @@ -66,7 +66,7 @@ def __init__( prefetch_bucket_sz: int, max_reuse_distance_in_numel: int, max_available_parameters_in_numel: int, - allgather_stream: Stream, + allgather_stream: get_accelerator().Stream, prefetch_nvme: bool = False, ) -> None: # mapping of param -> handle for each param that is currently in flight @@ -95,7 +95,7 @@ def __init__( self.hierarchy: int = 0 # stream that will be used for allgather operations - self.__allgather_stream: Stream = allgather_stream + self.__allgather_stream: get_accelerator().Stream = allgather_stream # limit the number of fetch events that can be queued at once # otherwise, what happens is memory is allocated by the host thread at the @@ -106,7 +106,7 @@ def __init__( # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now # because ideally in the future its replaced by an async allocation # mechanism which doesn't require any configuration by the user. - self.__ongoing_fetch_events: Deque[Event] = collections.deque() + self.__ongoing_fetch_events: Deque[get_accelerator().Event] = collections.deque() # TODO. make this configurable via JSON self.__max_ongoing_fetch_events: int = 2 @@ -260,7 +260,7 @@ def fetch_sub_module(self, current_submodule: Module) -> None: param.ds_active_sub_modules.add(current_submodule.id) debug_rank0(f"-wait: {param.ds_summary()}") if param in self.__inflight_param_registry: - with torch.cuda.stream(self.__allgather_stream): + with get_accelerator().stream(self.__allgather_stream): while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ 0].query(): self.__ongoing_fetch_events.popleft() @@ -270,12 +270,12 @@ def fetch_sub_module(self, current_submodule: Module) -> None: self.__inflight_param_registry.pop(param).wait() - event = Event() + event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - torch.cuda.current_stream().wait_stream(self.__allgather_stream) + get_accelerator().current_stream().wait_stream(self.__allgather_stream) # kick off parameter prefetches for upcoming modules # don't prefetch if we dont have a completed model trace @@ -393,7 +393,7 @@ def __all_gather_params(self, params: Set[Parameter]) -> None: self.__n_available_params += param.ds_numel if partitioned_params: - with torch.cuda.stream(self.__allgather_stream): + with get_accelerator().stream(self.__allgather_stream): handle = partitioned_params[0].all_gather_coalesced(partitioned_params) for param in partitioned_params: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 021e6317bc51..1577242c1623 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,7 +7,6 @@ import gc import collections from typing import Deque, Dict, Tuple -from torch.cuda import Event, Stream from torch._six import inf from deepspeed.runtime import ZeROOptimizer @@ -20,11 +19,11 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FP32_FLAT_GROUPS, PARTITION_COUNT, ZERO_STAGE +from deepspeed.accelerator import get_accelerator # Toggle this to true to enable correctness test # with gradient partitioning and without @@ -122,13 +121,13 @@ def __init__(self, # - assume all params requires grad # - flat by groups, not keeping state. TODO: remove state explicitly? # - master grad and unflat master weight never exist. TODO: a way to save out unflat master? - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") + if not get_accelerator().is_available(): + raise SystemError("Cannot use fp16 without accelerator.") self.optimizer = init_optimizer # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() + util_ops = get_accelerator().create_op_builder("UtilsBuilder").load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten self.dtype = self.optimizer.param_groups[0]['params'][0].dtype @@ -170,17 +169,17 @@ def __init__(self, self.__inf_or_nan_tracker: Tensor = torch.zeros( 1, dtype=torch.bool, - device=torch.cuda.current_device(), + device=get_accelerator().current_device_name(), requires_grad=False) self.deepspeed_adam_offload = (self.offload_optimizer and type(init_optimizer) == DeepSpeedCPUAdam) - self.device = torch.cuda.current_device( + self.device = get_accelerator().current_device_name( ) if not self.offload_optimizer else OffloadDeviceEnum.cpu ### streams used for overlapping computation with communication - self.__reduce_and_partition_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() + self.__reduce_and_partition_stream = get_accelerator().Stream( + ) if overlap_comm else get_accelerator().default_stream() ############################################################################ @@ -265,7 +264,7 @@ def __init__(self, self.__params_in_ipg_bucket: List[Parameter] = [] self.is_gradient_accumulation_boundary: bool = True - self.__param_reduce_events: Deque[Event] = collections.deque() + self.__param_reduce_events: Deque[get_accelerator().Event] = collections.deque() # TODO. make this configurable via JSON self.__max_param_reduce_events: int = 2 @@ -373,18 +372,19 @@ def _setup_for_real_optimizer(self): self.__ipg_bucket_flat_buffer: Tensor = torch.empty( self.reduce_bucket_size, dtype=self.dtype, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) grad_partitions_flat_buffer = None self.__param_id_to_grad_partition: Dict[int, Tensor] = {} all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - grad_partitions_flat_buffer: Tensor = torch.zeros( - sum(p.partition_numel() for p in all_params), - dtype=self.dtype, - device=self.device, - pin_memory=self.offload_optimizer_pin_memory) + grad_partitions_flat_buffer: Tensor = torch.zeros(sum(p.partition_numel() + for p in all_params), + dtype=self.dtype, + device=self.device) + if self.offload_optimizer_pin_memory: + get_accelerator().pin_memory(grad_partitions_flat_buffer) offset = 0 for param in all_params: @@ -431,7 +431,7 @@ def defragment(tensors: List[Tensor]) -> Tensor: offset += tensor_numel gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() # copy tensors (now flattened and contiguous) back to GPU device_buffer = cpu_buffer.to(orig_device) @@ -540,9 +540,9 @@ def _create_param_groups_fp16_flat_cpu_memory(self): print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", force=False) self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(int(flat_buffer_size), - dtype=self.dtype, - pin_memory=True)) + get_accelerator().pin_memory( + torch.empty(int(flat_buffer_size), + dtype=self.dtype))) else: print_rank_0( f"No flat buffer size. Param group size was {params_in_group}", @@ -912,7 +912,8 @@ def initialize_optimizer_states(self): dtype=gradient_dtype, device=self.device) if self.offload_optimizer_pin_memory: - subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() + subgroup_gradient_buffer = get_accelerator().pin_memory( + subgroup_gradient_buffer) self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer else: @@ -1089,19 +1090,20 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): @instrument_w_nvtx @torch.no_grad() def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) + self.__reduce_and_partition_stream.wait_stream( + get_accelerator().default_stream()) if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( ) < self.reduce_bucket_size: # move the gradient to a contiguous buffer - with torch.cuda.stream(self.__reduce_and_partition_stream): + with get_accelerator().stream(self.__reduce_and_partition_stream): # move the parameter's gradient to the contiguous flat buffer new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( 0, self.elements_in_ipg_bucket, param.grad.numel()).view_as(param.grad) new_grad_tensor.copy_(param.grad, non_blocking=True) - param.grad.record_stream(torch.cuda.current_stream()) + param.grad.record_stream(get_accelerator().current_stream()) param.grad.data = new_grad_tensor self.__params_in_ipg_bucket.append(param) @@ -1128,7 +1130,7 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: if len(self.__param_reduce_events) > self.__max_param_reduce_events: self.__param_reduce_events.popleft().synchronize() - with torch.cuda.stream(self.__reduce_and_partition_stream): + with get_accelerator().stream(self.__reduce_and_partition_stream): if safe_mode: assert_ints_same_as_other_ranks( [p.ds_id for p in self.__params_in_ipg_bucket]) @@ -1138,7 +1140,7 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: self.__params_in_ipg_bucket.clear() - event = Event() + event = get_accelerator().Event() event.record() self.__param_reduce_events.append(event) @@ -1202,7 +1204,7 @@ def set_norm_for_param_grad_in_gpu(self, param): self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): - with torch.cuda.stream(self.copy_grad_stream): + with get_accelerator().stream(self.copy_grad_stream): param_id = self.get_param_id(param) src_tensor = param.grad.view(-1).float() #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") @@ -1220,7 +1222,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, @@ -1258,7 +1260,7 @@ def __partition_grads(self, # ensure grad buffer is a CUDA buffer to speed up the next few # operations and so it can be used asynchronously grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) - elif grad_buffer.is_cuda: + elif get_accelerator().on_accelerator(grad_buffer): grad_buffer.add_(grad_partition) else: # if dst is CPU, copy first to src device, do the addition @@ -1303,7 +1305,7 @@ def __partition_grads(self, fp32_grad_tensor.copy_(grad_buffer) # free the gradient - param.grad.record_stream(torch.cuda.current_stream()) + param.grad.record_stream(get_accelerator().current_stream()) param.grad = None if self.offload_optimizer and self.swap_optimizer: @@ -1416,7 +1418,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None): # if rank is specified do a reduction instead of an allreduce def allreduce_and_copy(self, small_bucket, rank=None, log=None): - with torch.cuda.stream(self.reduction_stream): + with get_accelerator().stream(self.reduction_stream): allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) if rank is None or rank == dist.get_rank(group=self.dp_process_group): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): @@ -1507,8 +1509,8 @@ def zero_grad(self, set_grads_to_None=True): for group in self.fp16_groups: for p in group: if set_grads_to_None: - if p.grad is not None and p.grad.is_cuda: - p.grad.record_stream(torch.cuda.current_stream()) + if p.grad is not None and get_accelerator().on_accelerator(p.grad): + p.grad.record_stream(get_accelerator().current_stream()) p.grad = None else: if p.grad is not None: @@ -1544,7 +1546,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) @@ -1558,7 +1560,9 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): grad_norms = [] for g, p in zip(gradients, params): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) + grad_norms.append( + g.to(get_accelerator().device_name(), + non_blocking=True).double().norm(2)) # Sum across all model parallel GPUs. total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) @@ -1697,8 +1701,9 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id): # release all the gradient since we have already created a necessary copy in dp_grad_partition self.zero_grad() - for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): - grad.record_stream(torch.cuda.current_stream()) + for grad in filter(lambda g: get_accelerator().on_accelerator(g), + self.averaged_gradients[sub_group_id]): + grad.record_stream(get_accelerator().current_stream()) self.averaged_gradients[sub_group_id] = None @@ -1918,9 +1923,8 @@ def step(self, closure=None): self._post_step(timer_names) # warn user about caching allocator flushes - alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( - torch.cuda, - "memory_stats") else 0 + memory_stats = get_accelerator().memory_stats() + alloc_retries = memory_stats["num_alloc_retries"] if memory_stats != None else 0 if alloc_retries > self.__n_caching_allocator_flushes: if dist.get_rank() == 0: logger.warning( @@ -1929,7 +1933,7 @@ def step(self, closure=None): "performance. if this is happening frequently consider adjusting " "settings to reduce memory consumption. If you are unable to " "make the cache flushes go away consider adding " - "torch.cuda.empty_cache() calls in your training loop to ensure " + "get_accelerator().empty_cache() calls in your training loop to ensure " "that all ranks flush their caches at the same time", alloc_retries - self.__n_caching_allocator_flushes) self.__n_caching_allocator_flushes = alloc_retries @@ -2000,13 +2004,13 @@ def has_overflow_partitioned_grads_serial(self): @instrument_w_nvtx def has_overflow(self, partition_gradients=True): if partition_gradients: - with torch.cuda.stream(self.__reduce_and_partition_stream): + with get_accelerator().stream(self.__reduce_and_partition_stream): self.local_overflow = bool(self.__inf_or_nan_tracker.item()) self.__inf_or_nan_tracker.zero_() overflow = self.local_overflow #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow_gpu = get_accelerator().ByteTensor([overflow]) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) @@ -2018,7 +2022,7 @@ def has_overflow(self, partition_gradients=True): params.append(param) overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow_gpu = get_accelerator().ByteTensor([overflow]) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs @@ -2357,7 +2361,7 @@ def load_state_dict(self, will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() + model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c2c079e386e6..374f88d6bfa1 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -21,12 +21,13 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.utils import logger from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version from deepspeed.runtime.constants import PIPE_REPLICATED +from deepspeed.accelerator import get_accelerator + from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, @@ -48,11 +49,12 @@ def input(msg): def split_half_float_double(tensors): + device_type = get_accelerator().device_name() dtypes = [ - "torch.cuda.HalfTensor", - "torch.cuda.FloatTensor", - "torch.cuda.DoubleTensor", - "torch.cuda.BFloat16Tensor" + "torch.{}.HalfTensor".format(device_type), + "torch.{}.FloatTensor".format(device_type), + "torch.{}.DoubleTensor".format(device_type), + "torch.{}.BFloat16Tensor".format(device_type) ] buckets = [] for i, dtype in enumerate(dtypes): @@ -153,12 +155,12 @@ def __init__(self, # - assume all params requires grad # - flat by groups, not keeping state. TODO: remove state explicitly? # - master grad and unflat master weight never exist. TODO: a way to save out unflat master? - if not torch.cuda.is_available: + if not get_accelerator().is_available(): raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() + util_ops = get_accelerator().create_op_builder("UtilsBuilder").load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten @@ -175,7 +177,8 @@ def __init__(self, self.deepspeed_adam_offload = cpu_offload - self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu' + self.device = get_accelerator().current_device_name( + ) if not self.cpu_offload else 'cpu' self.dp_process_group = dp_process_group @@ -207,9 +210,11 @@ def __init__(self, if mpu is None: self.model_parallel_group = None + self.model_parallel_world_size = 1 self.model_parallel_rank = 0 else: self.model_parallel_group = mpu.get_model_parallel_group() + self.model_parallel_world_size = mpu.get_model_parallel_world_size() self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu) self.overflow = False @@ -316,8 +321,8 @@ def __init__(self, self.flatten_dense_tensors_aligned( self.round_robin_bit16_groups[i], self.nccl_start_alignment_factor * - dist.get_world_size(group=self.real_dp_process_group[i])).cuda( - torch.cuda.current_device())) + dist.get_world_size(group=self.real_dp_process_group[i])).to( + get_accelerator().current_device_name())) see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) @@ -396,10 +401,11 @@ def __init__(self, self.reduce_bucket_size = int(reduce_bucket_size) self.allgather_bucket_size = int(allgather_bucket_size) - self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) - self.reduction_stream = torch.cuda.Stream() - self.cpu_computation_stream = torch.cuda.Stream() - self.copy_grad_stream = torch.cuda.Stream() + self.reduction_event = get_accelerator().Event(enable_timing=False, + blocking=False) + self.reduction_stream = get_accelerator().Stream() + self.cpu_computation_stream = get_accelerator().Stream() + self.copy_grad_stream = get_accelerator().Stream() self.callback_queued = False self.param_dict = {} @@ -444,13 +450,13 @@ def __init__(self, self.norm_for_param_grads = {} self.local_overflow = False self.grad_position = {} - self.temp_grad_buffer_for_cpu_offload = torch.zeros( - largest_param_numel, - device=self.device, - dtype=self.dtype).pin_memory() + self.temp_grad_buffer_for_cpu_offload = get_accelerator().pin_memory( + torch.zeros(largest_param_numel, + device=self.device, + dtype=self.dtype)) self.temp_grad_buffer_for_gpu_offload = torch.zeros( largest_param_numel, - device=torch.cuda.current_device(), + device=get_accelerator().current_device_name(), dtype=self.dtype) for i, params_group in enumerate(self.bit16_groups): self.get_grad_position(i, @@ -638,9 +644,8 @@ def initialize_optimizer_states(self): int(self.partition_size[i]), dtype=self.single_partition_of_fp32_groups[i].dtype, device=self.device) - self.single_partition_of_fp32_groups[ - i].grad = single_grad_partition.pin_memory( - ) if self.cpu_offload else single_grad_partition + self.single_partition_of_fp32_groups[i].grad = get_accelerator().pin_memory( + single_grad_partition) if self.cpu_offload else single_grad_partition self.optimizer.step() @@ -662,7 +667,7 @@ def reduce_gradients(self, pipeline_parallel=False): self.ipg_buffer = [] buf_0 = torch.empty(int(self.reduce_bucket_size), dtype=self.dtype, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) self.ipg_buffer.append(buf_0) self.ipg_index = 0 @@ -723,7 +728,7 @@ def independent_gradient_partition_epilogue(self): self.params_already_reduced[i] = False if self.overlap_comm: - torch.cuda.synchronize() + get_accelerator().synchronize() # It is safe to clear previously reduced grads of other partitions self._clear_previous_reduced_grads() @@ -736,15 +741,16 @@ def independent_gradient_partition_epilogue(self): self.first_offset[i], self.partition_size[i], dtype=self.dtype, - device=torch.cuda.current_device(), + device=get_accelerator().current_device_name(), return_tensor_list=True) else: - avg_new = self.get_flat_partition(self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=self.dtype, - device=torch.cuda.current_device(), - return_tensor_list=True) + avg_new = self.get_flat_partition( + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=self.dtype, + device=get_accelerator().current_device_name(), + return_tensor_list=True) for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i], avg_new): accumulated_grad.add_(new_avg_grad) @@ -937,12 +943,12 @@ def gradient_reduction_w_predivide(self, tensor): def average_tensor(self, tensor): if self.overlap_comm: - torch.cuda.synchronize() + get_accelerator().synchronize() stream = self.reduction_stream else: - stream = torch.cuda.current_stream() + stream = get_accelerator().current_stream() - with torch.cuda.stream(stream): + with get_accelerator().stream(stream): if not self.reduce_scatter: self.gradient_reduction_w_predivide(tensor) return @@ -1081,9 +1087,10 @@ def async_accumulate_grad_in_cpu_via_gpu(self, param): #buffer for storing gradients for this parameter in CPU def buffer_to_accumulate_to_in_cpu(): if not self.fp16_master_weights_and_gradients: - return torch.zeros(param.numel(), - dtype=param.dtype, - device=self.device).pin_memory() + return get_accelerator().pin_memory( + torch.zeros(param.numel(), + dtype=param.dtype, + device=self.device)) else: return self.single_partition_of_fp32_groups[i].grad.view(-1).narrow( 0, @@ -1202,7 +1209,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): """ # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) @@ -1241,9 +1248,10 @@ def copy_grads_in_partition(self, param): total_size += param_in_partition.numel() see_memory_usage(f"before copying {total_size} gradients into partition") - self.grads_in_partition = torch.empty(int(total_size), - dtype=self.dtype, - device=torch.cuda.current_device()) + self.grads_in_partition = torch.empty( + int(total_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) see_memory_usage(f"after copying {total_size} gradients into partition") # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer @@ -1277,13 +1285,13 @@ def reduce_ipg_grads(self): stream = self.reduction_stream elif self.cpu_offload: # TODO: copy_grad_stream is disabled because of race with reduce. This hurts perf and should be fixed. - # torch.cuda.synchronize() + # get_accelerator().synchronize() # stream = self.copy_grad_stream - stream = torch.cuda.current_stream() + stream = get_accelerator().current_stream() else: - stream = torch.cuda.current_stream() + stream = get_accelerator().current_stream() - with torch.cuda.stream(stream): + with get_accelerator().stream(stream): for _, param, param_id in self.params_in_ipg_bucket: assert self.params_already_reduced[param_id] == False, \ @@ -1425,14 +1433,14 @@ def _clear_previous_reduced_grads(self): # if rank is specified do a reduction instead of an allreduce def allreduce_and_copy(self, small_bucket, rank=None, log=None): if self.overlap_comm: - torch.cuda.synchronize() + get_accelerator().synchronize() # It is safe to clear the previously reduced grads of other partitions self._clear_previous_reduced_grads() stream = self.reduction_stream else: - stream = torch.cuda.current_stream() + stream = get_accelerator().current_stream() - with torch.cuda.stream(stream): + with get_accelerator().stream(stream): allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) if rank is None or rank == dist.get_rank(group=self.dp_process_group): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): @@ -1545,7 +1553,7 @@ def zero_grad(self, set_grads_to_None=True): def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. """ - if self.model_parallel_group is None: + if self.model_parallel_group is None or self.model_parallel_world_size == 1: pass else: dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group) @@ -1570,7 +1578,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) @@ -1590,7 +1598,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): param_norm = g.data.double().norm(2) total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) @@ -1890,7 +1898,7 @@ def _average_expert_grad_norms(self, norm_groups): scaled_norm = norm * 1.0 / float( dist.get_world_size(group=self.real_dp_process_group[i])) scaled_norm_tensor = torch.tensor(scaled_norm, - device='cuda', + device=get_accelerator().device_name(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=self.real_dp_process_group[i]) norm_groups[i] = scaled_norm_tensor.item() @@ -1934,7 +1942,7 @@ def has_overflow(self, partition_gradients=True): if partition_gradients: overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial( ) - overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow_gpu = get_accelerator().ByteTensor([overflow]) '''This will capture overflow across all data parallel and expert parallel process Since expert parallel process are a subset of data parallel process''' dist.all_reduce(overflow_gpu, @@ -1948,7 +1956,7 @@ def has_overflow(self, partition_gradients=True): params.append(param) overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = torch.cuda.ByteTensor([overflow]) + overflow_gpu = get_accelerator().ByteTensor([overflow]) # Since each model parallel GPU carries only part of the model, # make sure overflow flag is synced across all the model parallel GPUs @@ -1993,14 +2001,14 @@ def backward(self, loss, retain_graph=False): self.ipg_buffer = [] buf_0 = torch.empty(int(self.reduce_bucket_size), dtype=self.dtype, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) self.ipg_buffer.append(buf_0) # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: buf_1 = torch.empty(int(self.reduce_bucket_size), dtype=self.dtype, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) self.ipg_buffer.append(buf_1) self.ipg_index = 0 @@ -2284,7 +2292,7 @@ def _load_legacy_checkpoint(self, will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() + model = torch.nn.Linear(D_in, D_out).to(get_accelerator().device_name()).half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 665802ce8e93..c9a8807b396d 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -7,6 +7,7 @@ from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.adam import FusedAdam from deepspeed.utils.nvtx import instrument_w_nvtx +from deepspeed.accelerator import get_accelerator def _initialize_parameter_parallel_groups(parameter_parallel_size=None): @@ -63,8 +64,8 @@ def get_lst_from_rank0(lst: List[int]) -> None: lst_tensor = torch.tensor( lst if dist.get_rank() == 0 else [-1] * len(lst), dtype=int, - # device=torch.cuda.current_device(), - device=torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])), + # device=get_accelerator().current_device_name(), + device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])), requires_grad=False, ) dist.broadcast(lst_tensor, src=0, async_op=False) diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py index 11579ffbb17f..2db7e4660fc1 100644 --- a/deepspeed/utils/nvtx.py +++ b/deepspeed/utils/nvtx.py @@ -1,14 +1,17 @@ import torch +from deepspeed.accelerator import get_accelerator def instrument_w_nvtx(func): """decorator that causes an NVTX range to be recorded for the duration of the function call.""" - if hasattr(torch.cuda.nvtx, "range"): + if get_accelerator().device_name() == 'cuda' and hasattr(torch.cuda.nvtx, "range"): def wrapped_fn(*args, **kwargs): - with torch.cuda.nvtx.range(func.__qualname__): - return func(*args, **kwargs) + get_accelerator().range_push(func.__qualname__) + ret_val = func(*args, **kwargs) + get_accelerator().range_pop() + return ret_val return wrapped_fn else: diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index ae5174508457..84c4cb603a04 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -3,9 +3,9 @@ """ import time -import torch from numpy import mean from deepspeed.utils.logging import log_dist +from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist try: @@ -18,12 +18,14 @@ class CudaEventTimer(object): - def __init__(self, start_event: torch.cuda.Event, end_event: torch.cuda.Event): + def __init__(self, + start_event: get_accelerator().Event, + end_event: get_accelerator().Event): self.start_event = start_event self.end_event = end_event def get_elapsed_msec(self): - torch.cuda.current_stream().wait_event(self.end_event) + get_accelerator().current_stream().wait_event(self.end_event) self.end_event.synchronize() return self.start_event.elapsed_time(self.end_event) @@ -42,14 +44,14 @@ def __init__(self, name): def start(self): """Start the timer.""" assert not self.started_, f"{self.name_} timer has already been started" - self.start_event = torch.cuda.Event(enable_timing=True) + self.start_event = get_accelerator().Event(enable_timing=True) self.start_event.record() self.started_ = True def stop(self, reset=False, record=False): """Stop the timer.""" assert self.started_, "timer is not started" - end_event = torch.cuda.Event(enable_timing=True) + end_event = get_accelerator().Event(enable_timing=True) end_event.record() self.event_timers.append(CudaEventTimer(self.start_event, end_event)) self.start_event = None @@ -100,14 +102,14 @@ def __call__(self, name): @staticmethod def memory_usage(): - alloc = "mem_allocated: {:.4f} GB".format(torch.cuda.memory_allocated() / + alloc = "mem_allocated: {:.4f} GB".format(get_accelerator().memory_allocated() / (1024 * 1024 * 1024)) max_alloc = "max_mem_allocated: {:.4f} GB".format( - torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024)) - cache = "cache_allocated: {:.4f} GB".format(torch.cuda.memory_cached() / + get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024)) + cache = "cache_allocated: {:.4f} GB".format(get_accelerator().memory_cached() / (1024 * 1024 * 1024)) max_cache = "max_cache_allocated: {:.4f} GB".format( - torch.cuda.max_memory_cached() / (1024 * 1024 * 1024)) + get_accelerator().max_memory_cached() / (1024 * 1024 * 1024)) return " | {} | {} | {} | {}".format(alloc, max_alloc, cache, max_cache) def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None): @@ -176,7 +178,7 @@ def start(self): self._init_timer() self.started = True if self.total_step_count >= self.start_step: - torch.cuda.synchronize() + get_accelerator().synchronize() self.start_time = time.time() def stop(self, report_speed=True): @@ -186,7 +188,7 @@ def stop(self, report_speed=True): self.total_step_count += 1 self.local_step_count += 1 if self.total_step_count > self.start_step: - torch.cuda.synchronize() + get_accelerator().synchronize() self.end_time = time.time() duration = self.end_time - self.start_time self.total_elapsed_time += duration @@ -197,14 +199,15 @@ def stop(self, report_speed=True): if report_speed: self.logging( "{}/{}, RunningAvgSamplesPerSec={}, CurrSamplesPerSec={}, MemAllocated={}GB, MaxMemAllocated={}GB" - .format(self.epoch_count, - self.local_step_count, - self.avg_samples_per_sec(), - curr_samples_sec, - round(torch.cuda.memory_allocated() / 1024**3, - 2), - round(torch.cuda.max_memory_allocated() / 1024**3, - 2))) + .format( + self.epoch_count, + self.local_step_count, + self.avg_samples_per_sec(), + curr_samples_sec, + round(get_accelerator().memory_allocated() / 1024**3, + 2), + round(get_accelerator().max_memory_allocated() / 1024**3, + 2))) if self.monitor_memory: virt_mem = psutil.virtual_memory() swap = psutil.swap_memory() diff --git a/docs/_tutorials/cifar-10.md b/docs/_tutorials/cifar-10.md index 11a05a78a749..74ee04502f18 100644 --- a/docs/_tutorials/cifar-10.md +++ b/docs/_tutorials/cifar-10.md @@ -140,7 +140,8 @@ Here we initialize DeepSpeed with CIFAR-10 model (`net`), `args`, `parameters` a After initializing DeepSpeed, the original `device` and `optimizer` are removed: ```python - #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + #from deepspeed.accelerator import get_accelerator + #device = torch.device(get_accelerator().device_name(0) if get_accelerator().is_available() else "cpu") #net.to(device) #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) diff --git a/op_builder/__init__.py b/op_builder/__init__.py index dcac71011aa8..527f67482b12 100755 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -1,6 +1,7 @@ """ Copyright 2020 The Microsoft DeepSpeed Team """ +from deepspeed.accelerator import get_accelerator from .cpu_adam import CPUAdamBuilder from .cpu_adagrad import CPUAdagradBuilder from .fused_adam import FusedAdamBuilder @@ -13,20 +14,3 @@ from .transformer_inference import InferenceBuilder from .quantizer import QuantizerBuilder from .builder import get_default_compute_capabilities, OpBuilder - -# TODO: infer this list instead of hard coded -# List of all available ops -__op_builders__ = [ - CPUAdamBuilder(), - CPUAdagradBuilder(), - FusedAdamBuilder(), - FusedLambBuilder(), - SparseAttnBuilder(), - TransformerBuilder(), - StochasticTransformerBuilder(), - AsyncIOBuilder(), - UtilsBuilder(), - QuantizerBuilder(), - InferenceBuilder() -] -ALL_OPS = {op.name: op for op in __op_builders__} diff --git a/op_builder/all_ops.py b/op_builder/all_ops.py new file mode 100644 index 000000000000..2d7725984026 --- /dev/null +++ b/op_builder/all_ops.py @@ -0,0 +1,21 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +from deepspeed.accelerator import get_accelerator + +# TODO: infer this list instead of hard coded +# List of all available ops +__op_builders__ = [ + get_accelerator().create_op_builder("CPUAdamBuilder"), + get_accelerator().create_op_builder("CPUAdagradBuilder"), + get_accelerator().create_op_builder("FusedAdamBuilder"), + get_accelerator().create_op_builder("FusedLambBuilder"), + get_accelerator().create_op_builder("SparseAttnBuilder"), + get_accelerator().create_op_builder("TransformerBuilder"), + get_accelerator().create_op_builder("StochasticTransformerBuilder"), + get_accelerator().create_op_builder("AsyncIOBuilder"), + get_accelerator().create_op_builder("UtilsBuilder"), + get_accelerator().create_op_builder("QuantizerBuilder"), + get_accelerator().create_op_builder("InferenceBuilder") +] +ALL_OPS = {op.name: op for op in __op_builders__ if op is not None} diff --git a/setup.py b/setup.py index 969060018e7c..bccb6d80ac14 100755 --- a/setup.py +++ b/setup.py @@ -20,17 +20,19 @@ from setuptools import setup, find_packages from setuptools.command import egg_info import time +from deepspeed.accelerator import get_accelerator torch_available = True try: import torch - from torch.utils.cpp_extension import BuildExtension + from torch.utils.cpp_extension import BuildExtension # noqa: F401 except ImportError: torch_available = False print('[WARNING] Unable to import torch, pre-compiling ops will be disabled. ' \ 'Please visit https://pytorch.org/ to see how to properly install torch on your system.') -from op_builder import ALL_OPS, get_default_compute_capabilities, OpBuilder +from op_builder import get_default_compute_capabilities, OpBuilder +from op_builder.all_ops import ALL_OPS # fetch rocm state is_rocm_pytorch = OpBuilder.is_rocm_pytorch() @@ -89,7 +91,8 @@ def fetch_requirements(path): # For any pre-installed ops force disable ninja if torch_available: - cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False) + cmdclass['build_ext'] = get_accelerator().build_extension().with_options( + use_ninja=False) if torch_available: TORCH_MAJOR = torch.__version__.split('.')[0] diff --git a/tests/accelerator/ds_config.json b/tests/accelerator/ds_config.json new file mode 100644 index 000000000000..8e9ac6b889ea --- /dev/null +++ b/tests/accelerator/ds_config.json @@ -0,0 +1,19 @@ +{ + "train_batch_size": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 1e-2 + } + }, + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + } +} diff --git a/tests/accelerator/test_ds_init.py b/tests/accelerator/test_ds_init.py new file mode 100644 index 000000000000..55121aa96921 --- /dev/null +++ b/tests/accelerator/test_ds_init.py @@ -0,0 +1,42 @@ +import os +import torch +import deepspeed +from deepspeed.accelerator import get_accelerator + + +class OneLayerNet(torch.nn.Module): + def __init__(self, D_in, D_out): + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + super(OneLayerNet, self).__init__() + self.linear1 = torch.nn.Linear(D_in, D_out) + + def forward(self, x): + """ + In the forward function we accept a Variable of input data and we must return + a Variable of output data. We can use Modules defined in the constructor as + well as arbitrary operators on Variables. + """ + h_relu = self.linear1(x).clamp(min=0) + y_pred = self.linear1(h_relu) + return y_pred + + +def test_literal_device(): + model = OneLayerNet(128, 128) + + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '1' + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '8088' + os.environ['LOCAL_RANK'] = '0' + deepspeed.init_distributed(get_accelerator().communication_backend_name()) + deepspeed.initialize(model=model, config='ds_config.json') + string = get_accelerator().device_name() #'xpu' or 'cuda' + string0 = get_accelerator().device_name(0) #'xpu:0' or 'cuda:0' + string1 = get_accelerator().device_name(1) #'xpu:1' or 'cuda:1' + assert string == 'xpu' or string == 'cuda' + assert string0 == 'xpu:0' or string0 == 'cuda:0' + assert string1 == 'xpu:1' or string1 == 'cuda:1' diff --git a/tests/benchmarks/flatten_bench.py b/tests/benchmarks/flatten_bench.py index d87971dc1a78..744542d46f0c 100755 --- a/tests/benchmarks/flatten_bench.py +++ b/tests/benchmarks/flatten_bench.py @@ -12,11 +12,11 @@ import torch from torch._utils import _flatten_dense_tensors -from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.accelerator import get_accelerator from apex_C import flatten as flatten_apex -util_ops = UtilsBuilder().load() +util_ops = get_accelerator().create_op_builder("UtilsBuilder").load() flatten = util_ops.flatten unflatten = util_ops.unflatten @@ -24,11 +24,11 @@ # emulate a small typical model weights x = [ torch.rand((512, - 512)).cuda(), + 512)).to(get_accelerator().device_name()), torch.rand((512, - 1024)).cuda(), + 1024)).to(get_accelerator().device_name()), torch.rand((512, - 30000)).cuda() + 30000)).to(get_accelerator().device_name()) ] t = x * 30 @@ -69,15 +69,15 @@ def cprofileme(): print("py") cProfile.run("py()", sort=-1) gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("cpp") cProfile.run("cpp()", sort=-1) gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("apex") cProfile.run("apex()", sort=-1) gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() #### timeit #### @@ -89,13 +89,13 @@ def timeme(): print("--------------- timeit -----------------") print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}') gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}') gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}') gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() #### line_profiler #### @@ -109,15 +109,15 @@ def line_profileme(): print("py") profile(py)() # noqa: F821 gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("cpp") profile(cpp)() # noqa: F821 gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("apex") profile(apex)() # noqa: F821 gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() if __name__ == "__main__": diff --git a/tests/benchmarks/unflatten_bench.py b/tests/benchmarks/unflatten_bench.py index 23fb3f87566d..e451145e68ff 100755 --- a/tests/benchmarks/unflatten_bench.py +++ b/tests/benchmarks/unflatten_bench.py @@ -11,12 +11,12 @@ import gc import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.accelerator import get_accelerator from apex_C import flatten as flatten_apex from apex_C import unflatten as unflatten_apex -util_ops = UtilsBuilder().load() +util_ops = get_accelerator().create_op_builder("UtilsBuilder").load() flatten = util_ops.flatten unflatten = util_ops.unflatten @@ -24,11 +24,11 @@ # emulate a small typical model weights x = [ torch.rand((512, - 512)).cuda(), + 512)).to(get_accelerator().device_name()), torch.rand((512, - 1024)).cuda(), + 1024)).to(get_accelerator().device_name()), torch.rand((512, - 30000)).cuda() + 30000)).to(get_accelerator().device_name()) ] unflat_t = x * 30 @@ -78,15 +78,15 @@ def cprofileme(): print("py") cProfile.run("py()", sort=-1) gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("cpp") cProfile.run("cpp()", sort=-1) gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("apex") cProfile.run("apex()", sort=-1) gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() #### timeit #### @@ -98,13 +98,13 @@ def timeme(): print("--------------- timeit -----------------") print(f'py ={timeit.Timer("py()", globals=globals()).timeit(number=1)}') gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print(f'cpp ={timeit.Timer("cpp()", globals=globals()).timeit(number=1)}') gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print(f'apex={timeit.Timer("apex()", globals=globals()).timeit(number=1)}') gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() #### line_profiler #### @@ -118,15 +118,15 @@ def line_profileme(): print("py") profile(py)() # noqa: F821 gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("cpp") profile(cpp)() # noqa: F821 gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() print("apex") profile(apex)() # noqa: F821 gc.collect() - torch.cuda.empty_cache() + get_accelerator().empty_cache() if __name__ == "__main__": diff --git a/tests/onebit/test_mpi_backend.py b/tests/onebit/test_mpi_backend.py index 65cfb3ed96c5..0745a506aa4e 100644 --- a/tests/onebit/test_mpi_backend.py +++ b/tests/onebit/test_mpi_backend.py @@ -5,17 +5,19 @@ import deepspeed from deepspeed.runtime.comm.mpi import MpiBackend +from deepspeed.accelerator import get_accelerator comm = MPI.COMM_WORLD size = comm.Get_size() rank = comm.Get_rank() -deepspeed.init_distributed(dist_backend='nccl') +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) # Change cuda_aware to True to test out CUDA-Aware MPI communication backend = MpiBackend(cuda_aware=False) -device = torch.device('cuda', rank % torch.cuda.device_count()) +local_rank = rank % get_accelerator().device_count() +device = torch.device(get_accelerator().device_name(), local_rank) # A simulated compression function using deepspeed.comm @@ -35,7 +37,7 @@ def torch_sim(a): [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) rank = dist.get_rank() server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] - torch.cuda.synchronize() + get_accelerator().synchronize() dist.barrier() return a_server_compressed, worker_error, server_error @@ -56,8 +58,7 @@ def torch_sim(a): server_error = torch.zeros(right_server_size, device=device) a_torch, worker_error_torch, server_error_torch = torch_sim(a) -torch.cuda.empty_cache() -local_rank = rank % torch.cuda.device_count() +get_accelerator().empty_cache() a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) diff --git a/tests/onebit/test_mpi_perf.py b/tests/onebit/test_mpi_perf.py index 1652e946985c..e4dcd4833cce 100644 --- a/tests/onebit/test_mpi_perf.py +++ b/tests/onebit/test_mpi_perf.py @@ -6,6 +6,7 @@ # Configure wall clock timer from deepspeed.utils.timer import SynchronizedWallClockTimer +from deepspeed.accelerator import get_accelerator from statistics import mean @@ -15,11 +16,12 @@ size = comm.Get_size() rank = comm.Get_rank() -deepspeed.init_distributed(dist_backend='nccl') +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) # Change cuda_aware to True to test out CUDA-Aware MPI communication backend = MpiBackend(cuda_aware=False) -device = torch.device('cuda', rank % torch.cuda.device_count()) +local_rank = rank % get_accelerator().device_count() +device = torch.device(get_accelerator().device_name(), local_rank) tensor_size = 300 * 2**20 server_size = int(tensor_size / size) @@ -39,8 +41,6 @@ warmup = 10 iters = 10 -local_rank = rank % torch.cuda.device_count() - # Warmup for i in range(warmup): backend.compressed_allreduce(a, worker_error, server_error, local_rank) diff --git a/tests/onebit/test_nccl_backend.py b/tests/onebit/test_nccl_backend.py index 395b1053f917..50aebec1f716 100644 --- a/tests/onebit/test_nccl_backend.py +++ b/tests/onebit/test_nccl_backend.py @@ -6,16 +6,17 @@ import os from deepspeed.runtime.comm.nccl import NcclBackend +from deepspeed.accelerator import get_accelerator parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1) args = parser.parse_args() -deepspeed.init_distributed(dist_backend='nccl') +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) args.local_rank = int(os.environ['LOCAL_RANK']) -torch.cuda.set_device(args.local_rank) -device = torch.device("cuda", args.local_rank) +get_accelerator().set_device(args.local_rank) +device = torch.device(get_accelerator().device_name(), args.local_rank) size = dist.get_world_size() rank = dist.get_rank() @@ -41,7 +42,7 @@ def torch_sim(a): [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) rank = dist.get_rank() server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] - torch.cuda.synchronize() + get_accelerator().synchronize() dist.barrier() return a_server_compressed, worker_error, server_error @@ -62,7 +63,7 @@ def torch_sim(a): server_error = torch.zeros(right_server_size, device=device) a_torch, worker_error_torch, server_error_torch = torch_sim(a) -torch.cuda.empty_cache() +get_accelerator().empty_cache() a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) diff --git a/tests/onebit/test_nccl_perf.py b/tests/onebit/test_nccl_perf.py index 86923ef90b82..c2dc1a674624 100644 --- a/tests/onebit/test_nccl_perf.py +++ b/tests/onebit/test_nccl_perf.py @@ -7,6 +7,7 @@ from deepspeed.runtime.comm.nccl import NcclBackend from deepspeed.utils.timer import SynchronizedWallClockTimer +from deepspeed.accelerator import get_accelerator from statistics import mean timers = SynchronizedWallClockTimer() @@ -15,11 +16,11 @@ parser.add_argument('--local_rank', type=int, default=-1) args = parser.parse_args() -deepspeed.init_distributed(dist_backend='nccl') +deepspeed.init_distributed(dist_backend=get_accelerator().communication_backend_name()) args.local_rank = int(os.environ['LOCAL_RANK']) -torch.cuda.set_device(args.local_rank) -device = torch.device("cuda", args.local_rank) +get_accelerator().set_device(args.local_rank) +device = torch.device(get_accelerator().device_name(), args.local_rank) size = dist.get_world_size() rank = dist.get_rank() diff --git a/tests/perf/adam_test1.py b/tests/perf/adam_test1.py index 88f1a1c5961d..69c393cc6af8 100755 --- a/tests/perf/adam_test1.py +++ b/tests/perf/adam_test1.py @@ -1,13 +1,15 @@ import torch from deepspeed.ops.adam import DeepSpeedCPUAdam import time +from deepspeed.accelerator import get_accelerator device = 'cpu' model_size = 1 * 1024**3 param = torch.nn.Parameter(torch.ones(model_size, device=device)) -param_fp16 = torch.nn.Parameter(torch.ones(model_size, - dtype=torch.half, - device='cuda:0')) +param_fp16 = torch.nn.Parameter( + torch.ones(model_size, + dtype=torch.half, + device=get_accelerator().device_name(0))) optimizer = DeepSpeedCPUAdam([param]) #torch.set_num_threads(128) diff --git a/tests/small_model_debugging/test.py b/tests/small_model_debugging/test.py index 331a8ef35ca6..ab1dd2a9bc05 100644 --- a/tests/small_model_debugging/test.py +++ b/tests/small_model_debugging/test.py @@ -1,6 +1,7 @@ import torch from deepspeed.pt.deepspeed_linear import LinearModuleForZeroStage3 from deepspeed.pt.log_utils import logger +from deepspeed.accelerator import get_accelerator def see_memory_usage(message): @@ -9,37 +10,42 @@ def see_memory_usage(message): logger.info(message) logger.info( "Memory Allocated %s GigaBytes ", - torch.cuda.memory_allocated() / (1024 * 1024 * 1024), + get_accelerator().memory_allocated() / (1024 * 1024 * 1024), ) logger.info( "Max Memory Allocated %s GigaBytes", - torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), + get_accelerator().max_memory_allocated() / (1024 * 1024 * 1024), ) logger.info( "Cache Allocated %s GigaBytes", - torch.cuda.memory_cached() / (1024 * 1024 * 1024), + get_accelerator().memory_cached() / (1024 * 1024 * 1024), ) logger.info( "Max cache Allocated %s GigaBytes", - torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), + get_accelerator().max_memory_cached() / (1024 * 1024 * 1024), ) -tens = torch.rand(1024, 16384, dtype=torch.half, device=torch.device('cuda')) +tens = torch.rand(1024, + 16384, + dtype=torch.half, + device=torch.device(get_accelerator().device_name())) tens_back = tens.detach().clone() #linear_bk = torch.nn.functional.linear #torch.nn.functional.linear = deepspeed.pt.deepspeed_linear.LinearFunctionForZeroStage3.apply model = LinearModuleForZeroStage3(16384, 16384) -model.cuda().half() +model.to(get_accelerator().device_name()).half() see_memory_usage("Before forward") y = model(tens) see_memory_usage("After forward") -model.weight.data = torch.zeros(1, dtype=torch.half, device=torch.device('cuda')) +model.weight.data = torch.zeros(1, + dtype=torch.half, + device=torch.device(get_accelerator().device_name())) see_memory_usage("After weight zero") diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index b94bd7052a9f..cb12a52d7514 100644 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -5,6 +5,7 @@ import deepspeed import deepspeed.comm as dist import deepspeed.runtime.utils as ds_utils +from deepspeed.accelerator import get_accelerator from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec @@ -108,7 +109,7 @@ def cifar_trainset(fp16=False): transform = transforms.Compose(transform_list) - local_rank = torch.cuda.current_device() + local_rank = get_accelerator().current_device() # Only one rank per machine downloads. dist.barrier() @@ -129,7 +130,8 @@ def train_cifar(model, average_dp_losses=True, fp16=True, seed=123): - with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + with get_accelerator().random().fork_rng( + devices=[get_accelerator().current_device_name()]): ds_utils.set_random_seed(seed) # disable dropout @@ -152,7 +154,7 @@ def train_cifar(model, print(f'STEP={step} LOSS={loss.item()}') if average_dp_losses: - loss_tensor = torch.tensor(losses).cuda() + loss_tensor = torch.tensor(losses).to(get_accelerator().device_name()) dist.all_reduce(loss_tensor) loss_tensor /= dist.get_world_size() losses = loss_tensor.tolist() diff --git a/tests/unit/checkpoint/test_lr_scheduler.py b/tests/unit/checkpoint/test_lr_scheduler.py index bd950b4183b6..7f3c5b226980 100644 --- a/tests/unit/checkpoint/test_lr_scheduler.py +++ b/tests/unit/checkpoint/test_lr_scheduler.py @@ -1,5 +1,5 @@ import deepspeed -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest from unit.simple_model import * @@ -26,7 +26,8 @@ class TestLRSchedulerCheckpoint(DistributedTest): world_size = 2 def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { @@ -76,7 +77,8 @@ def test_checkpoint_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): load_lr_scheduler_states=True) def test_checkpoint_no_lr_scheduler(self, tmpdir, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { diff --git a/tests/unit/checkpoint/test_other_optimizer.py b/tests/unit/checkpoint/test_other_optimizer.py index 74a333399587..a6607330d89e 100644 --- a/tests/unit/checkpoint/test_other_optimizer.py +++ b/tests/unit/checkpoint/test_other_optimizer.py @@ -1,5 +1,5 @@ import deepspeed -from deepspeed.ops.op_builder import FusedLambBuilder +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest from unit.simple_model import * @@ -12,7 +12,8 @@ class TestOtherOptimizerCheckpoint(DistributedTest): world_size = 2 - @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("FusedLambBuilder").name], reason="lamb is not compatible") def test_checkpoint_unfused_optimizer(self, tmpdir): config_dict = { diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index 73bde2fda940..972e5b0673f4 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -1,5 +1,5 @@ import deepspeed -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest, DistributedFixture from unit.simple_model import * @@ -34,7 +34,8 @@ def test_load_optimizer_state(self, zero_stage, use_cpu_offload, adam_optimizer): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { @@ -95,7 +96,8 @@ def test_not_load_optimizer_state(self, zero_stage, use_cpu_offload, adam_optimizer): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { diff --git a/tests/unit/comm/test_dist.py b/tests/unit/comm/test_dist.py index 2a2abeba680e..0f067751cf2a 100644 --- a/tests/unit/comm/test_dist.py +++ b/tests/unit/comm/test_dist.py @@ -5,6 +5,7 @@ from unit.common import DistributedTest, DistributedFixture, get_master_port from unit.simple_model import SimpleModel +from deepspeed.accelerator import get_accelerator import pytest @@ -103,9 +104,9 @@ class TestDistAllReduce(DistributedTest): world_size = [1, 2, 4] def test(self): - x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1) + x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1) sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 - result = torch.ones(1, 3).cuda() * sum_of_ranks + result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks dist.all_reduce(x) assert torch.all(x == result) @@ -115,16 +116,21 @@ class TestDistInit(DistributedTest): init_distributed = False def test_already_init(self, dist_init_required): - torch.distributed.init_process_group('nccl') - deepspeed.init_distributed('nccl', dist_init_required=dist_init_required) + torch.distributed.init_process_group( + get_accelerator().communication_backend_name()) + deepspeed.init_distributed(get_accelerator().communication_backend_name(), + dist_init_required=dist_init_required) def test_no_init(self, dist_init_required): if dist_init_required or dist_init_required is None: - deepspeed.init_distributed('nccl', dist_init_required=dist_init_required) + deepspeed.init_distributed(get_accelerator().communication_backend_name(), + dist_init_required=dist_init_required) else: # torch.dist is not done and for some reason the user says they don't want it done with pytest.raises(Exception): - deepspeed.init_distributed('nccl', dist_init_required=dist_init_required) + deepspeed.init_distributed( + get_accelerator().communication_backend_name(), + dist_init_required=dist_init_required) class TestDistInitNoEnv(DistributedTest): @@ -134,12 +140,13 @@ class TestDistInitNoEnv(DistributedTest): def test(self): torch.distributed.init_process_group( - backend='nccl', + backend=get_accelerator().communication_backend_name(), init_method=f"tcp://127.0.0.1:{get_master_port()}", world_size=1, rank=0) assert torch.distributed.is_initialized() - deepspeed.init_distributed('nccl', auto_mpi_discovery=True) + deepspeed.init_distributed(get_accelerator().communication_backend_name(), + auto_mpi_discovery=True) @pytest.mark.parametrize("dist_init_required", [True, False]) @@ -147,7 +154,8 @@ class TestDistInitWithModel(DistributedTest): init_distributed = False def test_already_init(self, dist_init_required): - torch.distributed.init_process_group('nccl') + torch.distributed.init_process_group( + get_accelerator().communication_backend_name()) model = SimpleModel(4) config_dict = { "train_micro_batch_size_per_gpu": 1, diff --git a/tests/unit/common.py b/tests/unit/common.py index df59ed62f017..f76f2089343a 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -7,6 +7,7 @@ import torch import torch.multiprocessing as mp import deepspeed +from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist from torch.multiprocessing import Process @@ -34,23 +35,36 @@ def get_master_port(): return master_port -def set_cuda_visibile(): +def set_accelerator_visibile(): cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) xdist_worker_id = get_xdist_worker_id() if xdist_worker_id is None: xdist_worker_id = 0 if cuda_visible is None: - # CUDA_VISIBLE_DEVICES is not set, discover it from nvidia-smi instead + # CUDA_VISIBLE_DEVICES is not set, discover it using accelerator specific command instead import subprocess - is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None - if is_rocm_pytorch: - rocm_smi = subprocess.check_output(['rocm-smi', '--showid']) - gpu_ids = filter(lambda s: 'GPU' in s, - rocm_smi.decode('utf-8').strip().split('\n')) - num_gpus = len(list(gpu_ids)) + if get_accelerator().device_name() == 'cuda': + is_rocm_pytorch = hasattr(torch.version, + 'hip') and torch.version.hip is not None + if is_rocm_pytorch: + rocm_smi = subprocess.check_output(['rocm-smi', '--showid']) + gpu_ids = filter(lambda s: 'GPU' in s, + rocm_smi.decode('utf-8').strip().split('\n')) + num_gpus = len(list(gpu_ids)) + else: + nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus']) + num_gpus = len(nvidia_smi.decode('utf-8').strip().split('\n')) else: - nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus']) - num_gpus = len(nvidia_smi.decode('utf-8').strip().split('\n')) + assert get_accelerator().device_name() == 'xpu' + import re + clinfo = subprocess.check_output(['clinfo']) + lines = clinfo.decode('utf-8').strip().split('\n') + num_gpus = 0 + for line in lines: + match = re.search('Device Type.*GPU', line) + if match: + num_gpus += 1 + cuda_visible = ",".join(map(str, range(num_gpus))) # rotate list based on xdist worker id, example below @@ -69,7 +83,7 @@ class DistributedExec(ABC): methods needed for DistributedTest and DistributedFixture. """ world_size = 2 - backend = "nccl" + backend = get_accelerator().communication_backend_name() init_distributed = True set_dist_env = True @@ -153,14 +167,14 @@ def _dist_init(self, local_rank, num_procs, skip_msg): # turn off NCCL logging if set os.environ.pop('NCCL_DEBUG', None) - set_cuda_visibile() + set_accelerator_visibile() if self.init_distributed: deepspeed.init_distributed(dist_backend=self.backend) dist.barrier() - if torch.cuda.is_available(): - torch.cuda.set_device(local_rank) + if get_accelerator().is_available(): + get_accelerator().set_device(local_rank) try: self.run(**self._fixture_kwargs) diff --git a/tests/unit/compression/test_compression.py b/tests/unit/compression/test_compression.py index 6ff215c4ec13..2268f471c75b 100644 --- a/tests/unit/compression/test_compression.py +++ b/tests/unit/compression/test_compression.py @@ -8,6 +8,7 @@ from unit.modelingpreln import BertEncoder as BertEncoderPreln from deepspeed.compression.basic_layer import LinearLayer_Compress, ColumnParallelLinear_Compress, RowParallelLinear_Compress from deepspeed.compression.helper import convert_conv1d_to_linear +from deepspeed.accelerator import get_accelerator TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) @@ -20,7 +21,7 @@ def reset_random(seed=1234): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + get_accelerator().manual_seed_all(seed) def create_bert_model(): diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 9ff8fff43c98..53767c2c309c 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -12,6 +12,7 @@ from transformers.models.t5.modeling_t5 import T5Block from transformers.models.roberta.modeling_roberta import RobertaLayer from huggingface_hub import HfApi +from deepspeed.accelerator import get_accelerator rocm_version = OpBuilder.installed_rocm_version() if rocm_version != (0, 0): @@ -252,10 +253,10 @@ def test( # Warm-up queries for perf measurement #for i in range(10): # _ = pipe(query, **inf_kwargs) - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() bs_output = pipe(query, **inf_kwargs) - torch.cuda.synchronize() + get_accelerator().synchronize() bs_time = time.time() - start pipe.model = deepspeed.init_inference( @@ -269,10 +270,10 @@ def test( # Warm-up queries for perf measurement #for i in range(10): # _ = pipe(query, **inf_kwargs) - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() ds_output = pipe(query, **inf_kwargs) - torch.cuda.synchronize() + get_accelerator().synchronize() ds_time = time.time() - start # facebook/opt* and some bigscient/bloom* models are not matching @@ -326,7 +327,7 @@ def test( replace_method="auto", replace_with_kernel_inject=True) # Switch device to GPU so that input tensors are not on CPU - pipe.device = torch.device(f"cuda:{local_rank}") + pipe.device = torch.device(get_accelerator().device_name(local_rank)) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) @@ -419,7 +420,7 @@ def test(self, model_family, model_name, task): import lm_eval.evaluator local_rank = os.getenv("LOCAL_RANK", "0") - device = torch.device(f"cuda:{local_rank}") + device = torch.device(get_accelerator().device_name(local_rank)) dtype = torch.float task_dict = lm_eval.tasks.get_task_dict([task]) @@ -433,12 +434,12 @@ def test(self, model_family, model_name, task): else: lm = lm_eval.models.get_model(model_family).create_from_arg_string( f"pretrained={model_name}", - {"device": f"cuda:{local_rank}"}) + {"device": get_accelerator().device_name(local_rank)}) - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() bs_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict) - torch.cuda.synchronize() + get_accelerator().synchronize() bs_time = time.time() - start ds_model = deepspeed.init_inference( @@ -451,10 +452,10 @@ def test(self, model_family, model_name, task): enable_cuda_graph=False, ) setattr(lm, model_family, ds_model) - torch.cuda.synchronize() + get_accelerator().synchronize() start = time.time() ds_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict) - torch.cuda.synchronize() + get_accelerator().synchronize() ds_time = time.time() - start ppl_diff = abs(bs_output["results"][task]["ppl"] - diff --git a/tests/unit/megatron_model.py b/tests/unit/megatron_model.py index 6fc55393295f..b1282ecaf09b 100644 --- a/tests/unit/megatron_model.py +++ b/tests/unit/megatron_model.py @@ -5,6 +5,7 @@ from .common import get_test_path from deepspeed.pipe import PipelineModule, LayerSpec +from deepspeed.accelerator import get_accelerator def get_megatron_version(): @@ -37,10 +38,10 @@ def get_gpt2_model(args_others, mp_size=1): initialize_megatron(args_defaults=args_defaults, ignore_unknown_args=True) model = GPT2Model(num_tokentypes=0, parallel_output=False) - model.cuda() + model.to(get_accelerator().device_name()) from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from megatron import mpu - i = torch.cuda.current_device() + i = get_accelerator().current_device_name() model = torchDDP(model, device_ids=[i], output_device=i, @@ -76,8 +77,9 @@ def __init__(self, num_layers, mp_size, args_others, topo, **kwargs): class ParallelTransformerLayerPipe(ParallelTransformerLayer): def forward(self, args): # hardcode attn mask for testing, PP requires the attn_mask to be stashed - attention_mask = torch.tensor([[True]], - device=torch.cuda.current_device()) + attention_mask = torch.tensor( + [[True]], + device=get_accelerator().current_device_name()) return super().forward(args, attention_mask) layers = [] diff --git a/tests/unit/model_parallelism/test_configurable_parallel_mp.py b/tests/unit/model_parallelism/test_configurable_parallel_mp.py index dda4c22bcdae..4d8c4d2b1309 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_mp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_mp.py @@ -5,6 +5,7 @@ import random import numpy as np import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest, DistributedFixture from unit.megatron_model import get_gpt2_model, get_megatron_version @@ -40,7 +41,7 @@ def reset_random(self, seed=1234): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + get_accelerator().manual_seed_all(seed) @pytest.fixture def inputs(self, bs=1, seq_len=20): @@ -68,7 +69,10 @@ def test_gpt2_basic(self, tmpdir, inputs): model = get_deepspeed_model(model) model.eval() - baseline = model(inputs[0].cuda(), inputs[1].cuda(), inputs[2].cuda()) + device_name = get_accelerator().device_name() + baseline = model(inputs[0].to(device_name), + inputs[1].to(device_name), + inputs[2].to(device_name)) tag = 'mp_1' state_dict = {} @@ -97,7 +101,10 @@ def test_gpt2_mp2_no_resize(self, tmpdir, inputs): model.eval() - baseline = model(inputs[0].cuda(), inputs[1].cuda(), inputs[2].cuda()) + device_name = get_accelerator().device_name() + baseline = model(inputs[0].to(device_name), + inputs[1].to(device_name), + inputs[2].to(device_name)) tag = 'mp_2' state_dict = {} @@ -109,7 +116,10 @@ def test_gpt2_mp2_no_resize(self, tmpdir, inputs): load_optimizer_states=False, load_lr_scheduler_states=False) - test = model(inputs[0].cuda(), inputs[1].cuda(), inputs[2].cuda()) + device_name = get_accelerator().device_name() + test = model(inputs[0].to(device_name), + inputs[1].to(device_name), + inputs[2].to(device_name)) assert torch.allclose(baseline, test, rtol=1.0, atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}" @@ -131,7 +141,10 @@ def run(self, inputs, class_tmpdir): model.eval() with torch.no_grad(): - baseline = model(inputs[0].cuda(), inputs[1].cuda(), inputs[2].cuda()) + device_name = get_accelerator().device_name() + baseline = model(inputs[0].to(device_name), + inputs[1].to(device_name), + inputs[2].to(device_name)) if dist.get_rank() == 0: save_path = os.path.join(class_tmpdir, "output.pt") torch.save(baseline.cpu(), save_path) @@ -162,7 +175,10 @@ def test(self, baseline_mp2, inputs, class_tmpdir): model.load_checkpoint(class_tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False) - test = model(inputs[0].cuda(), inputs[1].cuda(), inputs[2].cuda()) + device_name = get_accelerator().device_name() + test = model(inputs[0].to(device_name), + inputs[1].to(device_name), + inputs[2].to(device_name)) if dist.get_rank() == 0: load_path = os.path.join(class_tmpdir, "output.pt") baseline = torch.load(load_path) diff --git a/tests/unit/model_parallelism/test_configurable_parallel_pp.py b/tests/unit/model_parallelism/test_configurable_parallel_pp.py index 164e8cea5363..44985c615048 100644 --- a/tests/unit/model_parallelism/test_configurable_parallel_pp.py +++ b/tests/unit/model_parallelism/test_configurable_parallel_pp.py @@ -9,6 +9,7 @@ from unit.megatron_model import get_megatron_version from unit.megatron_model import MockGPT2ModelPipe as GPT2ModelPipe from deepspeed.utils import RepeatingLoader +from deepspeed.accelerator import get_accelerator TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) @@ -31,7 +32,7 @@ def get_deepspeed_model(model): model, _, _,_ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config_dict) - return model.cuda() + return model.to(get_accelerator().device_name()) def get_topology(mp, pp, world_size): @@ -50,7 +51,7 @@ def reset_random(self, seed=1234): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + get_accelerator().manual_seed_all(seed) @pytest.fixture def inputs(self, bs=1, seq_len=1, hidden_size=128): @@ -146,7 +147,7 @@ def run(self, inputs, class_tmpdir, mp_size, pp_size): model = get_deepspeed_model(gpt2_pipe_model) with torch.no_grad(): - inputs = [x.cuda() for x in inputs] + inputs = [x.to(get_accelerator().device_name()) for x in inputs] if model.is_first_stage() or model.is_last_stage(): loader = RepeatingLoader([(inputs[0], 0)]) data_iter = iter(loader) @@ -206,7 +207,7 @@ def _test(self, inputs, class_tmpdir, mp_size, pp_size, mp_resize, pp_resize): model.load_checkpoint(class_tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False) - inputs = [x.cuda() for x in inputs] + inputs = [x.to(get_accelerator().device_name()) for x in inputs] if model.is_first_stage() or model.is_last_stage(): loader = RepeatingLoader([(inputs[0], 0)]) data_iter = iter(loader) diff --git a/tests/unit/modeling.py b/tests/unit/modeling.py index e8a38afc9538..b4c1eaba771f 100644 --- a/tests/unit/modeling.py +++ b/tests/unit/modeling.py @@ -43,6 +43,7 @@ #from numba import cuda #from deepspeed_cuda import DeepSpeedSoftmaxConfig, DeepSpeedSoftmax +from deepspeed.accelerator import get_accelerator logger = logging.getLogger(__name__) @@ -184,8 +185,8 @@ def swish(x): class GPUTimer: def __init__(self): super().__init__() - self.start = cuda.event() # noqa: F821 - self.stop = cuda.event() # noqa: F821 + self.start = get_accelerator().Event() # noqa: F821 + self.stop = get_accelerator().Event() # noqa: F821 def record(self): self.start.record() @@ -749,12 +750,12 @@ def __init__(self, config, bert_model_embedding_weights): def forward(self, hidden_states): hidden_states = self.transform(hidden_states) - torch.cuda.nvtx.range_push( + get_accelerator().range_push( "decoder input.size() = {}, weight.size() = {}".format( hidden_states.size(), self.decoder.weight.size())) hidden_states = self.decoder(hidden_states) + self.bias - torch.cuda.nvtx.range_pop() + get_accelerator().range_pop() return hidden_states @@ -884,7 +885,7 @@ def from_pretrained(cls, weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load( weights_path, - map_location='cpu' if not torch.cuda.is_available() else None) + map_location='cpu' if not get_accelerator().is_available() else None) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) diff --git a/tests/unit/modelingpreln.py b/tests/unit/modelingpreln.py index 673a73ac91f4..e68a3d0cc269 100644 --- a/tests/unit/modelingpreln.py +++ b/tests/unit/modelingpreln.py @@ -39,6 +39,7 @@ from torch.nn import Module import torch.nn.functional as F import torch.nn.init as init +from deepspeed.accelerator import get_accelerator #from numba import cuda @@ -184,8 +185,8 @@ def swish(x): class GPUTimer: def __init__(self): super().__init__() - self.start = cuda.event() # noqa: F821 - self.stop = cuda.event() # noqa: F821 + self.start = get_accelerator().Event() # noqa: F821 + self.stop = get_accelerator().Event() # noqa: F821 def record(self): self.start.record() @@ -844,12 +845,12 @@ def __init__(self, config, bert_model_embedding_weights): def forward(self, hidden_states): hidden_states = self.transform(hidden_states) - torch.cuda.nvtx.range_push( + get_accelerator().range_push( "decoder input.size() = {}, weight.size() = {}".format( hidden_states.size(), self.decoder.weight.size())) hidden_states = self.decoder(hidden_states) + self.bias - torch.cuda.nvtx.range_pop() + get_accelerator().range_pop() return hidden_states @@ -979,7 +980,7 @@ def from_pretrained(cls, weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load( weights_path, - map_location='cpu' if not torch.cuda.is_available() else None) + map_location='cpu' if not get_accelerator().is_available() else None) if tempdir: # Clean up temp dir shutil.rmtree(tempdir) diff --git a/tests/unit/ops/adagrad/test_cpu_adagrad.py b/tests/unit/ops/adagrad/test_cpu_adagrad.py index 6f530d0309fa..7ae13888c7bd 100644 --- a/tests/unit/ops/adagrad/test_cpu_adagrad.py +++ b/tests/unit/ops/adagrad/test_cpu_adagrad.py @@ -4,9 +4,10 @@ import deepspeed from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad -from deepspeed.ops.op_builder import CPUAdagradBuilder +from deepspeed.accelerator import get_accelerator -if not deepspeed.ops.__compatible_ops__[CPUAdagradBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "CPUAdagradBuilder").name]: pytest.skip("cpu-adagrad is not compatible", allow_module_level=True) @@ -127,7 +128,7 @@ def gen_sparse_grad(vocabulary_size, dim, num_indices, dtype, device): def test_cpu_adam_gpu_error(): model_size = 64 - device = 'cuda:0' + device = get_accelerator().device_name(0) param = torch.nn.Parameter(torch.randn(model_size, device=device)) optimizer = DeepSpeedCPUAdagrad([param]) diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 54389ce5fcf4..dfb59c0a68b2 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -4,10 +4,11 @@ from cpuinfo import get_cpu_info import deepspeed +from deepspeed.accelerator import get_accelerator from deepspeed.ops.adam import FusedAdam -from deepspeed.ops.op_builder import CPUAdamBuilder -if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible", allow_module_level=True) pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() @@ -46,7 +47,9 @@ def test_cpu_adam_opt(dtype, model_size): param1_data = torch.randn(model_size, device=device) param1 = torch.nn.Parameter(param1_data) torch.set_rng_state(rng_state) - param2_data = torch.randn(model_size, device=device).to(dtype).cuda() + param2_data = torch.randn(model_size, + device=device).to(dtype).to( + get_accelerator().device_name()) param2 = torch.nn.Parameter(param2_data) optimizer1 = torch.optim.AdamW([param1]) @@ -59,7 +62,9 @@ def test_cpu_adam_opt(dtype, model_size): torch.set_rng_state(rng_state) param1.grad = torch.randn(model_size, device=device) torch.set_rng_state(rng_state) - param2.grad = torch.randn(model_size, device=device).to(dtype).cuda() + param2.grad = torch.randn(model_size, + device=device).to(dtype).to( + get_accelerator().device_name()) optimizer.step() optimizer2.step() @@ -78,7 +83,7 @@ def test_cpu_adam_opt(dtype, model_size): def test_cpu_adam_gpu_error(): model_size = 64 from deepspeed.ops.adam import DeepSpeedCPUAdam - device = 'cuda:0' + device = get_accelerator().device_name(0) param = torch.nn.Parameter(torch.randn(model_size, device=device)) optimizer = DeepSpeedCPUAdam([param]) diff --git a/tests/unit/ops/aio/test_aio.py b/tests/unit/ops/aio/test_aio.py index ca1f1b923743..c158201f7910 100644 --- a/tests/unit/ops/aio/test_aio.py +++ b/tests/unit/ops/aio/test_aio.py @@ -4,7 +4,7 @@ import torch import deepspeed import deepspeed.comm as dist -from deepspeed.ops.aio import AsyncIOBuilder +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest MEGA_BYTE = 1024**2 @@ -13,7 +13,8 @@ IO_SIZE = 16 * MEGA_BYTE IO_PARALLEL = 2 -if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "AsyncIOBuilder").name]: pytest.skip('Skip tests since async-io is not compatible', allow_module_level=True) @@ -31,9 +32,9 @@ def _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device, index=0): file_suffix = f'{dist.get_rank()}_{index}' test_file = os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') if cuda_device: - test_buffer = torch.cuda.ByteTensor(list(ref_buffer)) + test_buffer = get_accelerator().ByteTensor(list(ref_buffer)) else: - test_buffer = torch.ByteTensor(list(ref_buffer)).pin_memory() + test_buffer = get_accelerator().pin_memory(torch.ByteTensor(list(ref_buffer))) return test_file, test_buffer @@ -54,12 +55,16 @@ class TestRead(DistributedTest): def test_parallel_read(self, tmpdir, single_submit, overlap_events): ref_file, _ = _do_ref_write(tmpdir) - aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) + aio_buffer = get_accelerator().pin_memory( + torch.empty(IO_SIZE, + dtype=torch.uint8, + device='cpu')) + h = get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_handle( + BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) _validate_handle_state(h, single_submit, overlap_events) @@ -75,17 +80,21 @@ def test_async_read(self, tmpdir, single_submit, overlap_events, cuda_device): ref_file, _ = _do_ref_write(tmpdir) if cuda_device: - aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') - else: aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, - device='cpu').pin_memory() - - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) + device=get_accelerator().device_name()) + else: + aio_buffer = get_accelerator().pin_memory( + torch.empty(IO_SIZE, + dtype=torch.uint8, + device='cpu')) + + h = get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_handle( + BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) _validate_handle_state(h, single_submit, overlap_events) @@ -109,11 +118,12 @@ def test_parallel_write(self, tmpdir, single_submit, overlap_events): ref_file, ref_buffer = _do_ref_write(tmpdir) aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, False) - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) + h = get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_handle( + BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) _validate_handle_state(h, single_submit, overlap_events) @@ -130,11 +140,12 @@ def test_async_write(self, tmpdir, single_submit, overlap_events, cuda_device): ref_file, ref_buffer = _do_ref_write(tmpdir) aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device) - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) + h = get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_handle( + BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) _validate_handle_state(h, single_submit, overlap_events) @@ -164,18 +175,24 @@ def test_read(self, tmpdir, async_queue, cuda_device): aio_buffers = [] for i in range(async_queue): if cuda_device: - buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') + buf = torch.empty(IO_SIZE, + dtype=torch.uint8, + device=get_accelerator().device_name()) else: - buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() + buf = get_accelerator().pin_memory( + torch.empty(IO_SIZE, + dtype=torch.uint8, + device='cpu')) aio_buffers.append(buf) single_submit = True overlap_events = True - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) + h = get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_handle( + BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) _validate_handle_state(h, single_submit, overlap_events) @@ -209,11 +226,12 @@ def test_write(self, tmpdir, async_queue, cuda_device): single_submit = True overlap_events = True - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) + h = get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_handle( + BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) _validate_handle_state(h, single_submit, overlap_events) diff --git a/tests/unit/ops/cuda/test_cuda_backward.py b/tests/unit/ops/cuda/test_cuda_backward.py index f2720ce5c1ce..b9680f28aa2f 100644 --- a/tests/unit/ops/cuda/test_cuda_backward.py +++ b/tests/unit/ops/cuda/test_cuda_backward.py @@ -6,6 +6,7 @@ import copy from torch import nn from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig +from deepspeed.accelerator import get_accelerator from unit.modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderPostln from unit.modelingpreln import BertEncoder as BertEncoderPreln @@ -81,7 +82,7 @@ def zero_grad(variables): variable.grad.zero_() -device = torch.device("cuda") +device = torch.device(get_accelerator().device_name()) kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True} kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True} @@ -207,8 +208,8 @@ def create_models(ds_config): bert_encoder.half() ds_encoder.half() - bert_encoder.cuda() - ds_encoder.cuda() + bert_encoder.to(get_accelerator().device_name()) + ds_encoder.to(get_accelerator().device_name()) return bert_encoder, ds_encoder @@ -281,9 +282,8 @@ def test_backward(batch_size, is_preln, use_fp16, atol): - # Only run fp16 test cases on devices with 7+ capability. - major, _ = torch.cuda.get_device_capability() - if major < 7 and (use_fp16 is True or is_preln is False): + # Only run fp16 test cases on devices with FP16 capability. + if not get_accelerator().is_fp16_supported() and use_fp16 is True: return ds_config = DeepSpeedTransformerConfig() @@ -317,11 +317,9 @@ def test_backward(batch_size, # is_preln, # use_fp16, # atol): -# # Only run fp16 test cases on devices with 7+ capability. -# major, _ = torch.cuda.get_device_capability() -# if major < 7 and (use_fp16 is True or is_preln is False): +# # Only run fp16 test cases on devices with FP16 capability. +# if not get_accelerator().is_fp16_supported() and use_fp16 is True: # return -# # ds_config = DeepSpeedTransformerConfig() # ds_config.layer_id = None # ds_config.batch_size = batch_size diff --git a/tests/unit/ops/cuda/test_cuda_forward.py b/tests/unit/ops/cuda/test_cuda_forward.py index 546a596523a8..79960f409555 100644 --- a/tests/unit/ops/cuda/test_cuda_forward.py +++ b/tests/unit/ops/cuda/test_cuda_forward.py @@ -8,6 +8,7 @@ from unit.modelingpreln import BertEncoder as BertEncoderPreln from unit.modeling import BertLayerNorm, BertConfig, BertEncoder as BertEncoderPostln from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig +from deepspeed.accelerator import get_accelerator def check_equal(first, second, atol=1e-2, verbose=False): @@ -28,7 +29,7 @@ def zero_grad(variables): variable.grad.zero_() -device = torch.device("cuda") +device = torch.device(get_accelerator().device_name()) kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True} kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True} @@ -147,8 +148,8 @@ def create_models(ds_config): bert_encoder.half() ds_encoder.half() - bert_encoder.cuda() - ds_encoder.cuda() + bert_encoder.to(get_accelerator().device_name()) + ds_encoder.to(get_accelerator().device_name()) return bert_encoder, ds_encoder @@ -234,9 +235,8 @@ def test_forward(batch_size, num_layers, is_preln, use_fp16): - # Only run fp16 test cases on devices with 7+ capability. - major, _ = torch.cuda.get_device_capability() - if major < 7 and use_fp16 is True: + # Only run fp16 test cases on devices with FP16 capability. + if not get_accelerator().is_fp16_supported() and use_fp16 is True: return ds_config = DeepSpeedTransformerConfig() @@ -270,9 +270,8 @@ def test_forward_with_small_bsz(batch_size, num_layers, is_preln, use_fp16): - # Only run fp16 test cases on devices with 7+ capability. - major, _ = torch.cuda.get_device_capability() - if major < 7 and use_fp16 is True: + # Only run fp16 test cases on devices with FP16 capability. + if not get_accelerator().is_fp16_supported() and use_fp16 is True: return ds_config = DeepSpeedTransformerConfig() @@ -304,9 +303,8 @@ def test_forward_stochastic(batch_size, num_layers, is_preln, use_fp16): - # Only run fp16 test cases on devices with 7+ capability. - major, _ = torch.cuda.get_device_capability() - if major < 7 and use_fp16 is True: + # Only run fp16 test cases on devices with FP16 capability. + if not get_accelerator().is_fp16_supported() and use_fp16 is True: return ds_config = DeepSpeedTransformerConfig() diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index 1526937dd2bc..440f8a88bb6e 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -1,6 +1,6 @@ import torch import pytest -from deepspeed.ops import op_builder +from deepspeed.accelerator import get_accelerator quantizer_cuda_module = None @@ -30,7 +30,8 @@ def run_quant_dequant(inputs, groups, bits): global quantizer_cuda_module if quantizer_cuda_module is None: - quantizer_cuda_module = op_builder.QuantizerBuilder().load() + quantizer_cuda_module = get_accelerator().create_op_builder( + "QuantizerBuilder").load() return quantizer_cuda_module.ds_quantize_fp16(inputs, groups, bits) @@ -42,7 +43,8 @@ def run_quant_dequant(inputs, groups, bits): # Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG. def test_quant_dequant(tensor_shape, groups): - input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda() + input_tensor = torch.rand((tensor_shape), + dtype=torch.float16).to(get_accelerator().device_name()) # 8-bit quantization. ref_input_8bit = input_tensor.clone().detach() diff --git a/tests/unit/ops/sparse_attention/test_sparse_attention.py b/tests/unit/ops/sparse_attention/test_sparse_attention.py index 740dfacdd0de..0a24a54147c3 100644 --- a/tests/unit/ops/sparse_attention/test_sparse_attention.py +++ b/tests/unit/ops/sparse_attention/test_sparse_attention.py @@ -6,9 +6,10 @@ import pytest import torch import deepspeed -from deepspeed.ops.op_builder import SparseAttnBuilder +from deepspeed.accelerator import get_accelerator -if not deepspeed.ops.__compatible_ops__[SparseAttnBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "SparseAttnBuilder").name]: pytest.skip("sparse attention op is not compatible on this system", allow_module_level=True) @@ -92,7 +93,13 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo if layout is None: layout = make_layout(rho, (H, M // block, N // block)) if dense_x: - x = torch.rand((Z, H, M, N), dtype=dtype, requires_grad=True, device='cuda') + x = torch.rand((Z, + H, + M, + N), + dtype=dtype, + requires_grad=True, + device=get_accelerator().device_name()) else: x = torch.rand((Z, layout.sum(), @@ -100,7 +107,7 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo block), dtype=dtype, requires_grad=True, - device='cuda') + device=get_accelerator().device_name()) dx = torch.rand_like(x) bool_attn_mask = torch.randint(low=0, high=2, @@ -108,7 +115,7 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo N), dtype=torch.bool, requires_grad=False, - device='cuda') + device=get_accelerator().device_name()) fp_attn_mask = bool_attn_mask.type(dtype) kp_mask = torch.randint(low=0, high=2, @@ -116,20 +123,24 @@ def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layo N), dtype=dtype, requires_grad=False, - device='cuda') + device=get_accelerator().device_name()) kp_mask[kp_mask == 1.] = float('-inf') return layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask def _skip_on_cuda_compatability(): - if torch.cuda.get_device_capability()[0] < 7: - pytest.skip("needs higher compute capability than 7") - cuda_major = int(torch.version.cuda.split('.')[0]) * 10 - cuda_minor = int(torch.version.cuda.split('.')[1]) - cuda_version = cuda_major + cuda_minor - if (cuda_version != 101 and cuda_version != 102) and \ - (cuda_version != 111 and cuda_version != 110): - pytest.skip("requires cuda 10.1 or 10.2 or 11.0 or 11.1") + if deepspeed.accelerator.get_accelerator().device_name() == 'cuda': + if torch.cuda.get_device_capability()[0] < 7: + pytest.skip("needs higher compute capability than 7") + cuda_major = int(torch.version.cuda.split('.')[0]) * 10 + cuda_minor = int(torch.version.cuda.split('.')[1]) + cuda_version = cuda_major + cuda_minor + if (cuda_version != 101 and cuda_version != 102) and \ + (cuda_version != 111 and cuda_version != 110): + pytest.skip("requires cuda 10.1 or 10.2 or 11.0 or 11.1") + else: + assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu' + return @pytest.mark.parametrize("block", [16, 32]) @@ -193,9 +204,21 @@ def init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, BS0 = N if trans_b else K BS1 = K if trans_b else N shape = {'sdd': (M, N), 'dsd': (AS0, AS1), 'dds': (BS0, BS1)}[mode] - x = torch.rand((Z, H, AS0, AS1), dtype=dtype, requires_grad=True, device='cuda') - w = torch.rand((Z, H, BS0, BS1), dtype=dtype, requires_grad=True, device='cuda') - dy = torch.rand((Z, H, M, N), dtype=dtype, device='cuda') + x = torch.rand((Z, + H, + AS0, + AS1), + dtype=dtype, + requires_grad=True, + device=get_accelerator().device_name()) + w = torch.rand((Z, + H, + BS0, + BS1), + dtype=dtype, + requires_grad=True, + device=get_accelerator().device_name()) + dy = torch.rand((Z, H, M, N), dtype=dtype, device=get_accelerator().device_name()) if layout is None: layout = make_layout(rho, (H, shape[0] // block, shape[1] // block)) else: diff --git a/tests/unit/ops/transformer/inference/test_bias_add.py b/tests/unit/ops/transformer/inference/test_bias_add.py index 2077390aabfc..3d573bc38ee2 100644 --- a/tests/unit/ops/transformer/inference/test_bias_add.py +++ b/tests/unit/ops/transformer/inference/test_bias_add.py @@ -1,9 +1,10 @@ import pytest import torch import deepspeed -from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator -if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "InferenceBuilder").name]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -24,7 +25,7 @@ def run_bias_add_reference(activations, bias): def run_bias_add_ds(activations, bias): global inference_module if inference_module is None: - inference_module = InferenceBuilder().load() + inference_module = get_accelerator().create_op_builder("InferenceBuilder").load() if activations.dtype == torch.float16: return inference_module.bias_add_fp16(activations, bias) else: @@ -37,8 +38,14 @@ def run_bias_add_ds(activations, bias): @pytest.mark.parametrize("channels", [512, 1232, 4096]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) def test_bias_add(batch, sequence, channels, dtype): - activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda') - bias_ds = torch.randn((channels), dtype=dtype, device='cuda') + activations_ds = torch.randn((batch, + sequence, + channels), + dtype=dtype, + device=get_accelerator().device_name()) + bias_ds = torch.randn((channels), + dtype=dtype, + device=get_accelerator().device_name()) activations_ref = activations_ds.clone().detach() bias_ref = bias_ds.clone().detach() diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index bf0b184fb5fe..0c7bfa0a19dc 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -5,9 +5,10 @@ import pytest import torch import deepspeed -from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator -if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "InferenceBuilder").name]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -41,7 +42,7 @@ def run_bias_gelu_reference(activations, bias): def run_bias_gelu_ds(activations, bias): global inference_module if inference_module is None: - inference_module = InferenceBuilder().load() + inference_module = get_accelerator().create_op_builder("InferenceBuilder").load() if activations.dtype == torch.float16: return inference_module.bias_gelu_fp16(activations, bias) else: @@ -54,8 +55,14 @@ def run_bias_gelu_ds(activations, bias): @pytest.mark.parametrize("channels", [512, 1232, 4096]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_bias_gelu(batch, sequence, channels, dtype): - activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda') - bias_ds = torch.randn((channels), dtype=dtype, device='cuda') + activations_ds = torch.randn((batch, + sequence, + channels), + dtype=dtype, + device=get_accelerator().device_name()) + bias_ds = torch.randn((channels), + dtype=dtype, + device=get_accelerator().device_name()) activations_ref = activations_ds.clone().detach() bias_ref = bias_ds.clone().detach() diff --git a/tests/unit/ops/transformer/inference/test_bias_relu.py b/tests/unit/ops/transformer/inference/test_bias_relu.py index c62b4b29bebd..6c13be3b4dd7 100644 --- a/tests/unit/ops/transformer/inference/test_bias_relu.py +++ b/tests/unit/ops/transformer/inference/test_bias_relu.py @@ -5,9 +5,10 @@ import pytest import torch import deepspeed -from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator -if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "InferenceBuilder").name]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -30,7 +31,7 @@ def run_bias_relu_reference(activations, bias): def run_bias_relu_ds(activations, bias): global inference_module if inference_module is None: - inference_module = InferenceBuilder().load() + inference_module = get_accelerator().create_op_builder("InferenceBuilder").load() if activations.dtype == torch.float16: return inference_module.bias_relu_fp16(activations, bias) else: @@ -43,8 +44,14 @@ def run_bias_relu_ds(activations, bias): @pytest.mark.parametrize("channels", [512, 1232, 4096]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_bias_relu(batch, sequence, channels, dtype): - activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda') - bias_ds = torch.randn((channels), dtype=dtype, device='cuda') + activations_ds = torch.randn((batch, + sequence, + channels), + dtype=dtype, + device=get_accelerator().device_name()) + bias_ds = torch.randn((channels), + dtype=dtype, + device=get_accelerator().device_name()) activations_ref = activations_ds.clone().detach() bias_ref = bias_ds.clone().detach() diff --git a/tests/unit/ops/transformer/inference/test_moe_res_matmult.py b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py index 8b1b1cb16168..57cbfc539ac2 100644 --- a/tests/unit/ops/transformer/inference/test_moe_res_matmult.py +++ b/tests/unit/ops/transformer/inference/test_moe_res_matmult.py @@ -5,9 +5,10 @@ import pytest import torch import deepspeed -from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator -if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "InferenceBuilder").name]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -27,7 +28,7 @@ def run_moe_res_matmul_reference(residual, coef1, coef2, output): def run_moe_res_matmul_ds(residual, coef, output): global inference_module if inference_module is None: - inference_module = InferenceBuilder().load() + inference_module = get_accelerator().create_op_builder("InferenceBuilder").load() coef_t = coef.transpose(-1, -2).contiguous() return inference_module.moe_res_matmul(residual, coef_t, output) @@ -41,10 +42,22 @@ def test_moe_residual_matmul(hidden_dim, c, dtype): hidden_dim * c, hidden_dim), dtype=dtype, - device='cuda') - coeff1 = torch.randn((1, 1, hidden_dim), dtype=dtype, device='cuda') - coeff2 = torch.randn((1, 1, hidden_dim), dtype=dtype, device='cuda') - out_ds = torch.randn((c, hidden_dim * c, hidden_dim), dtype=dtype, device='cuda') + device=get_accelerator().device_name()) + coeff1 = torch.randn((1, + 1, + hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) + coeff2 = torch.randn((1, + 1, + hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) + out_ds = torch.randn((c, + hidden_dim * c, + hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) coeff_ds = torch.cat((coeff1, coeff2), dim=-1) residual_ref = residual_ds.clone().detach() coeff_ref = coeff_ds.clone().detach() diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index 336008f5a1e8..3d18a02ce6e7 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -5,9 +5,10 @@ import pytest import torch import deepspeed -from deepspeed.ops.op_builder import InferenceBuilder +from deepspeed.accelerator import get_accelerator -if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: +if not deepspeed.ops.__compatible_ops__[get_accelerator().create_op_builder( + "InferenceBuilder").name]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -20,7 +21,7 @@ def allclose(x, y): @pytest.fixture(scope="module") def inference_module(): - return InferenceBuilder().load() + return get_accelerator().create_op_builder("InferenceBuilder").load() def res_add_bias_ref(hidden_state, @@ -95,11 +96,27 @@ def test_residual_add(inference_module, add_bias, mp_size, pre_attn_norm): - ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') - residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') - attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') - final_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda') - attn_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda') + ds_out = torch.randn((batch, + sequence, + hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) + residual = torch.randn((batch, + sequence, + hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) + attn_output = torch.randn((batch, + sequence, + hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) + final_bias = torch.randn((hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) + attn_bias = torch.randn((hidden_dim), + dtype=dtype, + device=get_accelerator().device_name()) ref_out = ds_out.clone() ref_out = run_residual_add_reference(ref_out, diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 2b8cce57d7a8..ed2380009246 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -9,6 +9,7 @@ import deepspeed from deepspeed.pipe import PipelineModule from deepspeed.utils import RepeatingLoader +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest @@ -70,7 +71,8 @@ def test(self, sequential_model, simple_config, batch_input): # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) - total_pipe_params = torch.LongTensor([my_params]).to('cuda') + total_pipe_params = torch.LongTensor([my_params + ]).to(get_accelerator().device_name()) dist.all_reduce(total_pipe_params) total_pipe_params = total_pipe_params.item() assert total_pipe_params == base_params @@ -81,7 +83,7 @@ def test(self, sequential_model, simple_config, batch_input): model_parameters=[p for p in pipe_model.parameters()]) if pipe_model.is_first_stage or pipe_model.is_last_stage: - pipe_input = base_input.clone().detach().to('cuda') + pipe_input = base_input.clone().detach().to(get_accelerator().device_name()) # label 0 is meaningless dataset = [(pipe_input, 0)] loader = RepeatingLoader(dataset) diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py index 375ef30de0f6..9e3804d057a1 100644 --- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py @@ -3,6 +3,7 @@ import pytest import torch import deepspeed +from deepspeed.accelerator import get_accelerator from copy import deepcopy from unit.common import DistributedTest @@ -36,7 +37,7 @@ def _prep_inputs(*inputs): for inp in inputs: inp = deepcopy(inp) if torch.is_tensor(inp): - inp = inp.cuda() + inp = inp.to(get_accelerator().device_name()) _inputs.append(inp) return tuple(_inputs) @@ -57,7 +58,7 @@ def _match_outputs(ref, tgt): def _test_activation_checkpoint(module, *inputs): # Move to device - module.cuda() + module.to(get_accelerator().device_name()) # Get rid of dropouts until we fork the RNG between tests. module.eval() @@ -77,7 +78,7 @@ def _test_activation_checkpoint(module, *inputs): def _test_activation_checkpoint_ordering(module, expected_ordering, *inputs): # Move to device - module.cuda() + module.to(get_accelerator().device_name()) # Get rid of dropouts until we fork the RNG between tests. module.eval() diff --git a/tests/unit/runtime/comm/test_coalesced_collectives.py b/tests/unit/runtime/comm/test_coalesced_collectives.py index 92a081fb309b..1a007595b21a 100644 --- a/tests/unit/runtime/comm/test_coalesced_collectives.py +++ b/tests/unit/runtime/comm/test_coalesced_collectives.py @@ -3,6 +3,7 @@ import torch import deepspeed.comm as dist from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest @@ -15,7 +16,7 @@ def test_single_input(self): ), dist.get_rank(), dtype=torch.half, - device=torch.cuda.current_device()) + device=get_accelerator().current_device_name()) (output, ) = reduce_scatter_coalesced([input], dist.get_world_group()) @@ -23,7 +24,10 @@ def test_single_input(self): assert torch.allclose(output, torch.full_like(output, 0.5)) def test_two_inputs(self): - tensor_kwargs = {"device": torch.cuda.current_device(), "dtype": torch.half} + tensor_kwargs = { + "device": get_accelerator().current_device_name(), + "dtype": torch.half + } inputs = [ dist.get_rank() * torch.arange(0, 6, @@ -51,7 +55,10 @@ class TestReduceScatterCoalescedTensorSmallerThanWorldSize(DistributedTest): world_size = 2 def test(self): - input = torch.zeros((1, ), dtype=torch.half, device=torch.cuda.current_device()) + input = torch.zeros((1, + ), + dtype=torch.half, + device=get_accelerator().current_device_name()) (output, ) = reduce_scatter_coalesced([input], dist.get_world_group()) diff --git a/tests/unit/runtime/half_precision/onebit/test_onebit.py b/tests/unit/runtime/half_precision/onebit/test_onebit.py index 451d6abb6731..ddf8e09e6a2a 100644 --- a/tests/unit/runtime/half_precision/onebit/test_onebit.py +++ b/tests/unit/runtime/half_precision/onebit/test_onebit.py @@ -13,6 +13,7 @@ from unit.common import DistributedTest from unit.simple_model import SimpleModel, random_dataloader from unit.alexnet_model import AlexNetPipe, train_cifar +from deepspeed.accelerator import get_accelerator PipeTopo = PipeDataParallelTopology @@ -46,7 +47,7 @@ def test(self, dtype): "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -89,7 +90,7 @@ def test(self): "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -156,7 +157,7 @@ def test(self, tmpdir): "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -312,7 +313,7 @@ def test_overflow(self, tmpdir): "weight_decay": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -379,7 +380,7 @@ def test(self, topo_config): "weight_decay": 3e-7, "freeze_step": 200, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -434,7 +435,7 @@ def test(self, dtype): "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -480,7 +481,7 @@ def test(self): "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -550,7 +551,7 @@ def test(self, tmpdir): "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -705,7 +706,7 @@ def test_overflow(self, tmpdir): "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -775,7 +776,7 @@ def test(self, topo_config): "local_step_scaler": 1, "local_step_clipper": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -829,7 +830,7 @@ def test(self, dtype): "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, @@ -878,7 +879,7 @@ def test(self): "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, @@ -950,7 +951,7 @@ def test(self, tmpdir): "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, @@ -1125,7 +1126,7 @@ def test_overflow(self, tmpdir): "min_coeff": 0.01, "freeze_step": 2, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), "coeff_beta": 0.9, "factor_max": 1.0, "factor_min": 0.5, @@ -1196,7 +1197,7 @@ def test(self, topo_config): "weight_decay": 3e-7, "freeze_step": 200, "cuda_aware": False, - "comm_backend_name": "nccl", + "comm_backend_name": get_accelerator().communication_backend_name(), }, }, "gradient_clipping": 1.0, @@ -1244,7 +1245,7 @@ def test(self, tmpdir): rank = dist.get_rank() backend = NcclBackend() local_rank = dist.get_rank() - device = torch.device("cuda", dist.get_rank()) + device = torch.device(get_accelerator().device_name(), dist.get_rank()) # A simulated compression function using deepspeed.comm def torch_sim(a): @@ -1266,7 +1267,7 @@ def torch_sim(a): [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) rank = dist.get_rank() server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] - torch.cuda.synchronize() + get_accelerator().synchronize() dist.barrier() return a_server_compressed, worker_error, server_error @@ -1286,7 +1287,7 @@ def torch_sim(a): server_error = torch.zeros(right_server_size, device=device) a_torch, worker_error_torch, server_error_torch = torch_sim(a) - torch.cuda.empty_cache() + get_accelerator().empty_cache() a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) diff --git a/tests/unit/runtime/half_precision/test_bf16.py b/tests/unit/runtime/half_precision/test_bf16.py index de15a0868df4..7208fa7281b7 100644 --- a/tests/unit/runtime/half_precision/test_bf16.py +++ b/tests/unit/runtime/half_precision/test_bf16.py @@ -3,7 +3,7 @@ import pytest from deepspeed.ops.adam import FusedAdam from unit.common import DistributedTest -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.accelerator import get_accelerator from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader from unit.util import bf16_required_version_check from deepspeed import comm as dist @@ -18,7 +18,8 @@ def test(self, zero_stage=2, use_cpu_offload=False): " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { @@ -80,7 +81,8 @@ def test(self, zero_stage=2, use_cpu_offload=False): " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { @@ -118,7 +120,8 @@ def test(self, zero_stage=2, use_cpu_offload=False): " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" ) - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") if zero_stage == 3: diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index a6d1b12c0349..98f7d0ba5b17 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -4,9 +4,9 @@ import pytest from deepspeed.ops.adam import FusedAdam from unit.common import DistributedTest -from deepspeed.ops.op_builder import CPUAdamBuilder from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, SimpleMoEModel, sequence_dataloader from unit.util import required_torch_version +from deepspeed.accelerator import get_accelerator try: from apex import amp # noqa: F401 @@ -193,7 +193,7 @@ def test_unfused_gradnorm(self, monkeypatch): hidden_dim = 10 def mock_unscale_and_clip_grads(total_norm, apply_scale=True): - torch_norm_tensor = torch.cuda.FloatTensor([total_norm]) + torch_norm_tensor = get_accelerator().FloatTensor([total_norm]) all_gather_results = [ torch.zeros_like(torch_norm_tensor) for _ in range(dist.get_world_size()) ] @@ -234,7 +234,7 @@ def test_fused_gradnorm(self, monkeypatch): hidden_dim = 10 def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True): - torch_norm_tensor = torch.cuda.FloatTensor([total_norm]) + torch_norm_tensor = get_accelerator().FloatTensor([total_norm]) all_gather_results = [ torch.zeros_like(torch_norm_tensor) for _ in range(dist.get_world_size()) ] @@ -283,7 +283,7 @@ def test_lamb_gradnorm(self, monkeypatch, fused_lamb_legacy: bool): hidden_dim = 10 def mock_unscale_and_clip_grads(total_norm, apply_scale=True): - torch_norm_tensor = torch.cuda.FloatTensor([total_norm]) + torch_norm_tensor = get_accelerator().FloatTensor([total_norm]) all_gather_results = [ torch.zeros_like(torch_norm_tensor) for _ in range(dist.get_world_size()) ] @@ -345,7 +345,8 @@ class TestAdamFP16ZeroOneCycleCompatibility(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { @@ -402,7 +403,8 @@ class TestZeroStaticScale(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload, hidden_dim): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { @@ -450,7 +452,8 @@ class TestZeroAllowUntestedOptimizer(DistributedTest): world_size = 1 def test(self, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { @@ -482,7 +485,8 @@ class TestZeroEmptyPartition(DistributedTest): world_size = 3 def test(self, zero_stage, use_cpu_offload): - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") if zero_stage == 3: diff --git a/tests/unit/runtime/pipe/test_topology.py b/tests/unit/runtime/pipe/test_topology.py index 35860c5f5167..43a89602272f 100644 --- a/tests/unit/runtime/pipe/test_topology.py +++ b/tests/unit/runtime/pipe/test_topology.py @@ -7,6 +7,7 @@ from deepspeed.runtime.pipe.topology import ProcessTopology as Topo from deepspeed.runtime.pipe.topology import _prime_factors +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest @@ -173,13 +174,13 @@ def test_grid_pipe_data(self): grid.get_stage_id() == grid.get_pipe_parallel_world_size() - 1) # Test collectives along the pipeline parallel process groups - rank_tensor = torch.LongTensor(data=[rank]).cuda() + rank_tensor = torch.LongTensor(data=[rank]).to(get_accelerator().device_name()) dist.all_reduce(rank_tensor, group=grid.get_pipe_parallel_group()) pipe_group = grid.pp_group assert torch.all(rank_tensor == sum(pipe_group)) # Test collectives along the data parallel process groups - rank_tensor = torch.LongTensor(data=[rank]).cuda() + rank_tensor = torch.LongTensor(data=[rank]).to(get_accelerator().device_name()) dist.all_reduce(rank_tensor, group=grid.get_data_parallel_group()) data_group = grid.dp_group assert torch.all(rank_tensor == sum(data_group)) diff --git a/tests/unit/runtime/test_autocast.py b/tests/unit/runtime/test_autocast.py index f402486455ca..b7f6e145c1f6 100644 --- a/tests/unit/runtime/test_autocast.py +++ b/tests/unit/runtime/test_autocast.py @@ -1,39 +1,50 @@ import pytest import torch from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3 +from deepspeed.accelerator import get_accelerator @pytest.mark.parametrize('half_op', [False, True]) def test_missing_amp_autocast(tmpdir, half_op): hidden_dim = 4 if half_op: - input = torch.randn(hidden_dim).cuda().half() - ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda().half() + input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half() + ds_linear = LinearModuleForZeroStage3( + hidden_dim, + hidden_dim).to(get_accelerator().device_name()).half() else: - input = torch.randn(hidden_dim).cuda() - ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() + input = torch.randn(hidden_dim).to(get_accelerator().device_name()) + ds_linear = LinearModuleForZeroStage3(hidden_dim, + hidden_dim).to( + get_accelerator().device_name()) output = ds_linear(input) assert output.dtype == ds_linear.weight.dtype +@pytest.mark.skipif(get_accelerator().amp() is None, reason='amp is not installed') @pytest.mark.parametrize('half_op', [False, True]) def test_disable_autocast_linear(tmpdir, half_op): - amp = pytest.importorskip("torch.cuda.amp") + amp = get_accelerator().amp() hidden_dim = 4 if half_op: - input = torch.randn(hidden_dim).cuda().half() - ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda().half() + input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half() + ds_linear = LinearModuleForZeroStage3( + hidden_dim, + hidden_dim).to(get_accelerator().device_name()).half() else: - input = torch.randn(hidden_dim).cuda() - ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() + input = torch.randn(hidden_dim).to(get_accelerator().device_name()) + ds_linear = LinearModuleForZeroStage3(hidden_dim, + hidden_dim).to( + get_accelerator().device_name()) with amp.autocast(False): output = ds_linear(input) assert output.dtype == ds_linear.weight.dtype +@pytest.mark.skipif(get_accelerator().amp() is None, reason='amp is not installed') @pytest.mark.parametrize('half_input, half_weight', [(False, False), @@ -44,11 +55,12 @@ def test_disable_autocast_linear(tmpdir, half_op): (True, True)]) def test_autocast_linear(tmpdir, half_input, half_weight): - amp = pytest.importorskip("torch.cuda.amp") + amp = get_accelerator().amp() hidden_dim = 4 - input = torch.randn(hidden_dim).cuda() - ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() + input = torch.randn(hidden_dim).to(get_accelerator().device_name()) + ds_linear = LinearModuleForZeroStage3(hidden_dim, + hidden_dim).to(get_accelerator().device_name()) if half_input: input = input.half() @@ -58,4 +70,4 @@ def test_autocast_linear(tmpdir, half_input, half_weight): with amp.autocast(): output = ds_linear(input) - assert output.dtype == torch.half + assert output.dtype == torch.half or output.dtype == torch.bfloat16 diff --git a/tests/unit/runtime/test_data.py b/tests/unit/runtime/test_data.py index e87f6c5e96f0..5e47eaa7ff7e 100644 --- a/tests/unit/runtime/test_data.py +++ b/tests/unit/runtime/test_data.py @@ -2,6 +2,7 @@ import torch import pytest import deepspeed +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest from unit.simple_model import SimpleModel, random_dataset @@ -49,8 +50,8 @@ def test(self, train_batch_size, drop_last): training_data=train_dataset, optimizer=optimizer) for n, batch in enumerate(training_dataloader): - x = batch[0].to(torch.cuda.current_device()) - y = batch[1].to(torch.cuda.current_device()) + x = batch[0].to(get_accelerator().current_device_name()) + y = batch[1].to(get_accelerator().current_device_name()) loss = model(x, y) model.backward(loss) model.step() diff --git a/tests/unit/runtime/test_ds_config_dict.py b/tests/unit/runtime/test_ds_config_dict.py index dafe5ab674e5..0b579d0a4aa2 100644 --- a/tests/unit/runtime/test_ds_config_dict.py +++ b/tests/unit/runtime/test_ds_config_dict.py @@ -1,10 +1,10 @@ # A test on its own -import torch import pytest import json import argparse from deepspeed.runtime.zero.config import DeepSpeedZeroConfig +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest, get_test_path from unit.simple_model import SimpleModel, create_config_from_dict, random_dataloader @@ -33,7 +33,7 @@ def base_config(): def test_cuda(): - assert (torch.cuda.is_available()) + assert (get_accelerator().is_available()) def test_check_version(): diff --git a/tests/unit/runtime/test_runtime_utils.py b/tests/unit/runtime/test_runtime_utils.py index 33f40ad30a06..0353cfe2530e 100644 --- a/tests/unit/runtime/test_runtime_utils.py +++ b/tests/unit/runtime/test_runtime_utils.py @@ -5,6 +5,7 @@ import deepspeed.runtime.utils as ds_utils import deepspeed.utils.groups as groups +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest @@ -36,10 +37,11 @@ def test(self): groups._create_expert_and_data_parallel(2) norm = ds_utils.clip_grad_norm_(parameters, max_norm=0.1) - norm = torch.Tensor([norm]).to(dist.get_rank()) - + norm = torch.Tensor([norm]).to(get_accelerator().device_name(dist.get_rank())) world_size = dist.get_world_size() - gathered_norm = [torch.zeros(1).cuda() for i in range(world_size)] + gathered_norm = [ + torch.zeros(1).to(get_accelerator().device_name()) for i in range(world_size) + ] dist.all_gather(gathered_norm, norm) diff --git a/tests/unit/runtime/utils/test_partition.py b/tests/unit/runtime/utils/test_partition.py index e5e6ed14c586..d9aa244d213b 100644 --- a/tests/unit/runtime/utils/test_partition.py +++ b/tests/unit/runtime/utils/test_partition.py @@ -7,6 +7,7 @@ from deepspeed.runtime.utils import partition_balanced from deepspeed.runtime.utils import prefix_sum_inc from deepspeed.runtime.utils import PartitionedTensor +from deepspeed.accelerator import get_accelerator from unit.common import DistributedTest @@ -23,7 +24,7 @@ def test(self): rows = world * 4 cols = 3 - full = torch.rand(rows, cols).cuda() + full = torch.rand(rows, cols).to(get_accelerator().device_name()) dist.broadcast(full, src=0, group=group) part = PartitionedTensor(full, group=group) @@ -46,7 +47,7 @@ def test(self): rows = world * 7 cols = 3 - full = torch.rand(rows, cols).cuda() + full = torch.rand(rows, cols).to(get_accelerator().device_name()) dist.broadcast(full, src=0, group=group) part = PartitionedTensor(full, group=group) diff --git a/tests/unit/runtime/zero/test_ignore_unused_parameters.py b/tests/unit/runtime/zero/test_ignore_unused_parameters.py index 329a221bb826..3f2e893f6989 100644 --- a/tests/unit/runtime/zero/test_ignore_unused_parameters.py +++ b/tests/unit/runtime/zero/test_ignore_unused_parameters.py @@ -1,7 +1,7 @@ import pytest from unit.common import DistributedTest from unit.simple_model import UnusedParametersModel, random_dataloader -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.accelerator import get_accelerator import deepspeed @@ -13,7 +13,8 @@ class TestStage2IgnoreUnusedParameters(DistributedTest): def test(self, ignore_unused_parameters): use_cpu_offload = True - if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[ + get_accelerator().create_op_builder("CPUAdamBuilder").name]: pytest.skip("cpu-adam is not compatible") config_dict = { diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index f9715bd1dc8e..277042baf0a7 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -16,6 +16,7 @@ from deepspeed.runtime.engine import DeepSpeedEngine from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint +from deepspeed.accelerator import get_accelerator def run_unbalanced_gradients(model, data_loader): @@ -688,30 +689,30 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: grad_multiplier = 1 if zero_grad else (train_iter + 1) if dist.get_rank() == 0: assert torch.allclose( - dloss_wrt_layer3.cuda(), + dloss_wrt_layer3.to(get_accelerator().device_name()), grad_multiplier * create_tensor([2] * 8, torch.float)) assert torch.allclose( - dloss_wrt_layer2.cuda(), + dloss_wrt_layer2.to(get_accelerator().device_name()), grad_multiplier * create_tensor([3 * 1] * 8, torch.float)) assert torch.allclose( - dloss_wrt_layer1.cuda(), + dloss_wrt_layer1.to(get_accelerator().device_name()), grad_multiplier * create_tensor([3 * 2 * 1] * 8, torch.float)) elif dist.get_rank() == 1: # parameters dont split evenly across ranks so rank 1 has a zero-padded # partition assert torch.allclose( - dloss_wrt_layer3.cuda(), + dloss_wrt_layer3.to(get_accelerator().device_name()), grad_multiplier * create_tensor(([8] * 7) + [0], torch.float)) assert torch.allclose( - dloss_wrt_layer2.cuda(), + dloss_wrt_layer2.to(get_accelerator().device_name()), grad_multiplier * create_tensor(([6 * 2] * 7) + [0], torch.float)) assert torch.allclose( - dloss_wrt_layer1.cuda(), + dloss_wrt_layer1.to(get_accelerator().device_name()), grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], torch.float)) else: @@ -1118,28 +1119,28 @@ def create_tensor(vals): grad_multiplier = 1 if zero_grad else (train_iter + 1) if dist.get_rank() == 0: assert torch.allclose( - dloss_wrt_layer3.cuda(), + dloss_wrt_layer3.to(get_accelerator().device_name()), grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype)) assert torch.allclose( - dloss_wrt_layer2.cuda(), + dloss_wrt_layer2.to(get_accelerator().device_name()), grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype)) assert torch.allclose( - dloss_wrt_layer1.cuda(), + dloss_wrt_layer1.to(get_accelerator().device_name()), grad_multiplier * create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype)) elif dist.get_rank() == 1: # parameters dont split evenly across ranks so rank 1 has a zero-padded # partition assert torch.allclose( - dloss_wrt_layer3.cuda(), + dloss_wrt_layer3.to(get_accelerator().device_name()), grad_multiplier * create_tensor(([8] * 7) + [0]).to(expected_grad_dtype)) assert torch.allclose( - dloss_wrt_layer2.cuda(), + dloss_wrt_layer2.to(get_accelerator().device_name()), grad_multiplier * create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype)) assert torch.allclose( - dloss_wrt_layer1.cuda(), + dloss_wrt_layer1.to(get_accelerator().device_name()), grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype)) else: diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index d45beb3a618f..9dc7a668cd04 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -5,6 +5,7 @@ import pytest import deepspeed +from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape import deepspeed.comm as dist @@ -359,7 +360,7 @@ def test_subclass_param_init(): assert model.param.ds_status == ZeroParamStatus.NOT_AVAILABLE # test that the weights manipulation during each __init__ worked in all w/o needing gathering - ones = torch.ones(5).half().cuda() + ones = torch.ones(5).half().to(get_accelerator().device_name()) with deepspeed.zero.GatheredParameters(list(model.parameters(recurse=False))): assert torch.equal(model.param, ones + 1) assert torch.equal(model.param_pa, ones + 2) diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index 96f1927cda92..a5e9e38f340e 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -5,6 +5,7 @@ from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.moe.layer import MoE +from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist @@ -272,7 +273,7 @@ def create_deepspeed_args(): args.deepspeed = True if dist.is_initialized(): # We assume up to one full node executing unit tests - assert dist.get_world_size() <= torch.cuda.device_count() + assert dist.get_world_size() <= get_accelerator().device_count() args.local_rank = dist.get_rank() return args diff --git a/tests/unit/utils/test_init_on_device.py b/tests/unit/utils/test_init_on_device.py index 46d179b439cf..9a4338cc0db7 100644 --- a/tests/unit/utils/test_init_on_device.py +++ b/tests/unit/utils/test_init_on_device.py @@ -3,9 +3,10 @@ from unit.simple_model import SimpleModel from deepspeed import OnDevice from packaging import version as pkg_version +from deepspeed.accelerator import get_accelerator -@pytest.mark.parametrize('device', ['meta', 'cuda:0']) +@pytest.mark.parametrize('device', ['meta', get_accelerator().device_name(0)]) def test_on_device(device): if device == "meta" and pkg_version.parse( torch.__version__) < pkg_version.parse("1.10"):