Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0a849d5
[device abstraction] add device abstraction to allow other device tha…
delock Aug 16, 2022
e4f40f0
Merge branch '202208-base' into 202208
delock Aug 24, 2022
4a216ea
[rebase-202208] additional changes needed when rebase to 202208
delock Aug 24, 2022
2137642
Merge branch '20220824-base' into 20220824
delock Aug 24, 2022
089657e
[rebase] cleanup direct cuda usage after merge
delock Aug 24, 2022
d5a8424
[precommit] fix pre-commit issues
delock Aug 25, 2022
96d0765
Merge branch 'master' into gma/device-abstraction
tjruwase Aug 30, 2022
ac64c7a
[pin_memory] make pin_memory select device type
delock Sep 1, 2022
02c3a57
Merge branch 'master' into gma/device-abstraction
delock Sep 8, 2022
522b24b
[downstream] merge from xpu support downstream
delock Sep 9, 2022
a3b1e02
Merge branch 'master' into gma/device-abstraction
tjruwase Sep 12, 2022
4557c33
Merge branch 'master' into gma/device-abstraction
tjruwase Sep 13, 2022
2ef7d6c
Merge branch 'up-master' into gma/merge-upstream-20220921
delock Sep 21, 2022
9656321
[device] port cuda device to literal_device() in new tests
delock Sep 21, 2022
65729e3
[accel_runtime] add pin_memory to accelerator runtime interface.
delock Sep 22, 2022
f94d53e
[accelerator abstraction] merge from #2320
delock Sep 26, 2022
6005abe
Merge branch 'up-master' into gma/device-abstraction
delock Sep 26, 2022
31c0997
change call site of literal_device, on_accel_device and accel_runtime…
delock Oct 12, 2022
1785c26
add new interface definition from olruwase/accelerator_abstraction
delock Oct 12, 2022
17203a4
[accelerator abstraction] remove name() from interface, device_name()…
delock Oct 14, 2022
e8daea6
merge with master (ec13da6ba7cabc44bb4745a64a208b8580792954)
delock Oct 14, 2022
cfd23ed
Merge branch 'up-master' into gma/device-abstraction
delock Oct 14, 2022
13bbbdf
[OpBuilder] Add op builder abstraction
delock Oct 23, 2022
06e39a5
Merge branch 'up-master' into gma/device-abstraction
delock Oct 23, 2022
257490f
convert op builder usage in merged code
delock Oct 23, 2022
c93b999
[OpBuilder] add create_op_builder interface in abstract_accelerator.py
delock Oct 23, 2022
9858d42
[OpBuilder] fix op builder usage in tests
delock Oct 23, 2022
68ce006
[OpBuilder] fix <op builder>.NAME usage in tests to follow op builder…
delock Oct 23, 2022
4b62dab
import get_accelerator from deepspeed.accelerator directly
delock Oct 23, 2022
c5b2070
[OpBuilder] remove unused function and sync with main
delock Oct 23, 2022
9532843
add missing get_accelerator import
delock Oct 25, 2022
0729695
fix obsolete name in CPU Adam which should be create_op_builder
delock Oct 25, 2022
be517d8
fix create_op_builder calls
delock Oct 25, 2022
3af870f
fix misuse of new accelerator abstraction interface in tests
delock Oct 25, 2022
8fa64b9
Merge from downstream for bug fixing
delock Oct 28, 2022
4873538
merge from downstream
delock Nov 3, 2022
61b10b0
remove SYCL_KERNEL specific code
delock Nov 4, 2022
b9c7c79
don't gather partitioned activations for mp size 1 (#2454)
guoyejun Nov 4, 2022
b9211e8
Merge pull request #1 from guoyejun/forgma
delock Nov 10, 2022
7ebaeaa
stage_1_and_2.py: no allreduce needed when mp size is 1 (#2494)
guoyejun Nov 10, 2022
996655c
Merge pull request #2 from guoyejun/forgma
delock Nov 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator

import time

Expand Down Expand Up @@ -83,16 +84,20 @@ def run_all_gather(local_rank, args):
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
get_accelerator().empty_cache()
output = torch.zeros(input.nelement() * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
Expand Down Expand Up @@ -121,15 +126,17 @@ def run_all_gather(local_rank, args):
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
args.dtype)).to(
get_accelerator().device_name(local_rank))
# multiply each GPU's tensor by the rank to ease debugging
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(elements_per_gpu * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
get_accelerator().empty_cache()
output = torch.zeros(
elements_per_gpu * world_size,
dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
Expand Down
10 changes: 7 additions & 3 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator

import time

Expand Down Expand Up @@ -62,8 +63,10 @@ def run_all_reduce(local_rank, args):
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
Expand All @@ -86,7 +89,8 @@ def run_all_reduce(local_rank, args):
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
args.dtype)).to(
get_accelerator().device_name(local_rank))
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
Expand Down
19 changes: 12 additions & 7 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator

