Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a04480e
Fix the half-precision version of CPU-Adam (#2032)
RezaYazdaniAminabadi Jun 23, 2022
7bae53d
Fix for AMD unit tests (#2047)
mrwyattii Jun 23, 2022
38a00be
correct partition_id in fp32 param -> fp16 param for MoE+z2 (#2058)
siddharth9820 Jun 27, 2022
76ea053
Fix missing import in replace_module.py (#2050)
aphedges Jun 29, 2022
9b70ce5
Comms Benchmarks (#2040)
Quentin-Anthony Jun 29, 2022
9fc4e5f
add ds inference paper (#2072)
jeffra Jul 6, 2022
9305916
Comments for better understanding of zero stage1_2 (#2027)
kisseternity Jul 6, 2022
559fb8e
[docs] fix broken read-the-docs build (#2075)
jeffra Jul 6, 2022
3540ce7
Check for bf16 support only if CUDA is available (#2049)
aphedges Jul 6, 2022
b3388e1
Fix partition id in the fp32->fp16 param copying step for z2+cpu-offl…
siddharth9820 Jul 7, 2022
50a652e
Codeowner addendum and fix to small model debugging script (#2076)
samadejacobs Jul 8, 2022
0ad0860
remove require grad in params count (#2065)
cli99 Jul 13, 2022
db3252b
Add missing newline for ZeroOneAdam parameter table (#2088)
manuelciosici Jul 13, 2022
b052378
fixed "None type has no len()" (#2091)
xiazeyu Jul 13, 2022
c1af73f
Improving memory utilization of Z2+MoE (#2079)
siddharth9820 Jul 13, 2022
2feaf6d
bump to 0.6.7
jeffra Jul 18, 2022
aa88137
Add Inference support for running the BigScience-BLOOM Architecture (…
RezaYazdaniAminabadi Jul 18, 2022
16699d8
[ds-inference] checkpoint loading => tqdm (#2107)
stas00 Jul 19, 2022
9027f86
Dont overwrite hook handles in flop profiler (#2106)
Sanger2000 Jul 19, 2022
ee7ea3b
use HF NeoX (#2087)
mrwyattii Jul 19, 2022
6b9df56
bump to 0.7.0
jeffra Jul 19, 2022
0f4f2f9
Adding DeepSpeed Compression Composer (#2105)
yaozhewei Jul 19, 2022
69b7c97
Remove hardcoded ROCm install path (#2093)
mrwyattii Jul 19, 2022
b4513f6
fix softmax dim of Residual MoE in moe/layer.py (#2110)
Jul 20, 2022
844d9f3
reduce ds-inference log verbosity (#2111)
jeffra Jul 20, 2022
46fffc0
DeepSpeed Compression announcement (#2114)
conglongli Jul 20, 2022
80d0a32
Checkpoint reshaping (#1953)
tjruwase Jul 20, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/amd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
# Runs a set of commands using the runners shell
- name: Install deepspeed
run: |
sudo /opt/conda/bin/pip install .[dev,1bit,autotuning]
pip install .[dev,1bit,autotuning]
#python -c "from deepspeed.env_report import cli_main; cli_main()"
ds_report

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: script
exclude: ^(deepspeed/comm/|docs/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: https://github.com/codespell-project/codespell
Expand Down
2 changes: 1 addition & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft
* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft @samadejacobs
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ Remove until pypi issue is resolved: https://status.python.org/incidents/2jj696s
[![Downloads](https://pepy.tech/badge/deepspeed/month)](https://pepy.tech/project/deepspeed)
-->
## Latest News
* [2022/06/22] DeepSpeed Compression: 50x model size reduction via [XTC](https://arxiv.org/abs/2206.01859) and 5000x compression cost reduction via [ZeroQuant](https://arxiv.org/abs/2206.01861). Stay tuned for upcoming code release!
* [2022/07/20] [DeepSpeed Compression: A composable library for extreme compression and zero-cost quantization](https://www.microsoft.com/en-us/research/blog/deepspeed-compression-a-composable-library-for-extreme-compression-and-zero-cost-quantization/)
* [Tutorial](https://www.deepspeed.ai/tutorials/model-compression/) and [Code examples](https://github.com/microsoft/DeepSpeedExamples/tree/master/model_compression).
* 50x model size reduction via [XTC](https://arxiv.org/abs/2206.01859) and 5000x compression cost reduction via [ZeroQuant](https://arxiv.org/abs/2206.01861).
* [2022/03/21] [Supporting efficient large model training on AMD Instinct GPUs with DeepSpeed](https://cloudblogs.microsoft.com/opensource/2022/03/21/supporting-efficient-large-model-training-on-amd-instinct-gpus-with-deepspeed/)
* [2022/03/07] [Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/)
* [2022/01/19] [DeepSpeed: Advancing MoE inference and training to power next-generation AI scale](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/)
Expand Down Expand Up @@ -227,6 +229,8 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
11. Shaden Smith, Mostofa Patwary, Brandon Norick, Patrick LeGresley, Samyam Rajbhandari, Jared Casper, Zhun Liu, Shrimai Prabhumoye, George Zerveas, Vijay Korthikanti, Elton Zhang, Rewon Child, Reza Yazdani Aminabadi, Julie Bernauer, Xia Song, Mohammad Shoeybi, Yuxiong He, Michael Houston, Saurabh Tiwary, Bryan Catanzaro. (2022) Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model [arXiv:2201.11990](https://arxiv.org/abs/2201.11990).
12. Xiaoxia Wu, Zhewei Yao, Minjia Zhang, Conglong Li, Yuxiong He. (2022) Extreme Compression for Pre-trained Transformers Made Simple and Efficient. [arXiv:2206.01859](https://arxiv.org/abs/2206.01859).
13. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, Yuxiong He. (2022) ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers. [arXiv:2206.01861](https://arxiv.org/abs/2206.01861).
14. Reza Yazdani Aminabadi, Samyam Rajbhandari, Minjia Zhang, Ammar Ahmad Awan, Cheng Li, Du Li, Elton Zheng, Jeff Rasley, Shaden Smith, Olatunji Ruwase, Yuxiong He. (2022) DeepSpeed Inference: Enabling Efficient Inference of Transformer Models at Unprecedented Scale. [arXiv:2207.00032](https://arxiv.org/abs/2207.00032).


# Videos
1. DeepSpeed KDD 2020 Tutorial
Expand Down
65 changes: 65 additions & 0 deletions benchmarks/communication/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Running Communication Benchmarks


To run benchmarks, there are two options:

1. Run a single communication operation:

For example, run with a single large message size:
<pre>
deepspeed all_reduce.py
</pre>

Scan across message sizes:
<pre>
deepspeed all_reduce.py --scan
</pre>

Each individual communication operation's benchmarks have separate benchmarking options. For `all_reduce.py`, for example:

<pre>
usage: ds_bench [-h] [--local_rank LOCAL_RANK] [--trials TRIALS] [--warmup WARMUP] [--maxsize MAXSIZE] [--async-op] [--bw-unit {Gbps,GBps}] [--backend {nccl}] [--dist {deepspeed,torch}] [--scan] [--dtype DTYPE] [--mem-factor MEM_FACTOR] [--debug]

optional arguments:
-h, --help show this help message and exit
--local_rank LOCAL_RANK
--trials TRIALS Number of timed iterations
--warmup WARMUP Number of warmup (non-timed) iterations
--maxsize MAXSIZE Max message size as a power of 2
--async-op Enables non-blocking communication
--bw-unit {Gbps,GBps}
--backend {nccl} Communication library to use
--dist {deepspeed,torch}
Distributed DL framework to use
--scan Enables scanning all message sizes
--dtype DTYPE PyTorch tensor dtype
--mem-factor MEM_FACTOR
Proportion of max available GPU memory to use for single-size evals
--debug Enables alltoall debug prints
</pre>

2. Run all available communication benchmarks:

<pre>
deepspeed run_all.py
</pre>

Like the individual benchmarks, `run_all.py` supports scanning arguments for the max message size, bw-unit, etc. Simply pass the desired arguments to `run_all.py` and they'll be propagated to each comm op.

Note that `ds_bench` is a pre-packaged wrapper around `run_all.py`. Users can pass the same arguments as well:

<pre>
<path to deepspeed>/bin/ds_bench --scan --trials=10
</pre>


# Adding Communication Benchmarks

To add new communication benchmarks, follow this general procedure:

1. Copy a similar benchmark file (e.g. to add `reduce_scatter`, copy `all_reduce.py` as a template)
2. Add a new bw formula in `utils.get_bw`
3. Add a new maximum tensor element formula in `utils.max_numel`
4. Replace comm op calls in new file with find-replace
5. Find a good default `mem_factor` for use in `run_<collective>_single()` function
6. Add new comm op to `run_all.py`
Empty file.
153 changes: 153 additions & 0 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *

import time
import argparse
import os

import math


# Run allgather and print metrics
def timed_allgather(input, output, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
# use all_gather_base if available
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output_tensor,
cdb.get_world_size(group)))
dist.all_gather(output_tensors, input_tensor, group=group, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre

# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('allgather', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'

print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)


def run_allgather(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

# Prepare benchmark header
print_header(args, 'allgather')
global_rank = dist.get_rank()
world_size = dist.get_world_size()

if args.scan:
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)

sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(input.nelement() * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_allgather(input, output, args)
else:
# all_gather_base saves memory
if (args.dist == 'torch'
and hasattr(torch.distributed,
"_all_gather_base")) or (args.dist == 'deepspeed'
and dist.has_allgather_base):
mem_factor = args.mem_factor + 0.2
else:
mem_factor = args.mem_factor
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
sync_all()
elements_per_gpu = max_numel(comm_op='allgather',
dtype=getattr(torch,
args.dtype),
mem_factor=mem_factor,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
# multiply each GPU's tensor by the rank to ease debugging
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
torch.cuda.empty_cache()
output = torch.zeros(elements_per_gpu * world_size,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return

sync_all()
timed_allgather(input, output, args)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_allgather(local_rank=rank, args=args)
109 changes: 109 additions & 0 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from benchmarks.communication.utils import *
from benchmarks.communication.constants import *

import time
import argparse
import os
import math


def timed_allreduce(input, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

sync_all()
# Warmup, establish connections, etc.
for i in range(args.warmup):
dist.all_reduce(input, async_op=args.async_op)
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
sync_all()
duration = time.perf_counter() - pre

# maintain and clean performance data
avg_duration = duration / args.trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw('allreduce', size, avg_duration, args)
tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'

print_rank_0(
f"{convert_size(size):<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}"
)


def run_allreduce(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
import deepspeed.comm as dist

# Prepare benchmark header
print_header(args, 'allreduce')

world_size = dist.get_world_size()
global_rank = dist.get_rank()

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
M_LIST.append(x)

sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(world_size,
M,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
sync_all()
timed_allreduce(input, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
elements_per_gpu = max_numel(comm_op='allreduce',
dtype=getattr(torch,
args.dtype),
mem_factor=args.mem_factor * 2,
local_rank=local_rank,
args=args)
try:
mat = torch.ones(elements_per_gpu,
dtype=getattr(torch,
args.dtype)).cuda(local_rank)
input = ((mat.mul_(float(global_rank))).view(-1))
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print(
'WARNING: Ran out of GPU memory. Try to reduce the --mem-factor argument!'
)
sync_all()
return
sync_all()
timed_allreduce(input, args)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
run_allreduce(local_rank=rank, args=args)
Loading