Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: script
exclude: ^(deepspeed/comm/|docs/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: https://github.com/codespell-project/codespell
Expand Down
65 changes: 65 additions & 0 deletions benchmarks/communication/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Running Communication Benchmarks


To run benchmarks, there are two options:

1. Run a single communication operation:

For example, run with a single large message size:
<pre>
deepspeed all_reduce.py
</pre>

Scan across message sizes:
<pre>
deepspeed all_reduce.py --scan
</pre>

Each individual communication operation's benchmarks have separate benchmarking options. For `all_reduce.py`, for example:

<pre>
usage: ds_bench [-h] [--local_rank LOCAL_RANK] [--trials TRIALS] [--warmup WARMUP] [--maxsize MAXSIZE] [--async-op] [--bw-unit {Gbps,GBps}] [--backend {nccl}] [--dist {deepspeed,torch}] [--scan] [--dtype DTYPE] [--mem-factor MEM_FACTOR] [--debug]

optional arguments:
-h, --help show this help message and exit
--local_rank LOCAL_RANK
--trials TRIALS Number of timed iterations
--warmup WARMUP Number of warmup (non-timed) iterations
--maxsize MAXSIZE Max message size as a power of 2
--async-op Enables non-blocking communication
--bw-unit {Gbps,GBps}
--backend {nccl} Communication library to use
--dist {deepspeed,torch}
Distributed DL framework to use
--scan Enables scanning all message sizes
--dtype DTYPE PyTorch tensor dtype
--mem-factor MEM_FACTOR
Proportion of max available GPU memory to use for single-size evals
--debug Enables alltoall debug prints
</pre>

2. Run all available communication benchmarks:

<pre>
deepspeed run_all.py
</pre>

Like the individual benchmarks, `run_all.py` supports scanning arguments for the max message size, bw-unit, etc. Simply pass the desired arguments to `run_all.py` and they'll be propagated to each comm op.

Note that `ds_bench` is a pre-packaged wrapper around `run_all.py`. Users can pass the same arguments as well:

<pre>
<path to deepspeed>/bin/ds_bench --scan --trials=10
</pre>


# Adding Communication Benchmarks

To add new communication benchmarks, follow this general procedure:

1. Copy a similar benchmark file (e.g. to add `reduce_scatter`, copy `all_reduce.py` as a template)
2. Add a new bw formula in `utils.get_bw`
3. Add a new maximum tensor element formula in `utils.max_numel`
4. Replace comm op calls in new file with find-replace
5. Find a good default `mem_factor` for use in `run_<collective>_single()` function
6. Add new comm op to `run_all.py`
Empty file.
153 changes: 153 additions & 0 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *

import time
import argparse
import os

import math


# Run allgather and print metrics
def timed_allgather(input, output, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre

# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('allgather', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'

print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)


def run_allgather(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

# Prepare benchmark header
print_header(args, 'allgather')
global_rank = dist.get_rank()
world_size = dist.get_world_size()

if args.scan:
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)

sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(input.nelement() * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_allgather(input, output, args)
else:
# all_gather_base saves memory
if (args.dist == 'torch'
and hasattr(torch.distributed,
"_all_gather_base")) or (args.dist == 'deepspeed'
and dist.has_allgather_base):
mem_factor = args.mem_factor + 0.2
else:
mem_factor = args.mem_factor
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
sync_all()
elements_per_gpu = max_numel(comm_op='allgather',
dtype=getattr(torch,
args.dtype),
mem_factor=mem_factor,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
# multiply each GPU's tensor by the rank to ease debugging
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(elements_per_gpu * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return

sync_all()
timed_allgather(input, output, args)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_allgather(local_rank=rank, args=args)
109 changes: 109 additions & 0 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *

import time
import argparse
import os
import math


def timed_allreduce(input, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
dist.all_reduce(input, async_op=args.async_op)
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre

# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('allreduce', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'

print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)


def run_allreduce(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

# Prepare benchmark header
print_header(args, 'allreduce')

world_size = dist.get_world_size()
global_rank = dist.get_rank()

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)

sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_allreduce(input, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
elements_per_gpu = max_numel(comm_op='allreduce',
dtype=getattr(torch,
args.dtype),
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return
sync_all()
timed_allreduce(input, args)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_allreduce(local_rank=rank, args=args)
Loading