import time

Expand Down Expand Up @@ -61,8 +62,10 @@ def run_all_to_all(local_rank, args):
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks"
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
Expand All @@ -86,15 +89,17 @@ def run_all_to_all(local_rank, args):
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
args.dtype)).to(
get_accelerator().device_name(local_rank))
assert mat.numel() % world_size == 0, f"tensor with {mat.numel()} elements cannot be divided in {world_size} chunks"
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
get_accelerator().empty_cache()
output = torch.zeros(
elements_per_gpu,
dtype=getattr(torch,
args.dtype)).to(get_accelerator().device_name(local_rank))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
Expand Down
10 changes: 7 additions & 3 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator

import time

Expand Down Expand Up @@ -63,8 +64,10 @@ def run_broadcast(local_rank, args):
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
Expand All @@ -87,7 +90,8 @@ def run_broadcast(local_rank, args):
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
args.dtype)).to(
get_accelerator().device_name(local_rank))
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/communication/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from deepspeed.accelerator import get_accelerator

DEFAULT_WARMUPS = 5
DEFAULT_TRIALS = 50
DEFAULT_TYPE = 'float'
DEFAULT_BACKEND = 'nccl'
DEFAULT_BACKEND = get_accelerator().communication_backend_name()
DEFAULT_UNIT = 'Gbps'
DEFAULT_DIST = 'deepspeed'
DEFAULT_MAXSIZE = 24
10 changes: 7 additions & 3 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator

import time

Expand Down Expand Up @@ -81,8 +82,10 @@ def run_pt2pt(local_rank, args):
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
dtype=getattr(
torch,
args.dtype)).to(
get_accelerator().device_name(local_rank))
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
Expand All @@ -105,7 +108,8 @@ def run_pt2pt(local_rank, args):
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
args.dtype)).to(
get_accelerator().device_name(local_rank))
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
Expand Down
13 changes: 7 additions & 6 deletions benchmarks/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import argparse
from benchmarks.communication.constants import *
from deepspeed.accelerator import get_accelerator

global dist

Expand All @@ -12,7 +13,7 @@ def init_torch_distributed(backend):
import torch.distributed as dist
torch.distributed.init_process_group(backend)
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
get_accelerator().set_device(local_rank)


def init_deepspeed_comm(backend):
Expand All @@ -21,7 +22,7 @@ def init_deepspeed_comm(backend):
import deepspeed.comm as dist
deepspeed.init_distributed(dist_backend=backend)
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
get_accelerator().set_device(local_rank)


def init_processes(local_rank, args):
Expand Down Expand Up @@ -99,14 +100,13 @@ def get_metric_strings(args, tput, busbw, duration):


def sync_all():
torch.cuda.synchronize()
get_accelerator().synchronize()
dist.barrier()


def max_numel(comm_op, dtype, mem_factor, local_rank, args):
dtype_size = _element_size(dtype)
max_memory_per_gpu = torch.cuda.get_device_properties(
local_rank).total_memory * mem_factor
max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor
if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
elif comm_op == 'all_gather':
Expand Down Expand Up @@ -183,7 +183,8 @@ def benchmark_parser():
parser.add_argument("--backend",
type=str,
default=DEFAULT_BACKEND,
choices=['nccl'],
choices=['nccl',
'ccl'],
help='Communication library to use')
parser.add_argument("--dist",
type=str,
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/inference/bert-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import deepspeed
import argparse
from transformers import pipeline
from deepspeed.accelerator import get_accelerator

parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", type=str, help="hf model name")
Expand Down Expand Up @@ -44,7 +45,7 @@ def print_latency(latency_set, title, warmup=3):
print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000))


deepspeed.init_distributed("nccl")
deepspeed.init_distributed()

print(args.model, args.max_tokens, args.dtype)

Expand Down Expand Up @@ -72,10 +73,10 @@ def print_latency(latency_set, title, warmup=3):
times = []
mtimes = []
for i in range(args.trials):
torch.cuda.synchronize()
get_accelerator().synchronize()
start = time.time()
r = pipe("Hello I'm a [MASK] model")
torch.cuda.synchronize()
get_accelerator().synchronize()
end = time.time()
responses.append(r)
times.append((end - start))
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/inference/gpt-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import deepspeed
import argparse
from transformers import pipeline
from deepspeed.accelerator import get_accelerator

parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", type=str, help="hf model name")
Expand Down Expand Up @@ -61,7 +62,7 @@ def print_latency(latency_set, title, warmup=3):
print("\t999 Latency: {0:8.2f} ms".format(p999 * 1000))


deepspeed.init_distributed("nccl")
deepspeed.init_distributed()

if args.local_rank == 0:
print("BENCHMARK SETTINGS:")
Expand Down Expand Up @@ -101,10 +102,10 @@ def print_latency(latency_set, title, warmup=3):
times = []
mtimes = []
for i in range(args.trials):
torch.cuda.synchronize()
get_accelerator().synchronize()
start = time.time()
r = pipe("DeepSpeed is", do_sample=False, max_new_tokens=args.max_tokens)
torch.cuda.synchronize()
get_accelerator().synchronize()
end = time.time()
responses.append(r)
times.append(end - start) # / (args.max_tokens - 3))
Expand Down
4 changes: 2 additions & 2 deletions csrc/aio/py_test/aio_bench_perf_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from test_ds_aio_utils import refine_integer_value
from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \
READ_IO_DIR, WRITE_IO_DIR, READ_LOG_DIR, WRITE_LOG_DIR
from deepspeed.accelerator import get_accelerator

OTHER_OPTIONS = '--handle'
PERF_SCRIPT = 'test_ds_aio.py'
Expand Down Expand Up @@ -277,8 +278,7 @@ def script_path():


def async_io_setup():
from deepspeed.ops.aio import AsyncIOBuilder
return AsyncIOBuilder().is_compatible()
return get_accelerator().create_op_builder("AsyncIOBuilder").is_compatible()


def get_block_size_and_count(io_bytes):
Expand Down
37 changes: 21 additions & 16 deletions csrc/aio/py_test/ds_aio_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
import os
import time
from deepspeed.ops.aio import AsyncIOBuilder
from multiprocessing import Pool, Barrier
from test_ds_aio_utils import report_results, task_log, task_barrier
from deepspeed.accelerator import get_accelerator


def pre_basic(args, tid, read_op):
Expand All @@ -19,7 +19,10 @@ def pre_basic(args, tid, read_op):
file = args.read_file if read_op else f'{args.write_file}.{tid}'

task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
buffer = get_accelerator().pin_memory(
torch.empty(num_bytes,
dtype=torch.uint8,
device='cpu'))
task_log(
tid,
f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}'
Expand Down Expand Up @@ -56,13 +59,14 @@ def post_basic(pool_params):
def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_read(
ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time

Expand All @@ -72,13 +76,14 @@ def main_basic_read(pool_params):
def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
get_accelerator().create_op_builder("AsyncIOBuilder").load().aio_write(
ctxt['buffer'],
ctxt['file'],
args.block_size,
args.queue_depth,
args.single_submit,
args.overlap_events,
args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time

Expand Down
Loading