diff --git a/.github/workflows/amd.yml b/.github/workflows/amd.yml
index 91e8825b83e1..6f9eae41cd45 100644
--- a/.github/workflows/amd.yml
+++ b/.github/workflows/amd.yml
@@ -38,6 +38,10 @@ jobs:
sudo apt-get update
sudo apt-get install -y libaio-dev
+ - name: Python environment
+ run: |
+ pip list
+
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
@@ -59,5 +63,5 @@ jobs:
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
- TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 -m 'not sequential' unit/
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -x -m 'sequential' unit/
diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml
new file mode 100644
index 000000000000..fbb4a54f59a3
--- /dev/null
+++ b/.github/workflows/nv-inference.yml
@@ -0,0 +1,63 @@
+name: nv-inference
+
+on:
+ push:
+ branches:
+ - 'master'
+ - 'staging**'
+ paths-ignore:
+ - 'docs/**'
+ pull_request:
+ paths-ignore:
+ - 'docs/**'
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ unit-tests:
+ runs-on: [self-hosted, nvidia, cu111, v100]
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: environment
+ run: |
+ nvidia-smi
+ which python
+ python --version
+ which nvcc
+ nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
+ pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
+ python -c "import torch; print('torch:', torch.__version__, torch)"
+ python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ - name: Install transformers
+ run: |
+ git clone https://github.com/huggingface/transformers
+ cd transformers
+ # if needed switch to the last known good SHA until transformers@master is fixed
+ # git checkout 1cc453d33
+ git rev-parse --short HEAD
+ pip uninstall --yes transformers
+ pip install .
+
+ - name: Python environment
+ run: |
+ pip list
+
+ - name: Install deepspeed
+ run: |
+ pip uninstall --yes deepspeed
+ pip install .[dev,1bit,autotuning,sparse_attn,inf]
+ ds_report
+
+ - name: Unit tests
+ run: |
+ unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
+ if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
+ cd tests
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'inference' unit/ --torch_ver="1.8" --cuda_ver="11.1"
diff --git a/.github/workflows/nv-lightning-v100.yml b/.github/workflows/nv-lightning-v100.yml
index 6caceb48f2c4..bfdf9bc06eab 100644
--- a/.github/workflows/nv-lightning-v100.yml
+++ b/.github/workflows/nv-lightning-v100.yml
@@ -17,7 +17,7 @@ concurrency:
jobs:
unit-tests:
- runs-on: [self-hosted, nvidia, torch18, v100]
+ runs-on: [self-hosted, nvidia, cu111, v100]
steps:
- uses: actions/checkout@v2
@@ -29,16 +29,26 @@ jobs:
python --version
which nvcc
nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ - name: Python environment
+ run: |
+ pip list
+
- name: Install deepspeed
run: |
+ pip uninstall --yes deepspeed
pip install .[dev,autotuning]
ds_report
+
- name: PyTorch Lightning Tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
+ pip uninstall --yes pytorch-lightning
pip install pytorch-lightning
pip install "protobuf<4.21.0"
cd tests
diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml
new file mode 100644
index 000000000000..98afd75105a3
--- /dev/null
+++ b/.github/workflows/nv-nightly.yml
@@ -0,0 +1,52 @@
+name: nv-nightly
+
+on:
+ schedule:
+ - cron: "0 0 * * *"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ unit-tests:
+ runs-on: [self-hosted, nvidia, cu111, v100]
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: environment
+ run: |
+ nvidia-smi
+ which python
+ python --version
+ which nvcc
+ nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
+ pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
+ python -c "import torch; print('torch:', torch.__version__, torch)"
+ python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ - name: Install transformers
+ run: |
+ git clone https://github.com/huggingface/transformers
+ cd transformers
+ # if needed switch to the last known good SHA until transformers@master is fixed
+ # git checkout 1cc453d33
+ git rev-parse --short HEAD
+ pip uninstall --yes transformers
+ pip install .
+
+ - name: Install deepspeed
+ run: |
+ pip uninstall --yes deepspeed
+ pip install .[dev,1bit,autotuning,sparse_attn,inf]
+ ds_report
+
+ - name: Unit tests
+ run: |
+ unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
+ if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
+ cd tests
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'nightly' unit/ --torch_ver="1.8" --cuda_ver="11.1"
diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml
new file mode 100644
index 000000000000..fd6859c7fb45
--- /dev/null
+++ b/.github/workflows/nv-torch-latest-v100.yml
@@ -0,0 +1,64 @@
+name: nv-torch-latest-v100
+
+on:
+ push:
+ branches:
+ - 'master'
+ - 'staging**'
+ paths-ignore:
+ - 'docs/**'
+ pull_request:
+ paths-ignore:
+ - 'docs/**'
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ unit-tests:
+ runs-on: [self-hosted, nvidia, cu113, v100]
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: environment
+ run: |
+ nvidia-smi
+ which python
+ python --version
+ which nvcc
+ nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
+ pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
+ python -c "import torch; print('torch:', torch.__version__, torch)"
+ python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ - name: Install transformers
+ run: |
+ git clone https://github.com/huggingface/transformers
+ cd transformers
+ # if needed switch to the last known good SHA until transformers@master is fixed
+ # git checkout 1cc453d33
+ git rev-parse --short HEAD
+ pip uninstall --yes transformers
+ pip install .
+
+ - name: Python environment
+ run: |
+ pip list
+
+ - name: Install deepspeed
+ run: |
+ pip uninstall --yes deepspeed
+ pip install .[dev,1bit,autotuning,sparse_attn]
+ ds_report
+
+ - name: Unit tests
+ run: |
+ unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
+ if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
+ cd tests
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/
diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml
new file mode 100644
index 000000000000..e1c916afba2d
--- /dev/null
+++ b/.github/workflows/nv-torch-nightly-v100.yml
@@ -0,0 +1,57 @@
+name: nv-torch-nightly-v100
+
+on:
+ schedule:
+ - cron: "0 0 * * *"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ unit-tests:
+ runs-on: [self-hosted, nvidia, cu113, v100]
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: environment
+ run: |
+ nvidia-smi
+ which python
+ python --version
+ which nvcc
+ nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
+ pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cu113
+ python -c "import torch; print('torch:', torch.__version__, torch)"
+ python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ - name: Install transformers
+ run: |
+ git clone https://github.com/huggingface/transformers
+ cd transformers
+ # if needed switch to the last known good SHA until transformers@master is fixed
+ # git checkout 1cc453d33
+ git rev-parse --short HEAD
+ pip uninstall --yes transformers
+ pip install .
+
+ - name: Python environment
+ run: |
+ pip list
+
+ - name: Install deepspeed
+ run: |
+ pip uninstall --yes deepspeed
+ pip install .[dev,1bit,autotuning,sparse_attn]
+ ds_report
+
+ - name: Unit tests
+ run: |
+ unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
+ if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
+ cd tests
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/
diff --git a/.github/workflows/nv-torch12-p40.yml b/.github/workflows/nv-torch12-p40.yml
index 080543df6980..944ba3beb19d 100644
--- a/.github/workflows/nv-torch12-p40.yml
+++ b/.github/workflows/nv-torch12-p40.yml
@@ -17,7 +17,7 @@ concurrency:
jobs:
unit-tests:
- runs-on: [self-hosted, nvidia, torch12, p40]
+ runs-on: [self-hosted, nvidia, cu101, p40]
steps:
- uses: actions/checkout@v2
@@ -29,9 +29,16 @@ jobs:
python --version
which nvcc
nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
+ pip install torch==1.2.0 torchvision==0.4.0
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+ - name: Python environment
+ run: |
+ pip list
+
- name: Install transformers
run: |
git clone https://github.com/huggingface/transformers
@@ -39,15 +46,17 @@ jobs:
# if needed switch to the last known good SHA until transformers@master is fixed
# git checkout 1cc453d33
git rev-parse --short HEAD
+ pip uninstall --yes transformers
pip install .
- name: Install deepspeed
run: |
- pip install .[dev,autotuning]
+ pip uninstall --yes deepspeed
+ pip install .[dev,1bit,autotuning,sparse_attn]
ds_report
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
- TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/ --torch_ver="1.2" --cuda_ver="10"
diff --git a/.github/workflows/nv-torch18-v100.yml b/.github/workflows/nv-torch18-v100.yml
index 0afac798119a..b512ea29113f 100644
--- a/.github/workflows/nv-torch18-v100.yml
+++ b/.github/workflows/nv-torch18-v100.yml
@@ -17,7 +17,7 @@ concurrency:
jobs:
unit-tests:
- runs-on: [self-hosted, nvidia, torch18, v100]
+ runs-on: [self-hosted, nvidia, cu111, v100]
steps:
- uses: actions/checkout@v2
@@ -29,6 +29,8 @@ jobs:
python --version
which nvcc
nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
@@ -40,10 +42,16 @@ jobs:
# if needed switch to the last known good SHA until transformers@master is fixed
# git checkout 1cc453d33
git rev-parse --short HEAD
+ pip uninstall --yes transformers
pip install .
+ - name: Python environment
+ run: |
+ pip list
+
- name: Install deepspeed
run: |
+ pip uninstall --yes deepspeed
pip install .[dev,1bit,autotuning,sparse_attn]
ds_report
@@ -52,5 +60,5 @@ jobs:
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
- TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 -m 'not sequential' unit/
- TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -n 4 unit/ --torch_ver="1.8" --cuda_ver="11.1"
+ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m 'sequential' unit/ --torch_ver="1.8" --cuda_ver="11.1"
diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml
index c39e6978b15e..efbd015ce1b0 100644
--- a/.github/workflows/nv-transformers-v100.yml
+++ b/.github/workflows/nv-transformers-v100.yml
@@ -17,7 +17,7 @@ concurrency:
jobs:
unit-tests:
- runs-on: [self-hosted, nvidia, torch18, v100]
+ runs-on: [self-hosted, nvidia, cu111, v100]
steps:
- uses: actions/checkout@v2
@@ -29,13 +29,22 @@ jobs:
python --version
which nvcc
nvcc --version
+ pip install --upgrade pip
+ pip uninstall --yes torch torchvision
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
+
+ - name: Python environment
+ run: |
+ pip list
+
- name: Install deepspeed
run: |
+ pip uninstall --yes deepspeed
pip install .[dev,autotuning]
ds_report
+
- name: HF transformers tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
@@ -47,5 +56,10 @@ jobs:
# scipy/sklearn required for tests, using the 'dev' extra forces torch re-install
pip install .[testing]
# find reqs used in ds integration tests
- find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec pip install -r {} \;
+ find examples/pytorch -regextype posix-egrep -regex '.*(language-modeling|question-answering|summarization|image-classification|text-classification|translation).*/requirements.txt' -exec grep -v 'torch' {} \; | xargs -I {} pip install --upgrade {}
+ # force datasets version due to issues
+ pip install datasets==2.2.2
+ # force protobuf version due to issues
+ pip install "protobuf<4.21.0"
+ pip list
TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --color=yes --durations=0 --verbose tests/deepspeed
diff --git a/.gitmodules b/.gitmodules
deleted file mode 100644
index 37adb6f39e5c..000000000000
--- a/.gitmodules
+++ /dev/null
@@ -1,4 +0,0 @@
-[submodule "DeepSpeedExamples"]
- path = DeepSpeedExamples
- url = https://github.com/microsoft/DeepSpeedExamples
- branch = master
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f2fa818101c1..fb7ecd5bbfc6 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -33,6 +33,15 @@ repos:
- id: clang-format # formatter of C/C++ code based on a style guide: LLVM, Google, Chromium, Mozilla, and WebKit available
args: []
+- repo: local
+ hooks:
+ - id: check-torchdist
+ name: check-torchdist
+ entry: ./scripts/check-torchdist.py
+ language: script
+ exclude: ^(deepspeed/comm/|docs/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py)
+ # Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
+
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
diff --git a/CODEOWNERS b/CODEOWNERS
index ec7993c060aa..3eb8710cad7a 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1 +1 @@
-* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @arashashari @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @niumanar
+* @jeffra @samyam @tjruwase @ShadenSmith @conglongli @awan-10 @cli99 @eltonzheng @minjiaz @RezaYazdaniAminabadi @duli2012 @mrwyattii @yaozhewei @arashb @xiaoxiawu-microsoft
diff --git a/DeepSpeedExamples b/DeepSpeedExamples
deleted file mode 160000
index 36212dd59cb3..000000000000
--- a/DeepSpeedExamples
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 36212dd59cb3eb342c39bc8965aaba04d5491933
diff --git a/README.md b/README.md
index 65482cc1a66e..d5475be0aea2 100755
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
-[](https://github.com/microsoft/DeepSpeed/actions)
+[](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
[](https://pypi.org/project/deepspeed/)
-[](https://deepspeed.readthedocs.io/en/latest/?badge=latest)
-[](https://github.com/Microsoft/DeepSpeed/blob/master/LICENSE)
+[](https://pepy.tech/project/deepspeed)
+[](#build-pipeline-status)
+

@@ -13,6 +14,7 @@ Remove until pypi issue is resolved: https://status.python.org/incidents/2jj696s
[](https://pepy.tech/project/deepspeed)
-->
## Latest News
+* [2022/06/22] DeepSpeed Compression: 50x model size reduction via [XTC](https://arxiv.org/abs/2206.01859) and 5000x compression cost reduction via [ZeroQuant](https://arxiv.org/abs/2206.01861). Stay tuned for upcoming code release!
* [2022/03/21] [Supporting efficient large model training on AMD Instinct GPUs with DeepSpeed](https://cloudblogs.microsoft.com/opensource/2022/03/21/supporting-efficient-large-model-training-on-amd-instinct-gpus-with-deepspeed/)
* [2022/03/07] [Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/)
* [2022/01/19] [DeepSpeed: Advancing MoE inference and training to power next-generation AI scale](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/)
@@ -50,6 +52,17 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
**_For further documentation, tutorials, and technical deep-dives please see [deepspeed.ai](https://www.deepspeed.ai/)!_**
+# Build Pipeline Status
+
+| Description | Status |
+| ----------- | ------ |
+| NVIDIA | [](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch12-p40.yml) [](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch18-v100.yml) [](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-latest-v100.yml) |
+| AMD | [](https://github.com/microsoft/DeepSpeed/actions/workflows/amd.yml) |
+| PyTorch Nightly | [](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) |
+| Integrations | [](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) |
+| Misc | [](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [](https://deepspeed.readthedocs.io/en/latest/?badge=latest)|
+
+
# Table of Contents
| Section | Description |
| --------------------------------------- | ------------------------------------------- |
@@ -212,6 +225,8 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
9. Yucheng Lu, Conglong Li, Minjia Zhang, Christopher De Sa, Yuxiong He. (2022) Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam. [arXiv:2202.06009](https://arxiv.org/abs/2202.06009).
10. Samyam Rajbhandari, Conglong Li, Zhewei Yao, Minjia Zhang, Reza Yazdani Aminabadi, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He. (2022) DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale [arXiv:2201.05596](https://arxiv.org/abs/2201.05596).
11. Shaden Smith, Mostofa Patwary, Brandon Norick, Patrick LeGresley, Samyam Rajbhandari, Jared Casper, Zhun Liu, Shrimai Prabhumoye, George Zerveas, Vijay Korthikanti, Elton Zhang, Rewon Child, Reza Yazdani Aminabadi, Julie Bernauer, Xia Song, Mohammad Shoeybi, Yuxiong He, Michael Houston, Saurabh Tiwary, Bryan Catanzaro. (2022) Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model [arXiv:2201.11990](https://arxiv.org/abs/2201.11990).
+12. Xiaoxia Wu, Zhewei Yao, Minjia Zhang, Conglong Li, Yuxiong He. (2022) Extreme Compression for Pre-trained Transformers Made Simple and Efficient. [arXiv:2206.01859](https://arxiv.org/abs/2206.01859).
+13. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, Yuxiong He. (2022) ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers. [arXiv:2206.01861](https://arxiv.org/abs/2206.01861).
# Videos
1. DeepSpeed KDD 2020 Tutorial
diff --git a/bin/dsr b/bin/dsr
new file mode 120000
index 000000000000..747bf4722c42
--- /dev/null
+++ b/bin/dsr
@@ -0,0 +1 @@
+ds_report
\ No newline at end of file
diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu
index 00c70eea22b5..acb35325dd9e 100644
--- a/csrc/transformer/inference/csrc/gelu.cu
+++ b/csrc/transformer/inference/csrc/gelu.cu
@@ -174,7 +174,8 @@ __global__ void fused_bias_residual(float* input,
float* attnbias,
int total_count,
int intermediate_size,
- int mp_size)
+ int mp_size,
+ bool preln)
{
float4* input_cast = reinterpret_cast
(input);
float4* output_cast = reinterpret_cast(output);
@@ -189,12 +190,17 @@ __global__ void fused_bias_residual(float* input,
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
-
- data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
- data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
- data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
- data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
-
+ if (preln) {
+ data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
+ data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
+ data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
+ data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
+ } else {
+ data.x = data.x + out.x + bias_data.x;
+ data.y = data.y + out.y + bias_data.y;
+ data.z = data.z + out.z + bias_data.z;
+ data.w = data.w + out.w + bias_data.w;
+ }
output_cast[offset] = data;
}
}
@@ -206,7 +212,8 @@ __global__ void fused_bias_residual(__half* input,
__half* attn_bias,
int total_count,
int intermediate_size,
- int mp_size)
+ int mp_size,
+ bool preln)
{
#ifdef HALF_PRECISION_AVAILABLE
@@ -248,15 +255,21 @@ __global__ void fused_bias_residual(__half* input,
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
- low_data.x =
- (low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
- low_data.y =
- (low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
- high_data.x =
- (high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
- high_data.y =
- (high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
-
+ if (preln) {
+ low_data.x =
+ (low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
+ low_data.y =
+ (low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
+ high_data.x = (high_data.x + high_res.x) * mp_size +
+ (high_out.x + (high_bias.x + attn_high_bias.x));
+ high_data.y = (high_data.y + high_res.y) * mp_size +
+ (high_out.y + (high_bias.y + attn_high_bias.y));
+ } else {
+ low_data.x = (low_data.x + low_out.x + low_bias.x);
+ low_data.y = (low_data.y + low_out.y + low_bias.y);
+ high_data.x = (high_data.x + high_out.x + high_bias.x);
+ high_data.y = (high_data.y + high_out.y + high_bias.y);
+ }
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
@@ -274,6 +287,7 @@ void launch_bias_residual(T* input,
int batch,
int hidden_dim,
int mp_size,
+ bool preln,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
@@ -281,20 +295,13 @@ void launch_bias_residual(T* input,
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
fused_bias_residual<<>>(
- input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
+ input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size, preln);
}
-template void
-launch_bias_residual(float*, float*, float*, float*, float*, int, int, int, cudaStream_t);
-template void launch_bias_residual<__half>(__half*,
- __half*,
- __half*,
- __half*,
- __half*,
- int,
- int,
- int,
- cudaStream_t);
+template void launch_bias_residual<
+ float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t);
+template void launch_bias_residual<
+ __half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp
index ef4f0597d985..2c300c0b6c92 100644
--- a/csrc/transformer/inference/csrc/pt_binding.cpp
+++ b/csrc/transformer/inference/csrc/pt_binding.cpp
@@ -787,17 +787,17 @@ at::Tensor ds_vector_matmul_int8(at::Tensor& input,
}
template
-void mlp_unfused_cublas(at::Tensor& output,
- at::Tensor& input,
- at::Tensor& residual,
- at::Tensor& input_bias,
- at::Tensor& weight,
- at::Tensor& bias,
- at::Tensor& gamma,
- at::Tensor& beta,
- const float epsilon,
- bool preLayerNorm,
- bool mlp_after_attn)
+at::Tensor mlp_unfused_cublas(at::Tensor& output,
+ at::Tensor& input,
+ at::Tensor& residual,
+ at::Tensor& input_bias,
+ at::Tensor& weight,
+ at::Tensor& bias,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ const float epsilon,
+ bool preLayerNorm,
+ bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
@@ -840,18 +840,19 @@ void mlp_unfused_cublas(at::Tensor& output,
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
+ return inp_norm;
}
template
-at::Tensor ds_mlp_gemm(at::Tensor& input,
- at::Tensor& residual,
- at::Tensor& input_bias,
- at::Tensor& weight,
- at::Tensor& bias,
- at::Tensor& gamma,
- at::Tensor& beta,
- const float epsilon,
- bool preLayerNorm,
- bool mlp_after_attn)
+std::vector ds_mlp_gemm(at::Tensor& input,
+ at::Tensor& residual,
+ at::Tensor& input_bias,
+ at::Tensor& weight,
+ at::Tensor& bias,
+ at::Tensor& gamma,
+ at::Tensor& beta,
+ const float epsilon,
+ bool preLayerNorm,
+ bool mlp_after_attn)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
@@ -863,19 +864,19 @@ at::Tensor ds_mlp_gemm(at::Tensor& input,
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
- mlp_unfused_cublas(output,
- mlp_after_attn ? input : residual,
- residual,
- input_bias,
- weight,
- bias,
- gamma,
- beta,
- epsilon,
- preLayerNorm,
- mlp_after_attn);
-
- return output;
+ auto res_add = mlp_unfused_cublas(output,
+ mlp_after_attn ? input : residual,
+ residual,
+ input_bias,
+ weight,
+ bias,
+ gamma,
+ beta,
+ epsilon,
+ preLayerNorm,
+ mlp_after_attn);
+
+ return {output, res_add};
}
template
@@ -1001,7 +1002,8 @@ void residual_add_bias(at::Tensor& output,
at::Tensor& attention_b,
int mp_size,
bool mlp_after_attn,
- bool add_bias)
+ bool add_bias,
+ bool preln)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
@@ -1017,6 +1019,7 @@ void residual_add_bias(at::Tensor& output,
bsz,
hidden_size,
mp_size,
+ preln,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add((float*)input.data_ptr(),
@@ -1037,6 +1040,7 @@ void residual_add_bias(at::Tensor& output,
bsz,
hidden_size,
mp_size,
+ preln,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h
index c8a0b79a111b..a48b2d7f06cc 100644
--- a/csrc/transformer/inference/includes/custom_cuda_layers.h
+++ b/csrc/transformer/inference/includes/custom_cuda_layers.h
@@ -58,6 +58,7 @@ void launch_bias_residual(T* input,
int batch,
int hidden_dim,
int mp_size,
+ bool preln,
cudaStream_t stream);
template
diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py
index 585362096fe2..25229bde06d8 100755
--- a/deepspeed/__init__.py
+++ b/deepspeed/__init__.py
@@ -25,7 +25,7 @@
from .module_inject import replace_transformer_layer, revert_transformer_layer
from .utils import log_dist
-from .utils.distributed import init_distributed
+from .comm.comm import init_distributed
from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
@@ -83,7 +83,7 @@ def initialize(args=None,
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
- dist_init_required: Optional: None will auto-initialize torch.distributed if needed,
+ dist_init_required: Optional: None will auto-initialize torch distributed if needed,
otherwise the user can force it to be initialized or not via boolean.
collate_fn: Optional: Merges a list of samples to form a
@@ -114,6 +114,7 @@ def initialize(args=None,
__git_hash__,
__git_branch__),
ranks=[0])
+
assert model is not None, "deepspeed.initialize requires a model"
if not isinstance(model, PipelineModule):
diff --git a/deepspeed/autotuning/scheduler.py b/deepspeed/autotuning/scheduler.py
index e4090b38e4c6..e3c4fbe7708b 100755
--- a/deepspeed/autotuning/scheduler.py
+++ b/deepspeed/autotuning/scheduler.py
@@ -25,7 +25,7 @@
thread-N: start each experiment in its own thread
"""
-import torch.distributed as dist
+from deepspeed import comm as dist
from datetime import datetime
diff --git a/deepspeed/comm/__init__.py b/deepspeed/comm/__init__.py
new file mode 100644
index 000000000000..1975aab8582b
--- /dev/null
+++ b/deepspeed/comm/__init__.py
@@ -0,0 +1,39 @@
+import torch
+from .utils import *
+from deepspeed import utils
+
+supported_torch_version = False
+
+# See more details at: https://github.com/pytorch/pytorch/pull/48767
+# The PG API in torch versions lesser than 1.8 are different so it is
+# non-trivial to support both in the same API. We will just use the
+# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
+
+if older_torch():
+ # Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
+ # NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
+ supported_torch_version = False
+ from torch.distributed import *
+
+ def get_world_group():
+ return group.WORLD
+
+ def get_global_rank(group, group_rank):
+ from torch.distributed.distributed_c10d import _get_global_rank
+ return _get_global_rank(group, group_rank)
+
+ def allgather_fn(output_tensor, input_tensor, group, async_op):
+ from torch.distributed import all_gather, get_world_size
+ from torch import chunk
+ output_tensors = list(chunk(output_tensor, get_world_size(group)))
+ return all_gather(output_tensors, input_tensor, group=group, async_op=True)
+
+ def reduce_scatter_fn(output_tensor, input_tensor, group):
+ from torch.distributed import reduce_scatter, get_world_size
+ from torch import chunk
+ input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
+ return reduce_scatter(output_tensor, input_tensor_lst, group=group)
+
+else:
+ supported_torch_version = True
+ from .comm import *
diff --git a/deepspeed/comm/backend.py b/deepspeed/comm/backend.py
new file mode 100644
index 000000000000..72b860ce870d
--- /dev/null
+++ b/deepspeed/comm/backend.py
@@ -0,0 +1,42 @@
+''' DeepSpeed Communication Backend.
+
+In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distributed. Simply wrap torch.distributed for now.
+
+# Custom DS Backends -- Direct C/Ops
+ - NCCL -- [EXPERIMENTAL]
+ - MPI -- [EXPERIMENTAL]
+ - RCCL -- [EXPERIMENTAL]
+ - GLOO -- [EXPERIMENTAL]
+
+# DS backend wrapper for torch.distributed [DEFAULT]
+ - T-NCCL -- [DEFAULT]
+ - T-GLOO
+ - T-MPI
+'''
+''' Backend is the base class
+ -- NcclBackend, MpiBackend, and TorchBackend are the main subclasses. TorchBackend is the only officially supported backend for now.
+'''
+
+
+class Backend(object):
+ def __init__(self, name='backend', rank=0, size=1):
+ self.name = name
+ # The world size and rank of the world process group
+ self.world_group = None
+ self.world_size = rank
+ self.world_rank = size
+ # Single process group (pg) implementation for now but keep a list for future
+ self.process_groups = []
+ self.initialized = False
+
+ def is_initialized(self):
+ return self.initialized
+
+ def new_group(self):
+ # create a new pg and add it to pg list
+ pass
+
+ def init_process_group(self):
+ # subclasses will initialize them fully
+ # - initialize a default world process group and add it to pg list
+ self.initialized = True
diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py
new file mode 100644
index 000000000000..ca328ee28dd9
--- /dev/null
+++ b/deepspeed/comm/comm.py
@@ -0,0 +1,540 @@
+"""
+ Copyright 2021 The Microsoft DeepSpeed Team
+
+ DeepSpeed Communication Package: deepspeed.comm
+
+ deepspeed.comm
+ -- import and use deepspeeed.ops.comm
+ -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
+ -- use custom collectives
+ -- can either use torch.dist or ds.ops.comm?
+
+ Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
+
+ deepspeed.comm API
+ -- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
+ -- e.g. if a client code used
+ from deepspeed import comm as dist
+
+ instead of
+ import torch.distributed as dist
+
+ The code should work without breaking any of the public torch.distributed functionality
+
+ Future:
+ -- deepspeed groups API should be brought into ds.comm
+"""
+
+from enum import Enum
+import torch
+import os
+import torch
+
+from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
+from .constants import *
+
+
+class ReduceOp(Enum):
+ SUM = 0
+ PRODUCT = 1
+ MIN = 2
+ MAX = 3
+ BAND = 4
+ BOR = 5
+ BXOR = 6
+ AVG = 7
+ UNUSED = 8
+
+
+from deepspeed.comm.backend import Backend
+from deepspeed.comm.torch import TorchBackend
+
+from deepspeed import utils
+from datetime import timedelta
+
+# Current deepspeed.comm backend (cdb) global object for simple access by client code
+use_ds_backend = False
+cdb = None
+
+# Ensure we don't warn about base collectives more than once
+has_warned_all_gather = False
+has_warned_reduce_scatter = False
+
+# Maintain objects of all initialized ds backends and assign them using the API functions in this file
+nccl_backend = None
+mpi_backend = None
+
+# This should be set here so all rank/size information from the launcher can be propagated
+from deepspeed.comm.utils import *
+
+# For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
+# DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
+# Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html
+
+
+# UNUSED: Future helper function to initialize DS backends
+def init_deepspeed_backend(ds_backend):
+ global cdb
+ global nccl_backend
+ global mpi_backend
+ global use_ds_backend
+
+ if ds_backend == NCCL_BACKEND:
+ utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
+ elif ds_backend == MPI_BACKEND:
+ utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
+ elif ds_backend == GLOO_BACKEND:
+ utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
+ else:
+ utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
+
+
+def is_initialized():
+ #assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ if cdb is None:
+ return False
+ else:
+ return cdb.is_initialized()
+
+
+def destroy_process_group(group=None):
+ global cdb
+ return cdb.destroy_process_group(group=group)
+
+
+def new_group(ranks):
+ global cdb
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ return cdb.new_group(ranks)
+
+
+def is_available() -> bool:
+
+ # Returns ``True`` if the deepspeed comm package is available.
+
+ # TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
+ # any communication related primitives from this package.
+ # use hasattr(deepspeed.csrc.ops, "_comm") or something
+ return True
+
+
+def set_backend(backend):
+ if not use_ds_backend:
+ utils.logger.error(
+ "DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
+ )
+ raise RuntimeError(
+ 'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
+ )
+
+ global cdb
+ global nccl_backend
+ global mpi_backend
+
+ try:
+ if backend_name == NCCL_BACKEND:
+ if nccl_backend is not None and nccl_backend.is_initialized():
+ cdb = nccl_backend
+ elif backend_name == MPI_BACKEND:
+ if mpi_backend is not None and mpi_backend.is_initialized():
+ cdb = mpi_backend
+ except Exception as inst:
+ print(inst)
+
+
+def broadcast(tensor, src, group=None, async_op=False):
+ global cdb
+ return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
+
+
+def all_gather(tensor_list, tensor, group=None, async_op=False):
+ global cdb
+ return cdb.all_gather(tensor_list=tensor_list,
+ tensor=tensor,
+ group=group,
+ async_op=async_op)
+
+
+def has_reduce_scatter_base():
+ global cdb
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
+ return cdb.has_reduce_scatter_base
+
+
+def reduce_scatter_fn(output_tensor: torch.Tensor, input_tensor: torch.Tensor, group):
+ global cdb
+ global has_warned_reduce_scatter
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ if cdb.has_reduce_scatter_base:
+ return cdb.reduce_scatter_base(output_tensor, input_tensor, group=group)
+ else:
+ if not has_warned_reduce_scatter:
+ utils.logger.warning(
+ "unable to find torch.distributed._reduce_scatter_base. will fall back to "
+ "torch.distributed.all_gather which will result in suboptimal performance. "
+ "please consider upgrading your pytorch installation.")
+ has_warned_reduce_scatter = True
+ input_tensor_lst = list(torch.chunk(input_tensor, cdb.get_world_size(group)))
+ return cdb.reduce_scatter(output_tensor, input_tensor_lst, group=group)
+
+
+def reduce_scatter_base(output_tensor, input_tensor, group=None):
+ global cdb
+ return cdb.reduce_scatter_base(output_tensor=output_tensor,
+ input_tensor=input_tensor,
+ group=group)
+
+
+def has_allgather_base():
+ global cdb
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
+ return cdb.has_allgather_base
+
+
+def allgather_fn(output_tensor: torch.Tensor,
+ input_tensor: torch.Tensor,
+ group,
+ async_op):
+ global cdb
+ global has_warned_all_gather
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ if cdb.has_allgather_base:
+ return cdb.all_gather_base(output_tensor,
+ input_tensor,
+ group=group,
+ async_op=True)
+ else:
+ if not has_warned_all_gather:
+ utils.logger.warning(
+ "unable to find torch.distributed._all_gather_base. will fall back to "
+ "torch.distributed.all_gather which will result in suboptimal performance. "
+ "please consider upgrading your pytorch installation.")
+ has_warned_all_gather = True
+ output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
+ return cdb.all_gather(output_tensors, input_tensor, group=group, async_op=True)
+
+
+def all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
+ global cdb
+ return cdb.all_gather_base(output_tensor=output_tensor,
+ input_tensor=input_tensor,
+ group=group,
+ async_op=async_op)
+
+
+def all_to_all_single(
+ output,
+ input,
+ output_split_sizes=None,
+ input_split_sizes=None,
+ group=None,
+ async_op=False,
+):
+ global cdb
+ return cdb.all_to_all_single(output=output,
+ input=input,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op)
+
+
+def send(tensor, dst, group=None, tag=0):
+ global cdb
+ return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
+
+
+def recv(tensor, src=None, group=None, tag=0):
+ global cdb
+ return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
+
+
+def isend(tensor, dst, group=None, tag=0):
+ global cdb
+ return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
+
+
+def irecv(tensor, src=None, group=None, tag=0):
+ global cdb
+ return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
+
+
+def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
+ global cdb
+ return cdb.gather(tensor=tensor,
+ gather_list=gather_list,
+ dst=dst,
+ group=group,
+ async_op=async_op)
+
+
+def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
+ global cdb
+ return cdb.scatter(tensor=tensor,
+ scatter_list=scatter_list,
+ src=src,
+ group=group,
+ async_op=async_op)
+
+
+def barrier(group=None):
+ global cdb
+ return cdb.barrier()
+
+
+# Local enum for Reduction operators
+#from .utils import ReduceOp
+
+
+def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
+ global cdb
+ return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
+
+
+def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
+ global cdb
+ return cdb.reduce_scatter(output=output,
+ input_list=input_list,
+ op=op,
+ group=group,
+ async_op=async_op)
+
+
+def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
+ #if profile_comm:
+ # context of the timers?
+ # timers.start()
+ # TensorBoard logging for comm calls.?
+ global cdb
+ #print(f'op = {op}, cdb= {cdb.name}')
+ return cdb.all_reduce(tensor, op, group, async_op)
+
+
+def get_world_group():
+ global cdb
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ return cdb.get_world_group()
+
+
+def get_world_size(group=None) -> int:
+ """
+ Returns the number of processes in the current process group
+ Args:
+ group (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used.
+ Returns:
+ The world size of the process group
+ -1, if not part of the group
+ """
+ global cdb
+
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ return cdb.get_world_size(group)
+
+
+def get_rank(group=None):
+ """
+ Returns the rank of the current process in the provided ``group`` or the
+ default group if none was provided.
+ Rank is a unique identifier assigned to each process within a distributed
+ process group. They are always consecutive integers ranging from 0 to
+ ``world_size``.
+ Args:
+ group (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used.
+ Returns:
+ The rank of the process group
+ -1, if not part of the group
+ """
+ global cdb
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ return cdb.get_rank(group)
+
+
+def get_local_rank():
+ """
+ Helper function to get local rank after a backend has been set and initialized
+ Args:
+ None
+ Returns:
+ local rank (= GPU device ID)
+ """
+ global cdb
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ return get_local_rank_from_launcher()
+
+
+def get_global_rank(group=None, group_rank=0):
+ global cdb
+ assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
+ return cdb.get_global_rank(group, group_rank)
+
+
+# Main DeepSpeed Comms. public API.
+def init_distributed(dist_backend="nccl",
+ auto_mpi_discovery=True,
+ distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
+ verbose=True,
+ timeout=default_pg_timeout,
+ init_method=None,
+ dist_init_required=None):
+ ''' Initialize dist backend, potentially performing MPI discovery if needed
+
+ Arguments:
+ dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
+ auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
+ distributed_port: Optional (int). torch distributed backend port
+ verbose: Optional (bool). verbose logging
+ timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
+ init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
+ '''
+ global cdb
+
+ if dist_init_required is None:
+ dist_init_required = cdb is None or not cdb.is_initialized()
+
+ if dist_init_required is False:
+ assert (
+ cdb is not None and cdb.is_initialized() is True
+ ), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
+ else:
+ # Initialize torch distributed if needed
+ required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
+ if verbose:
+ utils.logger.info(
+ "Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
+ )
+ if in_aml() and not in_dlts():
+ patch_aml_env_for_torch_nccl_backend(verbose=verbose)
+ elif in_aws_sm():
+ patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
+ else:
+ mpi_discovery(distributed_port=distributed_port, verbose=verbose)
+
+ if cdb is not None and cdb.is_initialized():
+ if int(os.getenv('RANK', '0')) == 0:
+ utils.logger.info('Distributed backend already initialized')
+ else:
+ assert isinstance(timeout, timedelta)
+ if int(os.getenv('RANK', '0')) == 0:
+ utils.logger.info(
+ 'Initializing TorchBackend in DeepSpeed with backend {}'.format(
+ dist_backend))
+ # Create a torch backend object, initialize torch distributed, and assign to cdb
+ cdb = TorchBackend(dist_backend, timeout, init_method)
+
+
+def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
+ '''
+ Discovery MPI environment via mpi4py and map to relevant dist state
+ '''
+ from mpi4py import MPI
+ import subprocess
+ comm = MPI.COMM_WORLD
+ rank = comm.Get_rank()
+ world_size = comm.Get_size()
+
+ master_addr = None
+ if rank == 0:
+ hostname_cmd = ["hostname -I"]
+ result = subprocess.check_output(hostname_cmd, shell=True)
+ master_addr = result.decode('utf-8').split()[0]
+ master_addr = comm.bcast(master_addr, root=0)
+
+ # Determine local rank by assuming hostnames are unique
+ proc_name = MPI.Get_processor_name()
+ all_procs = comm.allgather(proc_name)
+ local_rank = sum([i == proc_name for i in all_procs[:rank]])
+
+ os.environ['RANK'] = str(rank)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['LOCAL_RANK'] = str(local_rank)
+ os.environ['MASTER_ADDR'] = master_addr
+ os.environ['MASTER_PORT'] = str(distributed_port)
+
+ if verbose:
+ utils.logger.info(
+ "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
+ .format(os.environ['RANK'],
+ os.environ['LOCAL_RANK'],
+ os.environ['WORLD_SIZE'],
+ os.environ['MASTER_ADDR'],
+ os.environ['MASTER_PORT']))
+
+ if cdb is not None and cdb.is_initialized():
+ assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
+ rank, cdb.get_rank())
+ assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
+ world_size, cdb.get_world_size())
+
+
+def in_aml():
+ # Are we running inside an Azure Machine Learning (AML) environment?
+ return 'AZUREML_EXPERIMENT_ID' in os.environ
+
+
+def in_aws_sm():
+ # Are we running inside an AWS SageMaker environment?
+ return 'SM_TRAINING_ENV' in os.environ
+
+
+def in_dlts():
+ # Are we running on a DLTS cluster?
+ return 'DLTS_JOB_ID' in os.environ
+
+
+def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
+ """Helper routine to get and set environment variables.
+ This is adapted from Azure ML's documentation available from:
+ https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
+ """
+ os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
+ os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
+ single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
+ os.environ["WORLD_SIZE"])
+
+ if not single_node:
+ master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
+ os.environ["MASTER_ADDR"] = master_node_params[0]
+ # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
+ if "MASTER_PORT" not in os.environ:
+ os.environ["MASTER_PORT"] = str(master_port)
+ else:
+ os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
+ os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
+
+ if verbose:
+ utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
+ os.environ["NCCL_SOCKET_IFNAME"]))
+
+ os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
+ os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
+
+ if verbose:
+ utils.logger.info(
+ "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
+ .format(os.environ['RANK'],
+ os.environ['LOCAL_RANK'],
+ os.environ['WORLD_SIZE'],
+ os.environ['MASTER_ADDR'],
+ os.environ['MASTER_PORT']))
+
+
+def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
+ """Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
+ """
+ os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
+ os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
+ os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
+
+ if verbose:
+ utils.logger.info(
+ "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
+ .format(os.environ['RANK'],
+ os.environ['LOCAL_RANK'],
+ os.environ['WORLD_SIZE'],
+ os.environ['MASTER_ADDR'],
+ os.environ['MASTER_PORT']))
diff --git a/deepspeed/comm/constants.py b/deepspeed/comm/constants.py
new file mode 100644
index 000000000000..d85f72e8cbaa
--- /dev/null
+++ b/deepspeed/comm/constants.py
@@ -0,0 +1,7 @@
+NCCL_BACKEND = 'nccl'
+MPI_BACKEND = 'mpi'
+GLOO_BACKEND = 'gloo'
+SCCL_BACKEND = 'sccl'
+
+DEFAULT_AML_MASTER_PORT = "54965"
+DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo"
diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py
new file mode 100644
index 000000000000..dcfae9161213
--- /dev/null
+++ b/deepspeed/comm/torch.py
@@ -0,0 +1,205 @@
+'''
+Copyright 2021 The Microsoft DeepSpeed Team
+'''
+import os
+import torch
+
+from deepspeed import utils
+from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
+from datetime import timedelta
+
+from .utils import *
+from .backend import *
+from .comm import *
+
+
+class TorchBackend(Backend):
+ """
+ A light-weight wrapper class for torch.distributed API.
+ Only a subset of functions are wrapped. Once the init_process_group
+ is initialized, standard torch.distributed.* can be used directly
+ so no need to wrap all the functions. We can keep adding wrappers as
+ needed.
+ """
+ def __init__(self, backend, timeout, init_method, name='torch'):
+ super(TorchBackend, self).__init__()
+ self.torch_version_before_18 = older_torch()
+ self.has_allgather_base = has_allgather_base()
+ self.has_reduce_scatter_base = has_reduce_scatter_base()
+ self.initialized = True
+ self.name = name
+ # Future functionality to support ds.initialize() on a single GPU
+ # The idea is to fake that dist backend is initialized even when
+ # it is not so we can run on a single GPU without doing any init_process_group
+ self.single_gpu_mode = True
+ self.init_process_group(backend, timeout, init_method)
+
+ def init_process_group(self, backend, timeout, init_method):
+ return torch.distributed.init_process_group(backend,
+ timeout=timeout,
+ init_method=init_method)
+
+ def all_reduce(self,
+ tensor,
+ op=torch.distributed.ReduceOp.SUM,
+ group=None,
+ async_op=False):
+ op = self._reduce_op(op)
+ return torch.distributed.all_reduce(tensor=tensor,
+ op=op,
+ group=group,
+ async_op=async_op)
+
+ def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
+ return torch.distributed.reduce(tensor=tensor,
+ dst=dst,
+ op=self._reduce_op(op),
+ group=group,
+ async_op=async_op)
+
+ def reduce_scatter(self,
+ output,
+ input_list,
+ op=ReduceOp.SUM,
+ group=None,
+ async_op=False):
+ return torch.distributed.reduce_scatter(output=output,
+ input_list=input_list,
+ op=self._reduce_op(op),
+ group=group,
+ async_op=async_op)
+
+ def broadcast(self, tensor, src, group=None, async_op=False):
+ return torch.distributed.broadcast(tensor=tensor,
+ src=src,
+ group=group,
+ async_op=async_op)
+
+ def all_gather(self, tensor_list, tensor, group=None, async_op=False):
+ return torch.distributed.all_gather(tensor_list=tensor_list,
+ tensor=tensor,
+ group=group,
+ async_op=async_op)
+
+ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=False):
+ if self.has_allgather_base:
+ return torch.distributed.distributed_c10d._all_gather_base(
+ output_tensor=output_tensor,
+ input_tensor=input_tensor,
+ group=group,
+ async_op=async_op)
+ else:
+ utils.logger.warning(
+ "unable to find torch.distributed._all_gather_base. will fall back to "
+ "torch.distributed.reduce_scatter which will result in suboptimal performance. "
+ "please consider upgrading your pytorch installation.")
+ pass
+
+ def reduce_scatter_base(self, output_tensor, input_tensor, group=None):
+ if self.has_reduce_scatter_base:
+ return torch.distributed._reduce_scatter_base(output_tensor,
+ input_tensor,
+ group=group)
+ else:
+ utils.logger.warning(
+ "unable to find torch.distributed._reduce_scatter_base. will fall back to "
+ "torch.distributed.reduce_scatter which will result in suboptimal performance. "
+ "please consider upgrading your pytorch installation.")
+ pass
+
+ def all_to_all_single(self,
+ output,
+ input,
+ output_split_sizes=None,
+ input_split_sizes=None,
+ group=None,
+ async_op=False):
+ return torch.distributed.all_to_all_single(output=output,
+ input=input,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ async_op=async_op)
+
+ def send(self, tensor, dst, group=None, tag=0):
+ return torch.distributed.send(tensor=tensor, dst=dst, group=group, tag=tag)
+
+ def recv(self, tensor, src=None, group=None, tag=0):
+ return torch.distributed.recv(tensor=tensor, src=src, group=group, tag=tag)
+
+ def isend(self, tensor, dst, group=None, tag=0):
+ return torch.distributed.isend(tensor=tensor, dst=dst, group=group, tag=tag)
+
+ def irecv(self, tensor, src=None, group=None, tag=0):
+ return torch.distributed.irecv(tensor=tensor, src=src, group=group, tag=tag)
+
+ def gather(self, tensor, gather_list=None, dst=0, group=None, async_op=False):
+ return torch.distributed.gather(tensor=tensor,
+ gather_list=gather_list,
+ dst=dst,
+ group=group,
+ async_op=async_op)
+
+ def scatter(self, tensor, scatter_list=None, src=0, group=None, async_op=False):
+ return torch.distributed.scatter(tensor=tensor,
+ scatter_list=scatter_list,
+ src=src,
+ group=group,
+ async_op=async_op)
+
+ def barrier(self):
+ return torch.distributed.barrier()
+
+ def get_rank(self, group=None):
+ return torch.distributed.get_rank(group=group)
+
+ def get_world_size(self, group=None):
+ return torch.distributed.get_world_size(group=group)
+
+ def is_initialized(self):
+ return torch.distributed.is_initialized()
+
+ def get_backend(self, group=None):
+ return torch.distributed.get_backend(group=group)
+
+ def new_group(self, ranks):
+ utils.logger.info(f"new group called with {ranks}")
+ return torch.distributed.new_group(ranks)
+
+ def get_global_rank(self, group, group_rank):
+ return torch.distributed.distributed_c10d._get_global_rank(group, group_rank)
+
+ def get_world_group(self):
+ return torch.distributed.group.WORLD
+
+ def destroy_process_group(self, group=None):
+ return torch.distributed.destroy_process_group(group=group)
+
+ def _reduce_op(self, op):
+ '''
+ Helper function. If the op provided is not a torch.dist.ReduceOp, convert it and return
+ '''
+ if not isinstance(op, torch.distributed.ReduceOp):
+ if op == ReduceOp.SUM:
+ op = torch.distributed.ReduceOp.SUM
+ elif op == ReduceOp.PRODUCT:
+ op = torch.distributed.ReduceOp.PRODUCT
+ elif op == ReduceOp.AVG:
+ op = torch.distributed.ReduceOp.AVG
+ elif op == ReduceOp.MIN:
+ op = torch.distributed.ReduceOp.MIN
+ elif op == ReduceOp.MAX:
+ op = torch.distributed.ReduceOp.MAX
+ elif op == ReduceOp.BAND:
+ op = torch.distributed.ReduceOp.BAND
+ elif op == ReduceOp.BOR:
+ op = torch.distributed.ReduceOp.BOR
+ elif op == ReduceOp.BXOR:
+ op = torch.distributed.ReduceOp.BXOR
+ return op
+
+
+# This will become a light-weight wrapper around torch.distributed functions
+# TODO: create some example to show how this wrapper can help profile communication
+# TODO: make sure there is no performance regression with this approach
+# TODO: explore monkey-patching if this does not work
diff --git a/deepspeed/comm/utils.py b/deepspeed/comm/utils.py
new file mode 100644
index 000000000000..7ca5581a2029
--- /dev/null
+++ b/deepspeed/comm/utils.py
@@ -0,0 +1,75 @@
+import os
+import enum
+import torch
+
+
+def older_torch():
+ '''
+ Helper to lookup torch version. For versions less than 1.8, torch.dist
+ used torch.distributed.group.WORLD as the default group argument instead of None.
+ See more details at: https://github.com/pytorch/pytorch/pull/48767
+ '''
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
+ return True
+ else:
+ return False
+
+
+def has_allgather_base():
+ '''
+ Helper to check if torch.distributed has _all_gather_base
+ '''
+ return hasattr(torch.distributed, "_all_gather_base")
+
+
+def has_reduce_scatter_base():
+ '''
+ Helper to check if torch.distributed has _reduce_scatter_base
+ '''
+ return hasattr(torch.distributed, "_reduce_scatter_base")
+
+
+def get_local_rank_from_launcher():
+
+ # DeepSpeed launcher will set it so get from there
+ rank = os.environ.get('LOCAL_RANK')
+
+ if rank is None:
+ rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK')
+
+ # Make it a single process job and set rank to 0
+ if rank is None:
+ rank = 0
+
+ return int(rank)
+
+
+def get_world_rank_from_launcher():
+
+ # DeepSpeed launcher will set it so get from there
+ rank = os.environ.get('RANK')
+
+ if rank is None:
+ rank = os.environ.get('OMPI_COMM_WORLD_RANK')
+
+ # Make it a single process job and set rank to 0
+ if rank is None:
+ rank = 0
+
+ return int(rank)
+
+
+def get_world_size_from_launcher():
+ # DeepSpeed launcher will set it so get from there
+ size = os.environ.get('WORLD_SIZE')
+
+ if size is None:
+ size = os.environ.get('OMPI_COMM_WORLD_SIZE')
+
+ # Make it a single process job and set size to 1
+ if size is None:
+ size = 1
+
+ return int(size)
diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py
index 5444016cae1b..b5e68aaf297f 100755
--- a/deepspeed/inference/engine.py
+++ b/deepspeed/inference/engine.py
@@ -3,20 +3,22 @@
'''
import torch
import os
+
+from deepspeed import comm as dist
+from deepspeed.utils import groups
+
from torch.nn.modules import Module
-import torch.distributed as dist
+from packaging import version as pkg_version
+
from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject.replace_module import replace_transformer_layer
-from ..utils import logger, init_distributed
-
+from ..utils import logger
+from ..comm.comm import init_distributed
from ..pipe import PipelineModule
from ..moe.utils import has_moe_layers
from ..moe.layer import MoE
-import torch.distributed as dist
-import deepspeed.utils.groups as groups
-
DS_INFERENCE_ENABLED = False
@@ -87,9 +89,13 @@ def __init__(self,
self.ep_group = ep_group
self.expert_mp_group = expert_mp_group
self.enable_cuda_graph = enable_cuda_graph
- self.cuda_grah_created = False
+ self.cuda_graph_created = False
self._init_quantization_setting(quantization_setting)
+ if enable_cuda_graph:
+ assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
+ "If you want to use cuda graph, please upgrade torch to at least v1.10"
+
if self.checkpoint:
self._load_checkpoint(self.checkpoint)
@@ -371,7 +377,7 @@ def _create_cuda_graph(self, *inputs, **kwargs):
with torch.cuda.graph(self._cuda_graphs):
self.static_output = self.module(*self.static_inputs, **self.static_kwargs)
- self.cuda_grah_created = True
+ self.cuda_graph_created = True
def _graph_replay(self, *inputs, **kwargs):
for i in range(len(inputs)):
@@ -408,7 +414,7 @@ def forward(self, *inputs, **kwargs):
outputs = self.model_orig_fwd(*inputs, **kwargs)
else:
if self.enable_cuda_graph:
- if self.cuda_grah_created:
+ if self.cuda_graph_created:
outputs = self._graph_replay(*inputs, **kwargs)
else:
self._create_cuda_graph(*inputs, **kwargs)
diff --git a/deepspeed/launcher/launch.py b/deepspeed/launcher/launch.py
index 21ef04b338b7..06321cbf90b3 100755
--- a/deepspeed/launcher/launch.py
+++ b/deepspeed/launcher/launch.py
@@ -1,6 +1,6 @@
# Copyright 2020 The Microsoft DeepSpeed Team
"""
-DeepSpeed launcher, this is similar to torch.distributed.launch but supports
+DeepSpeed launcher, this is similar to torch's distributed.launch but supports
additional features such as arbitrary gpu exclusion.
deepspeed.launcher.launch is intended to be run on a single worker node and
diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py
index a962a8a7c925..27f32af30a2e 100644
--- a/deepspeed/launcher/multinode_runner.py
+++ b/deepspeed/launcher/multinode_runner.py
@@ -3,7 +3,7 @@
import shutil
import subprocess
import warnings
-from shlex import quote
+from shlex import quote, split
from abc import ABC, abstractmethod
from ..utils import logger
@@ -66,7 +66,7 @@ def get_cmd(self, environment, active_resources):
# PDSH flags for max node fan out and specific hosts to launch on
# See https://linux.die.net/man/1/pdsh for flag details
- pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers]
+ pdsh_cmd_args = ['pdsh', '-S', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers]
exports = ""
for key, val in self.exports.items():
@@ -137,7 +137,7 @@ def get_cmd(self, environment, active_resources):
'--mca',
'btl_tcp_if_include',
'eth0',
- ]
+ ] + split(self.args.launcher_args)
export_cmd = []
for k, v in self.exports.items():
@@ -227,7 +227,7 @@ def get_cmd(self, environment, active_resources):
f'{process_per_node}',
'--hostfile',
f'{MVAPICH_TMP_HOSTFILE}',
- ]
+ ] + split(self.args.launcher_args)
export_cmd = []
for k, v in self.exports.items():
diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py
index 79c89ffd6bd8..72599e9e43a1 100755
--- a/deepspeed/module_inject/replace_module.py
+++ b/deepspeed/module_inject/replace_module.py
@@ -19,7 +19,7 @@ def __init__(self, weight, bias=None, mp_group=None):
def forward(self, input):
output = torch.matmul(input, self.weight)
if self.mp_group is not None:
- torch.distributed.all_reduce(output, group=self.mp_group)
+ dist.all_reduce(output, group=self.mp_group)
if self.bias is not None:
output += self.bias
return output
@@ -41,7 +41,7 @@ def forward(self, input):
class ReplaceWithTensorSlicing:
def __init__(self, mp_group=None):
if mp_group is not None:
- self.gpu_index = torch.distributed.get_rank(group=mp_group)
+ self.gpu_index = dist.get_rank(group=mp_group)
else:
self.gpu_index = 0
@@ -247,7 +247,7 @@ def replace_with_policy(child,
if inference:
if moe:
- ep_world_size = torch.distributed.get_world_size()
+ ep_world_size = dist.get_world_size()
local_ep_size = 1 if num_experts < ep_world_size else num_experts // ep_world_size
transformer_config = transformer_inference.DeepSpeedMoEInferenceConfig(
@@ -351,8 +351,11 @@ def replace_with_policy(child,
# linear layer is created with [input, output] shape
# transpose it here to reduce inference cost!
def transpose(data):
+ # temp move to cpu to avoid requiring extra GPU memory during the reshape
+ data = data.to('cpu')
data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
data = data.reshape(data.shape[-1], data.shape[-2])
+ data.to(torch.cuda.current_device())
return data
if attn_linear_layer:
@@ -416,7 +419,7 @@ def _transpose(x):
mpl_block = new_module.mlp
if moe:
- gpu_index = torch.distributed.get_rank()
+ gpu_index = dist.get_rank()
gpu_index = 0
for ep_index in range(local_ep_size):
mpl_block[ep_index].inter_w.data = _h4h_w[
@@ -460,7 +463,7 @@ def _transpose(x):
new_module.norm_b.data = input_nb.to(torch.cuda.current_device())
else:
transformer_config = deepspeed.DeepSpeedTransformerConfig(
- batch_size=micro_batch_size,
+ batch_size=micro_batch_size if micro_batch_size > 0 else 1,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py
index c596da4903e8..7dcf4144c0e6 100644
--- a/deepspeed/moe/layer.py
+++ b/deepspeed/moe/layer.py
@@ -4,11 +4,11 @@
import torch.nn.init as init
import torch
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils import logger, log_dist
-import deepspeed.utils.groups as groups
+from deepspeed.utils import groups
from .sharded_moe import MOELayer, TopKGate
from .experts import Experts
import copy
diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py
index 024de2f51e05..d87d22dc7ff7 100644
--- a/deepspeed/moe/sharded_moe.py
+++ b/deepspeed/moe/sharded_moe.py
@@ -20,7 +20,7 @@
from time import perf_counter
import torch
from torch import Tensor
-import torch.distributed as dist
+from deepspeed import comm as dist
from torch.nn import Module, ModuleList
import torch.nn.functional as F
@@ -80,12 +80,20 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
return gumbel(shape)
+from deepspeed import comm as dist
+
+# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
+# See https://arxiv.org/pdf/2006.16668.pdf for details.
+
+
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
- def forward(ctx: Any,
- group: dist.ProcessGroup,
- input: Tensor) -> Tensor: # type: ignore
+ def forward(
+ ctx: Any,
+ # TODO: replace with DS process group
+ group: torch.distributed.ProcessGroup,
+ input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
@@ -206,7 +214,7 @@ def top1gating(logits: Tensor,
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
- dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
+ dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
capacity = new_capacity
# Compute l_aux
diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py
index 09922937e6d8..2b62a66eba09 100644
--- a/deepspeed/moe/utils.py
+++ b/deepspeed/moe/utils.py
@@ -1,6 +1,6 @@
from typing import List, Tuple, Dict
import torch
-import deepspeed.utils.groups as groups
+from deepspeed.utils import groups
from .layer import MoE
diff --git a/deepspeed/monitor/__init__.py b/deepspeed/monitor/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py
new file mode 100644
index 000000000000..b854b8580023
--- /dev/null
+++ b/deepspeed/monitor/config.py
@@ -0,0 +1,50 @@
+"""
+Copyright (c) Microsoft Corporation
+Licensed under the MIT license.
+"""
+
+from typing import Optional
+from deepspeed.runtime.config_utils import get_scalar_param
+from pydantic import BaseModel, validator, ValidationError, create_model
+from .constants import *
+
+
+class MonitorConfig(BaseModel):
+ class Config:
+ validate_all = True
+ validate_assignment = True
+ use_enum_values = True
+ extra = 'forbid'
+
+
+class TensorBoardConfig(MonitorConfig):
+ enabled: bool = TENSORBOARD_ENABLED_DEFAULT
+ output_path: str = TENSORBOARD_OUTPUT_PATH_DEFAULT
+ job_name: str = TENSORBOARD_JOB_NAME_DEFAULT
+
+
+class WandbConfig(MonitorConfig):
+ enabled: bool = WANDB_ENABLED_DEFAULT
+ group: str = WANDB_GROUP_NAME_DEFAULT
+ team: str = WANDB_TEAM_NAME_DEFAULT
+ project: str = WANDB_PROJECT_NAME_DEFAULT
+
+
+class CSVConfig(MonitorConfig):
+ enabled: bool = CSV_MONITOR_ENABLED_DEFAULT
+ output_path: str = CSV_MONITOR_OUTPUT_PATH_DEFAULT
+ job_name: str = CSV_MONITOR_JOB_NAME_DEFAULT
+
+
+class DeepSpeedMonitorConfig:
+ def __init__(self, ds_config):
+ self.tensorboard_enabled = 'tensorboard' in ds_config
+ self.wandb_enabled = 'wandb' in ds_config
+ self.csv_monitor_enabled = 'csv_monitor' in ds_config
+
+ if self.tensorboard_enabled:
+ self.tensorboard_config = TensorBoardConfig(**ds_config['tensorboard'])
+ if self.wandb_enabled:
+ self.wandb_config = WandbConfig(**ds_config['wandb'])
+ if self.csv_monitor_enabled:
+ self.csv_monitor_config = CSVConfig(**ds_config['csv_monitor'])
diff --git a/deepspeed/monitor/constants.py b/deepspeed/monitor/constants.py
new file mode 100644
index 000000000000..95cb970175c4
--- /dev/null
+++ b/deepspeed/monitor/constants.py
@@ -0,0 +1,85 @@
+#########################################
+# Tensorboard
+#########################################
+# Tensorboard. By default, this feature is not enabled.
+# Users can configure in ds_config.json as below example:
+TENSORBOARD_FORMAT = '''
+Tensorboard can be specified as:
+"tensorboard": {
+ "enabled": true,
+ "output_path": "/home/myname/foo",
+ "job_name": "model_lr2e-5_epoch3_seed2_seq64"
+}
+'''
+TENSORBOARD = "tensorboard"
+
+# Tensorboard enable signal
+TENSORBOARD_ENABLED = "enabled"
+TENSORBOARD_ENABLED_DEFAULT = False
+
+# Tensorboard output path
+TENSORBOARD_OUTPUT_PATH = "output_path"
+TENSORBOARD_OUTPUT_PATH_DEFAULT = ""
+
+# Tensorboard job name
+TENSORBOARD_JOB_NAME = "job_name"
+TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"
+
+#########################################
+# Wandb
+#########################################
+# Wandb. By default, this feature is not enabled.
+# Users can configure in ds_config.json as below example:
+WANDB_FORMAT = '''
+Wandb can be specified as:
+"wandb": {
+ "enabled": true,
+ "team_name": "deepspeed"
+ "project_name": "zero"
+ "group_name": "zero: stage 3",
+}
+'''
+WANDB = "wandb"
+
+# Wandb enable signal
+WANDB_ENABLED = "enabled"
+WANDB_ENABLED_DEFAULT = False
+
+# Wandb team
+WANDB_TEAM_NAME = "team"
+WANDB_TEAM_NAME_DEFAULT = None
+
+# Wandb project
+WANDB_PROJECT_NAME = "project"
+WANDB_PROJECT_NAME_DEFAULT = "deepspeed"
+
+# Wandb group
+WANDB_GROUP_NAME = "group"
+WANDB_GROUP_NAME_DEFAULT = None
+
+#########################################
+# csv monitor
+#########################################
+# Basic CSV monitor. By default, this feature is not enabled.
+# Users can configure in ds_config.json as below example:
+CSV_FORMAT = '''
+The basic csv monitor can be specified as:
+"csv_monitor": {
+ "enabled": true,
+ "output_path": "/home/myname/foo",
+ "job_name": "model_lr2e-5_epoch3_seed2_seq64"
+}
+'''
+CSV_MONITOR = "csv_monitor"
+
+# csv monitor enable signal
+CSV_MONITOR_ENABLED = "enabled"
+CSV_MONITOR_ENABLED_DEFAULT = False
+
+# csv monitor output path
+CSV_MONITOR_OUTPUT_PATH = "output_path"
+CSV_MONITOR_OUTPUT_PATH_DEFAULT = ""
+
+# csv_monitor job name
+CSV_MONITOR_JOB_NAME = "job_name"
+CSV_MONITOR_JOB_NAME_DEFAULT = "DeepSpeedJobName"
diff --git a/deepspeed/monitor/csv_monitor.py b/deepspeed/monitor/csv_monitor.py
new file mode 100644
index 000000000000..b2b05260e445
--- /dev/null
+++ b/deepspeed/monitor/csv_monitor.py
@@ -0,0 +1,62 @@
+from .monitor import Monitor
+import os
+
+import deepspeed.comm as dist
+
+
+class csvMonitor(Monitor):
+ def __init__(self, monitor_config):
+ super().__init__(monitor_config)
+ import csv
+ self.filenames = []
+ self.enabled = monitor_config.csv_monitor_config.enabled
+ self.output_path = monitor_config.csv_monitor_config.output_path
+ self.job_name = monitor_config.csv_monitor_config.job_name
+ self.log_dir = self.setup_log_dir()
+
+ def setup_log_dir(self, base=os.path.join(os.path.expanduser("~"), "csv_monitor")):
+ if self.enabled and dist.get_rank() == 0:
+ if self.output_path is not None:
+ log_dir = os.path.join(self.output_path, self.job_name)
+ # NOTE: This code path currently is never used since the default tensorboard_output_path is an empty string and not None. Saving it in case we want this functionality in the future.
+ else:
+ if "DLWS_JOB_ID" in os.environ:
+ infra_job_id = os.environ["DLWS_JOB_ID"]
+ elif "DLTS_JOB_ID" in os.environ:
+ infra_job_id = os.environ["DLTS_JOB_ID"]
+ else:
+ infra_job_id = "unknown-job-id"
+
+ csv_monitor_dir_name = os.path.join(infra_job_id, "logs")
+ log_dir = os.path.join(base, csv_monitor_dir_name, self.job_name)
+ os.makedirs(log_dir, exist_ok=True)
+ return log_dir
+
+ def write_events(self, event_list):
+ if self.enabled and dist.get_rank() == 0:
+ import csv
+ # We assume each event_list element is a tensorboard-style tuple in the format: (log_name: String, value, step: Int)
+ for event in event_list:
+ log_name = event[0]
+ value = event[1]
+ step = event[2]
+
+ # Set the header to the log_name
+ # Need this check because the deepspeed engine currently formats log strings to separate with '/'
+ if '/' in log_name:
+ record_splits = log_name.split('/')
+ header = record_splits[len(record_splits) - 1]
+ else:
+ header = log_name
+
+ # sanitize common naming conventions into filename
+ filename = log_name.replace('/', '_').replace(' ', '_')
+ fname = self.log_dir + '/' + filename + '.csv'
+
+ # Open file and record event. Insert header if this is the first time writing
+ with open(fname, 'a+') as csv_monitor_file:
+ csv_monitor_writer = csv.writer(csv_monitor_file)
+ if filename not in self.filenames:
+ self.filenames.append(filename)
+ csv_monitor_writer.writerow(['step', header])
+ csv_monitor_writer.writerow([step, value])
diff --git a/deepspeed/monitor/monitor.py b/deepspeed/monitor/monitor.py
new file mode 100644
index 000000000000..a5ac271861ff
--- /dev/null
+++ b/deepspeed/monitor/monitor.py
@@ -0,0 +1,47 @@
+"""
+ Support different forms of monitoring such as wandb and tensorboard
+"""
+
+from abc import ABC, abstractmethod
+import deepspeed.comm as dist
+
+
+class Monitor(ABC):
+ @abstractmethod
+ def __init__(self, monitor_config):
+ self.monitor_config = monitor_config
+
+ @abstractmethod
+ def write_events(self, event_list):
+ pass
+
+
+from .wandb import WandbMonitor
+from .tensorboard import TensorBoardMonitor
+from .csv_monitor import csvMonitor
+
+
+class MonitorMaster(Monitor):
+ def __init__(self, monitor_config):
+ super().__init__(monitor_config)
+ self.tb_monitor = None
+ self.wandb_monitor = None
+ self.csv_monitor = None
+ self.enabled = monitor_config.tensorboard_enabled or monitor_config.csv_monitor_enabled or monitor_config.wandb_enabled
+
+ if dist.get_rank() == 0:
+ if monitor_config.tensorboard_enabled:
+ self.tb_monitor = TensorBoardMonitor(monitor_config)
+ if monitor_config.wandb_enabled:
+ self.wandb_monitor = WandbMonitor(monitor_config)
+ if monitor_config.csv_monitor_enabled:
+ self.csv_monitor = csvMonitor(monitor_config)
+
+ def write_events(self, event_list):
+ if dist.get_rank() == 0:
+ if self.tb_monitor is not None:
+ self.tb_monitor.write_events(event_list)
+ if self.wandb_monitor is not None:
+ self.wandb_monitor.write_events(event_list)
+ if self.csv_monitor is not None:
+ self.csv_monitor.write_events(event_list)
diff --git a/deepspeed/monitor/tensorboard.py b/deepspeed/monitor/tensorboard.py
new file mode 100644
index 000000000000..447143e53b05
--- /dev/null
+++ b/deepspeed/monitor/tensorboard.py
@@ -0,0 +1,52 @@
+from .utils import check_tb_availability
+from .monitor import Monitor
+import os
+
+import deepspeed.comm as dist
+
+
+class TensorBoardMonitor(Monitor):
+ def __init__(self, monitor_config):
+ super().__init__(monitor_config)
+ check_tb_availability()
+
+ self.summary_writer = None
+ self.enabled = monitor_config.tensorboard_config.enabled
+ self.output_path = monitor_config.tensorboard_config.output_path
+ self.job_name = monitor_config.tensorboard_config.job_name
+
+ if self.enabled and dist.get_rank() == 0:
+ self.get_summary_writer()
+
+ def get_summary_writer(self,
+ base=os.path.join(os.path.expanduser("~"),
+ "tensorboard")):
+ if self.enabled and dist.get_rank() == 0:
+ from torch.utils.tensorboard import SummaryWriter
+ if self.output_path is not None:
+ log_dir = os.path.join(self.output_path, self.job_name)
+ # NOTE: This code path currently is never used since the default output_path is an empty string and not None. Saving it in case we want this functionality in the future.
+ else:
+ if "DLWS_JOB_ID" in os.environ:
+ infra_job_id = os.environ["DLWS_JOB_ID"]
+ elif "DLTS_JOB_ID" in os.environ:
+ infra_job_id = os.environ["DLTS_JOB_ID"]
+ else:
+ infra_job_id = "unknown-job-id"
+
+ summary_writer_dir_name = os.path.join(infra_job_id, "logs")
+ log_dir = os.path.join(base, summary_writer_dir_name, self.output_path)
+ os.makedirs(log_dir, exist_ok=True)
+ self.summary_writer = SummaryWriter(log_dir=log_dir)
+ return self.summary_writer
+
+ def write_events(self, event_list, flush=True):
+ if self.enabled and self.summary_writer is not None and dist.get_rank() == 0:
+ for event in event_list:
+ self.summary_writer.add_scalar(*event)
+ if flush:
+ self.summary_writer.flush()
+
+ def flush(self):
+ if self.enabled and self.summary_writer is not None and dist.get_rank() == 0:
+ self.summary_writer.flush()
diff --git a/deepspeed/monitor/utils.py b/deepspeed/monitor/utils.py
new file mode 100644
index 000000000000..f519a71823a9
--- /dev/null
+++ b/deepspeed/monitor/utils.py
@@ -0,0 +1,18 @@
+def check_tb_availability():
+ try:
+ # torch.utils.tensorboard will fail if `tensorboard` is not available,
+ # see their docs for more details: https://pytorch.org/docs/1.8.0/tensorboard.html
+ import tensorboard
+ except ImportError:
+ print('If you want to use tensorboard logging, please `pip install tensorboard`')
+ raise
+
+
+def check_wandb_availability():
+ try:
+ import wandb
+ except ImportError:
+ print(
+ 'If you want to use wandb logging, please `pip install wandb` and follow the instructions at https://docs.wandb.ai/quickstart'
+ )
+ raise
diff --git a/deepspeed/monitor/wandb.py b/deepspeed/monitor/wandb.py
new file mode 100644
index 000000000000..63f5879633b5
--- /dev/null
+++ b/deepspeed/monitor/wandb.py
@@ -0,0 +1,32 @@
+from .utils import check_wandb_availability
+from .monitor import Monitor
+
+import deepspeed.comm as dist
+
+
+class WandbMonitor(Monitor):
+ def __init__(self, monitor_config):
+ super().__init__(monitor_config)
+ check_wandb_availability()
+ import wandb
+
+ self.enabled = monitor_config.wandb_config.enabled
+ self.group = monitor_config.wandb_config.group
+ self.team = monitor_config.wandb_config.team
+ self.project = monitor_config.wandb_config.project
+
+ if self.enabled and dist.get_rank() == 0:
+ wandb.init(project=self.project, group=self.group, entity=self.team)
+
+ def log(self, data, step=None, commit=None, sync=None):
+ if self.enabled and dist.get_rank() == 0:
+ import wandb
+ return wandb.log(data, step=step, commit=commit, sync=sync)
+
+ def write_events(self, event_list):
+ if self.enabled and dist.get_rank() == 0:
+ for event in event_list:
+ label = event[0]
+ value = event[1]
+ step = event[2]
+ self.log({label: value}, step=step)
diff --git a/deepspeed/ops/transformer/inference/moe_inference.py b/deepspeed/ops/transformer/inference/moe_inference.py
index 855211baf57b..830110cc5f05 100644
--- a/deepspeed/ops/transformer/inference/moe_inference.py
+++ b/deepspeed/ops/transformer/inference/moe_inference.py
@@ -16,7 +16,7 @@
import torch.nn as nn
from .transformer_inference import DeepSpeedSelfAttention, DeepSpeedInferenceConfig
from ....moe.sharded_moe import TopKGate
-import torch.distributed as dist
+from deepspeed import comm as dist
import torch.nn.functional as F
diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py
index 1c3d69bf9d78..4fa162bc3491 100755
--- a/deepspeed/ops/transformer/inference/transformer_inference.py
+++ b/deepspeed/ops/transformer/inference/transformer_inference.py
@@ -10,7 +10,7 @@
import time
from ... import op_builder
import torch.nn as nn
-import torch.distributed as dist
+from deepspeed import comm as dist
# Cuda modules will be imported if needed
inference_cuda_module = None
@@ -471,7 +471,7 @@ def forward(ctx,
config.pre_layer_norm,
False)
else:
- intermediate = mlp_gemm_func(input,
+ intermediate, residual_add = mlp_gemm_func(input,
residual,
bias,
inter_w,
@@ -482,14 +482,16 @@ def forward(ctx,
config.pre_layer_norm,
config.mlp_after_attn)
output = vector_matmul_func(intermediate, output_w, False)
- inference_cuda_module.residual_add(output,
- residual,
- input,
- output_b,
- bias if bias is not None else output_b,
- config.mp_size,
- config.mlp_after_attn,
- bias is not None)
+ inference_cuda_module.residual_add(
+ output,
+ residual if config.pre_layer_norm else residual_add,
+ input,
+ output_b,
+ bias if bias is not None else output_b,
+ config.mp_size,
+ config.mlp_after_attn,
+ bias is not None,
+ config.pre_layer_norm)
if mp_group is not None and dist.get_world_size(group=mp_group) > 1:
dist.all_reduce(output, group=mp_group)
return output
@@ -671,7 +673,7 @@ def forward(self,
self.config.epsilon)
output = output.to(input_type)
- #print(f'[{torch.distributed.get_rank()}] {self.config.layer_id}: {output.norm()}')
+ #print(f'[{deepspeed.comm.get_rank()}] {self.config.layer_id}: {output.norm()}')
#exit()
if get_present:
output = (output, presents)
diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py
index 4b3104b6bdea..084587ba2a3b 100755
--- a/deepspeed/ops/transformer/transformer.py
+++ b/deepspeed/ops/transformer/transformer.py
@@ -218,7 +218,7 @@ def forward(ctx,
output_b,
norm_w,
norm_b,
- config.training,
+ config.training and config.is_grad_enabled,
config.pre_layer_norm,
config.attn_dropout_checkpoint,
config.normalize_invertible,
diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py
index 35ef25938d79..48c03b309d87 100644
--- a/deepspeed/profiling/flops_profiler/profiler.py
+++ b/deepspeed/profiling/flops_profiler/profiler.py
@@ -287,7 +287,7 @@ def print_model_profile(self,
print('{:<60} {:<8}'.format(
'params of model = params per GPU * mp_size: ',
params_to_string(total_params *
- (self.ds_engine.mp_world_size) if self.ds_engine else 1)))
+ ((self.ds_engine.mp_world_size) if self.ds_engine else 1))))
print('{:<60} {:<8}'.format('fwd MACs per GPU: ', macs_to_string(total_macs)))
@@ -296,7 +296,7 @@ def print_model_profile(self,
print('{:<60} {:<8}'.format(
'fwd flops of model = fwd flops per GPU * mp_size: ',
num_to_string(total_flops *
- (self.ds_engine.mp_world_size) if self.ds_engine else 1)))
+ ((self.ds_engine.mp_world_size) if self.ds_engine else 1))))
fwd_latency = self.get_total_duration()
if self.ds_engine and self.ds_engine.wall_clock_breakdown():
diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py
index 06d1a2a59833..e3a945f2eff6 100644
--- a/deepspeed/runtime/activation_checkpointing/checkpointing.py
+++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py
@@ -16,7 +16,7 @@
import copy
import torch
import contextlib
-import torch.distributed as dist
+from deepspeed import comm as dist
import mmap
from torch import _C
@@ -222,12 +222,12 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed.
data_parallel_seed = seed
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
logger.info(
'> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
- torch.distributed.get_rank(),
+ dist.get_rank(),
tp_rank,
mpu.get_data_parallel_rank(),
model_parallel_seed,
diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py
index 3805fee7e306..9dd3bdd4e329 100644
--- a/deepspeed/runtime/bf16_optimizer.py
+++ b/deepspeed/runtime/bf16_optimizer.py
@@ -1,5 +1,5 @@
import torch
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.runtime import ZeROOptimizer
@@ -68,7 +68,7 @@ def get_full_hp_param(self, optim_state_key=None):
hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)
reduce_fragment.data.copy_(hp_fragment.data)
- torch.distributed.all_reduce(reduce_buffer, group=self._dp_group)
+ dist.all_reduce(reduce_buffer, group=self._dp_group)
return reduce_buffer.reshape_as(self)
diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py
index 3b42d223182d..e92af044f53e 100644
--- a/deepspeed/runtime/comm/coalesced_collectives.py
+++ b/deepspeed/runtime/comm/coalesced_collectives.py
@@ -6,36 +6,19 @@
import torch
from torch import Tensor
-import torch.distributed
+from deepspeed import comm as dist
+# NOTE: Use torch.distributed's ProcessGroup class until we have our own.
from torch.distributed import ProcessGroup
import torch.nn.functional
from deepspeed.utils import instrument_w_nvtx
from deepspeed.utils.logging import logger
-if hasattr(torch.distributed, "_reduce_scatter_base"):
-
- def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
- instrument_w_nvtx(torch.distributed._reduce_scatter_base)(
- output_tensor,
- input_tensor,
- group=group,
- )
-else:
- logger.warning(
- "unable to find torch.distributed._reduce_scatter_base. will fall back to "
- "torch.distributed.reduce_scatter which will result in suboptimal performance. "
- "please consider upgrading your pytorch installation.")
-
- def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
- input_tensor_lst = list(
- torch.chunk(input_tensor,
- torch.distributed.get_world_size(group)))
- instrument_w_nvtx(torch.distributed.reduce_scatter)(
- output_tensor,
- input_tensor_lst,
- group=group,
- )
+
+def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group):
+ return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor,
+ input_tensor,
+ group=group)
@instrument_w_nvtx
@@ -49,8 +32,8 @@ def reduce_scatter_coalesced(
TODO. see if PyTorch team wants a c++ version of this for ProcessGroupNCCL
"""
- this_rank = torch.distributed.get_rank(group)
- world_sz = torch.distributed.get_world_size(group)
+ this_rank = dist.get_rank(group)
+ world_sz = dist.get_world_size(group)
partition_lst_for_each_tensor = [None] * len(tensors)
for tensor_idx, tensor in enumerate(tensors):
@@ -97,9 +80,9 @@ def reduce_scatter_coalesced(
world_sz)
# batched reduce-scatter call
- torch_reduce_scatter_fn(tensor_partition_flat_buffer,
- tensor_partition_buffer_for_each_rank[this_rank],
- group)
+ _torch_reduce_scatter_fn(tensor_partition_flat_buffer,
+ tensor_partition_buffer_for_each_rank[this_rank],
+ group)
# reverse procedure of the interleaving done previously, done on the
# result of the batched reduce-scatter
diff --git a/deepspeed/runtime/comm/nccl.py b/deepspeed/runtime/comm/nccl.py
index fcab478a3ca8..ed80059a9067 100644
--- a/deepspeed/runtime/comm/nccl.py
+++ b/deepspeed/runtime/comm/nccl.py
@@ -3,7 +3,7 @@
'''
import torch
-import torch.distributed as dist
+from deepspeed import comm as dist
import time
import cupy
import numpy as np
diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py
index 1df5912ef172..4571cbdf7056 100755
--- a/deepspeed/runtime/config.py
+++ b/deepspeed/runtime/config.py
@@ -24,6 +24,9 @@
from .zero.config import DeepSpeedZeroConfig
from .zero.constants import *
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
+from ..monitor.config import DeepSpeedMonitorConfig
+
+from deepspeed import comm as dist
from ..git_version_info import version as __version__
from ..utils import logger
@@ -615,15 +618,6 @@ def get_memory_breakdown(param_dict):
return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT)
-def get_tensorboard_enabled(param_dict):
- if TENSORBOARD in param_dict.keys():
- return get_scalar_param(param_dict[TENSORBOARD],
- TENSORBOARD_ENABLED,
- TENSORBOARD_ENABLED_DEFAULT)
- else:
- return False
-
-
def get_eigenvalue_config(param_dict):
if get_quantize_enabled(param_dict):
param_dict = param_dict[QUANTIZE_TRAINING]
@@ -724,26 +718,6 @@ def get_eigenvalue_layer_num(param_dict):
return EIGENVALUE_LAYER_NUM_DEFAULT
-def get_tensorboard_output_path(param_dict):
- if get_tensorboard_enabled(param_dict):
- return get_scalar_param(
- param_dict[TENSORBOARD],
- TENSORBOARD_OUTPUT_PATH,
- TENSORBOARD_OUTPUT_PATH_DEFAULT,
- )
- else:
- return TENSORBOARD_OUTPUT_PATH_DEFAULT
-
-
-def get_tensorboard_job_name(param_dict):
- if get_tensorboard_enabled(param_dict):
- return get_scalar_param(param_dict[TENSORBOARD],
- TENSORBOARD_JOB_NAME,
- TENSORBOARD_JOB_NAME_DEFAULT)
- else:
- return TENSORBOARD_JOB_NAME_DEFAULT
-
-
def get_checkpoint_params(param_dict):
return param_dict.get(CHECKPOINT, {})
@@ -803,9 +777,9 @@ def __init__(self, config: Union[str, dict], mpu=None):
f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}"
)
try:
- self.global_rank = torch.distributed.get_rank()
+ self.global_rank = dist.get_rank()
if mpu is None:
- self.world_size = torch.distributed.get_world_size()
+ self.world_size = dist.get_world_size()
else:
self.world_size = mpu.get_data_parallel_world_size()
except:
@@ -897,6 +871,8 @@ def _initialize_params(self, param_dict):
self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(
param_dict)
+ self.monitor_config = DeepSpeedMonitorConfig(param_dict)
+
self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
@@ -943,9 +919,6 @@ def _initialize_params(self, param_dict):
| self.flops_profiler_config.enabled)
self.memory_breakdown = get_memory_breakdown(param_dict)
self.autotuning_config = DeepSpeedAutotuningConfig(param_dict)
- self.tensorboard_enabled = get_tensorboard_enabled(param_dict)
- self.tensorboard_output_path = get_tensorboard_output_path(param_dict)
- self.tensorboard_job_name = get_tensorboard_job_name(param_dict)
(
self.eigenvalue_enabled,
diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py
index ee2e51c6109f..88b055b3e210 100755
--- a/deepspeed/runtime/constants.py
+++ b/deepspeed/runtime/constants.py
@@ -282,33 +282,6 @@
MEMORY_BREAKDOWN = 'memory_breakdown'
MEMORY_BREAKDOWN_DEFAULT = False
-#########################################
-# Tensorboard
-#########################################
-# Tensorboard. By default, this feature is not enabled.
-# Users can configure in ds_config.json as below example:
-TENSORBOARD_FORMAT = '''
-Tensorboard can be specified as:
-"tensorboard": {
- "enabled": true,
- "output_path": "/home/myname/foo",
- "job_name": "model_lr2e-5_epoch3_seed2_seq64"
-}
-'''
-TENSORBOARD = "tensorboard"
-
-# Tensorboard enable signal
-TENSORBOARD_ENABLED = "enabled"
-TENSORBOARD_ENABLED_DEFAULT = False
-
-# Tensorboard output path
-TENSORBOARD_OUTPUT_PATH = "output_path"
-TENSORBOARD_OUTPUT_PATH_DEFAULT = ""
-
-# Tensorboard job name
-TENSORBOARD_JOB_NAME = "job_name"
-TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"
-
#########################################
# Eigenvalue
#########################################
diff --git a/deepspeed/runtime/dataloader.py b/deepspeed/runtime/dataloader.py
index acd21d9557fe..84b206e1eecd 100644
--- a/deepspeed/runtime/dataloader.py
+++ b/deepspeed/runtime/dataloader.py
@@ -68,10 +68,15 @@ def __init__(self,
self.device_count = device_count
self.batch_size = batch_size
self.pin_memory = pin_memory
- self.len = len(self.data_sampler)
self.data = None
self.dataloader_drop_last = dataloader_drop_last
+ if self.dataloader_drop_last:
+ self.len = len(self.data_sampler) // self.batch_size
+ else:
+ from math import ceil
+ self.len = ceil(len(self.data_sampler) / self.batch_size)
+
def __iter__(self):
self._create_dataloader()
return self
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 731708549ab6..892394e049ad 100644
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -9,7 +9,6 @@
import torch
import warnings
import hashlib
-import torch.distributed as dist
from collections import defaultdict, OrderedDict
from shutil import copyfile
@@ -26,9 +25,12 @@
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
+from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
+
from deepspeed.runtime.activation_checkpointing import (
checkpointing as activation_checkpointing,
)
+
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
@@ -46,12 +48,14 @@
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
from deepspeed.runtime.sparse_tensor import SparseTensor
-import deepspeed.runtime.lr_schedules as lr_schedules
-import deepspeed.utils.groups as groups
+from deepspeed.runtime import lr_schedules
+from deepspeed.utils import groups
from deepspeed.runtime.utils import get_grad_norm
-from deepspeed.utils import logger, log_dist, init_distributed, instrument_w_nvtx
+from deepspeed.utils import logger, log_dist, instrument_w_nvtx
+from deepspeed.comm.comm import init_distributed
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
+from deepspeed.monitor.monitor import MonitorMaster
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.utils import clip_grad_norm_
from deepspeed.runtime.eigenvalue import Eigenvalue
@@ -70,6 +74,9 @@
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
from deepspeed.utils.logging import print_json_dist
+# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
+dist = None
+
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
DeepSpeedOptimizerCallable = \
@@ -212,6 +219,10 @@ def __init__(
self.moe_layers = []
self._step_applied = False
self._global_grad_norm = None
+ self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend.
+
+ global dist
+ from deepspeed import comm as dist
self._is_gradient_accumulation_boundary = None
# for debug purposes - can then debug print: debug_get_module_name(module)
@@ -224,16 +235,22 @@ def __init__(
if self.config is None and config_params is not None:
self.config = config_params
- if dist_init_required is None:
- dist_init_required = not dist.is_initialized()
-
- if dist_init_required is False:
- assert (
- dist.is_initialized() is True
- ), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
+ from deepspeed.comm import supported_torch_version
+ # This supported_torch_version check is for torch1.2 compatibility only
+ if supported_torch_version:
+ dist.init_distributed(dist_backend=self.dist_backend,
+ dist_init_required=dist_init_required)
else:
- # Initialize torch distributed if needed
- init_distributed(dist_backend=self.dist_backend)
+ if dist_init_required is None:
+ dist_init_required = not dist.is_initialized()
+
+ if dist_init_required is False:
+ assert (
+ dist.is_initialized() is True
+ ), "Torch distributed not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
+ else:
+ if not dist.is_initialized():
+ dist.init_process_group(backend=self.dist_backend)
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
@@ -247,8 +264,7 @@ def __init__(
self._set_distributed_vars(args)
- if self.tensorboard_enabled() and self.global_rank == 0:
- self.summary_writer = self.get_summary_writer()
+ self.monitor = MonitorMaster(self._config.monitor_config)
see_memory_usage(
f"DeepSpeed Engine: Before configure distributed model",
@@ -315,7 +331,8 @@ def __init__(
self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
- self._configure_checkpointing(dist_init_required)
+ if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
+ self._configure_checkpointing(dist_init_required)
if self.eigenvalue_enabled():
self.eigenvalue = self._configure_eigenvalue()
@@ -488,54 +505,6 @@ def curriculum_enabled(self):
def curriculum_params(self):
return self._config.curriculum_params
- def tensorboard_enabled(self):
- return self._config.tensorboard_enabled
-
- def tensorboard_output_path(self):
- return self._config.tensorboard_output_path
-
- def tensorboard_job_name(self):
- return self._config.tensorboard_job_name
-
- def get_summary_writer(
- self,
- name="DeepSpeedJobName",
- base=os.path.join(os.path.expanduser("~"),
- "tensorboard"),
- ):
- if self.tensorboard_output_path():
- base_dir = self.tensorboard_output_path()
- job_name = self.tensorboard_job_name()
- log_dir = os.path.join(base_dir, job_name)
- else:
- if self.tensorboard_job_name():
- name = self.tensorboard_job_name()
-
- # Infrastructure-specific job-id
- if "DLWS_JOB_ID" in os.environ:
- infra_job_id = os.environ["DLWS_JOB_ID"]
- elif "DLTS_JOB_ID" in os.environ:
- infra_job_id = os.environ["DLTS_JOB_ID"]
- else:
- infra_job_id = "unknown-job-id"
-
- summary_writer_dir_name = os.path.join(infra_job_id, "logs")
- log_dir = os.path.join(base, summary_writer_dir_name, name)
-
- os.makedirs(log_dir, exist_ok=True)
- try:
- # torch.utils.tensorboard will fail if `tensorboard` is not available,
- # see their docs for more details: https://pytorch.org/docs/1.8.0/tensorboard.html
- import tensorboard
- except ImportError:
- print(
- 'If you want to use tensorboard logging please `pip install tensorboard`'
- )
- raise
- from torch.utils.tensorboard import SummaryWriter
-
- return SummaryWriter(log_dir=log_dir)
-
def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown
@@ -808,8 +777,7 @@ def _configure_checkpointing(self, dist_init_required):
dp_rank == 0) or self.zero_optimization_partition_weights()
if self.zero_optimization() or self.bfloat16_enabled():
- param_rank = torch.distributed.get_rank(
- group=self.optimizer.dp_process_group)
+ param_rank = dist.get_rank(group=self.optimizer.dp_process_group)
# Only the first parameter parallel process needs to store the
# optimizer state checkpoints for zero
@@ -884,7 +852,7 @@ def _do_args_sanity_check(self, args):
args.deepspeed_config = args.deepscale_config
assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \
- "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch.distributed launcher. If using a " \
+ "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \
"different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed."
if hasattr(args, 'local_rank') and args.local_rank != None:
@@ -1373,7 +1341,6 @@ def _configure_zero_optimizer(self, optimizer):
"Pipeline parallelism does not support overlapped communication, will be disabled."
)
overlap_comm = False
-
optimizer = DeepSpeedZeroOptimizer(
optimizer,
timers=timers,
@@ -1410,33 +1377,47 @@ def _configure_zero_optimizer(self, optimizer):
logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
- optimizer = DeepSpeedZeroOptimizer_Stage3(
- self.module,
- optimizer,
- timers=timers,
- ds_config=self.config,
- static_loss_scale=self.loss_scale(),
- dynamic_loss_scale=self.dynamic_loss_scale(),
- dynamic_loss_args=self.dynamic_loss_scale_args(),
- clip_grad=self.gradient_clipping(),
- contiguous_gradients=self.zero_contiguous_gradients(),
- reduce_bucket_size=self.zero_reduce_bucket_size(),
- prefetch_bucket_size=self.zero_prefetch_bucket_size(),
- max_reuse_distance=self.zero_max_reuse_distance(),
- max_live_parameters=self.zero_max_live_parameters(),
- param_persistence_threshold=self.zero_param_persistence_threshold(),
- dp_process_group=self.data_parallel_group,
- reduce_scatter=self.zero_reduce_scatter(),
- overlap_comm=self.zero_overlap_comm(),
- offload_optimizer_config=self.zero_offload_optimizer(),
- offload_param_config=self.zero_offload_param(),
- sub_group_size=self.zero_sub_group_size(),
- mpu=self.mpu,
- postscale_gradients=self.postscale_gradients(),
- gradient_predivide_factor=self.gradient_predivide_factor(),
- gradient_accumulation_steps=self.gradient_accumulation_steps(),
- aio_config=self.aio_config(),
- communication_data_type=self.communication_data_type)
+ if isinstance(optimizer, DummyOptim):
+ optimizer = DeepSpeedZeRoOffload(
+ self.module,
+ timers=timers,
+ ds_config=self.config,
+ overlap_comm=self.zero_overlap_comm(),
+ prefetch_bucket_size=self.zero_prefetch_bucket_size(),
+ max_reuse_distance=self.zero_max_reuse_distance(),
+ max_live_parameters=self.zero_max_live_parameters(),
+ param_persistence_threshold=self.zero_param_persistence_threshold(),
+ offload_param_config=self.zero_offload_param(),
+ mpu=self.mpu)
+ else:
+
+ optimizer = DeepSpeedZeroOptimizer_Stage3(
+ self.module,
+ optimizer,
+ timers=timers,
+ ds_config=self.config,
+ static_loss_scale=self.loss_scale(),
+ dynamic_loss_scale=self.dynamic_loss_scale(),
+ dynamic_loss_args=self.dynamic_loss_scale_args(),
+ clip_grad=self.gradient_clipping(),
+ contiguous_gradients=self.zero_contiguous_gradients(),
+ reduce_bucket_size=self.zero_reduce_bucket_size(),
+ prefetch_bucket_size=self.zero_prefetch_bucket_size(),
+ max_reuse_distance=self.zero_max_reuse_distance(),
+ max_live_parameters=self.zero_max_live_parameters(),
+ param_persistence_threshold=self.zero_param_persistence_threshold(),
+ dp_process_group=self.data_parallel_group,
+ reduce_scatter=self.zero_reduce_scatter(),
+ overlap_comm=self.zero_overlap_comm(),
+ offload_optimizer_config=self.zero_offload_optimizer(),
+ offload_param_config=self.zero_offload_param(),
+ sub_group_size=self.zero_sub_group_size(),
+ mpu=self.mpu,
+ postscale_gradients=self.postscale_gradients(),
+ gradient_predivide_factor=self.gradient_predivide_factor(),
+ gradient_accumulation_steps=self.gradient_accumulation_steps(),
+ aio_config=self.aio_config(),
+ communication_data_type=self.communication_data_type)
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
@@ -1654,9 +1635,9 @@ def print_forward_breakdown(self, fwd_time):
# TODO: Allreduce/average them across ranks for more accurate timing.
- # if torch.distributed.get_rank() == 0:
+ # if deepspeed.comm.get_rank() == 0:
log_dist(
- f"rank={torch.distributed.get_rank()} time (ms) | forward: {fwd_time:.2f} (forward_moe: {moe_time:.2f}, 1st alltoall: {falltoall:.2f}, 2nd alltoall: {salltoall:.2f}, top-k: {gate_time:.2f})",
+ f"rank={dist.get_rank()} time (ms) | forward: {fwd_time:.2f} (forward_moe: {moe_time:.2f}, 1st alltoall: {falltoall:.2f}, 2nd alltoall: {salltoall:.2f}, top-k: {gate_time:.2f})",
ranks=[0])
@instrument_w_nvtx
@@ -1701,7 +1682,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False):
loss = self._scale_loss_by_gas(loss.float())
# Log training Loss
- if self.tensorboard_enabled():
+ if self.monitor.enabled:
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(
@@ -1709,9 +1690,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False):
loss.mean().item() * self.gradient_accumulation_steps(),
self.global_samples,
)]
- for event in self.summary_events: # write_summary_events
- self.summary_writer.add_scalar(event[0], event[1], event[2])
- self.summary_writer.flush()
+ self.monitor.write_events(self.summary_events)
self._start_timers(self.engine_timers.backward_timers)
@@ -1934,14 +1913,13 @@ def step(self, lr_kwargs=None):
self._stop_timers(self.engine_timers.step_timers)
# Log learning rate
- if self.tensorboard_enabled():
+ if self.monitor.enabled:
if self.is_gradient_accumulation_boundary():
if self.global_rank == 0:
self.summary_events = [(f"Train/Samples/lr",
self.get_lr()[0],
self.global_samples)]
- for event in self.summary_events: # write_summary_events
- self.summary_writer.add_scalar(event[0], event[1], event[2])
+
if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"):
self.summary_events.append((
f"Train/Samples/loss_scale",
@@ -1953,16 +1931,12 @@ def step(self, lr_kwargs=None):
self.eigenvalue_gas_boundary_resolution()):
ev_values = self.block_eigenvalue.values()
for i in range(len(ev_values)):
- self.summary_writer.add_scalar(
+ self.summary_events.append((
f"Train/Eigenvalues/ModelBlockParam_{i}",
self.ev_values[i][0],
self.global_samples,
- )
- self.summary_writer.flush()
-
- for event in self.summary_events: # write_summary_events
- self.summary_writer.add_scalar(event[0], event[1], event[2])
- self.summary_writer.flush()
+ ))
+ self.monitor.write_events(self.summary_events)
# Check flops profiling
if flops_profiler_active:
@@ -1990,8 +1964,8 @@ def step(self, lr_kwargs=None):
if self.wall_clock_breakdown() or self.flops_profiler_enabled():
# Log global timing and reset
if self.is_gradient_accumulation_boundary():
- if self.tensorboard_enabled():
- self._write_tensorboard()
+ if self.monitor.enabled:
+ self._write_monitor()
if self.has_moe_layers:
fwd_time = self.timers(FORWARD_GLOBAL_TIMER).elapsed(
@@ -2034,7 +2008,7 @@ def _autotuning_exit(self):
atexit.register(print, "Autotuning: done with running current ds config.")
exit()
- def _write_tensorboard(self):
+ def _write_monitor(self):
if self.global_rank == 0:
self.summary_events = [
(
@@ -2065,9 +2039,7 @@ def _write_tensorboard(self):
self.global_samples,
),
]
- for event in self.summary_events: # write_summary_events
- self.summary_writer.add_scalar(event[0], event[1], event[2])
- self.summary_writer.flush()
+ self.monitor.write_events(self.summary_events)
def _get_optimizer_param(self, param_name):
result = []
@@ -2344,7 +2316,7 @@ def load_moe_state_dict(checkpoint_path,
else:
moe_layer_id = 0
for n_module, module in model.named_modules():
- if isinstance(module, MoE): # and torch.distributed.get_rank() == 0:
+ if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
group_name = module.expert_group_name
num_local_experts = module.num_local_experts
expp_rank = groups._get_expert_parallel_rank(group_name)
@@ -2395,7 +2367,7 @@ def _get_rank_zero_ckpt_name(self,
def _get_zero_ckpt_name(self, checkpoints_path, tag):
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
- pp_rank = torch.distributed.get_rank(group=self.optimizer.dp_process_group)
+ pp_rank = dist.get_rank(group=self.optimizer.dp_process_group)
bf16_mode = self.bfloat16_enabled()
return self._get_rank_zero_ckpt_name(checkpoints_path,
tag,
@@ -2412,7 +2384,7 @@ def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
if self.zero_optimization_partition_weights():
filename = "zero_pp_rank_{}".format(
- torch.distributed.get_rank(group=self.optimizer.dp_process_group))
+ dist.get_rank(group=self.optimizer.dp_process_group))
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
@@ -2781,8 +2753,8 @@ def _checkpoint_tag_validation(self, tag):
bhash = torch.ByteTensor([s_hash.digest()]).flatten().to(self.device)
max_bhash = bhash.clone()
min_bhash = bhash.clone()
- dist.all_reduce(max_bhash, op=torch.distributed.ReduceOp.MAX)
- dist.all_reduce(min_bhash, op=torch.distributed.ReduceOp.MIN)
+ dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX)
+ dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = (
f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across "
@@ -2817,7 +2789,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
# Ensure save_dir directory exists
os.makedirs(save_dir, exist_ok=True)
- torch.distributed.barrier()
+ dist.barrier()
if tag is None:
tag = f"global_step{self.global_steps}"
@@ -2845,7 +2817,7 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
self.optimizer.checkpoint_event_epilogue()
# Save latest checkpoint tag
- torch.distributed.barrier()
+ dist.barrier()
if save_latest and self.global_rank == 0:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
@@ -2871,7 +2843,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
# Using layer_#_export_# to save the model's expert state_dict
moe_layer_id = 0
for n_module, module in self.module.named_modules():
- if isinstance(module, MoE): # and torch.distributed.get_rank() == 0:
+ if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
group_name = module.expert_group_name
num_local_experts = module.num_local_experts
expp_rank = groups._get_expert_parallel_rank(group_name)
@@ -3135,7 +3107,7 @@ def _zero3_consolidated_16bit_state_dict(self):
if not self.zero_optimization_partition_weights():
raise ValueError("this function requires ZeRO-3 mode")
- state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None
+ state_dict = OrderedDict() if dist.get_rank() == 0 else None
shared_params = {}
def get_layer_state_dict(module, prefix=""):
@@ -3145,7 +3117,7 @@ def get_layer_state_dict(module, prefix=""):
with deepspeed.zero.GatheredParameters(list(
module.parameters(recurse=False)),
modifier_rank=0):
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
# handle params
for name, param in module.named_parameters(recurse=False):
if param is None:
@@ -3224,7 +3196,7 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"):
else:
state_dict = self.module.state_dict()
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
os.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}")
torch.save(state_dict, path)
diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py
index 0a18b8fdba97..479a0f7a2839 100755
--- a/deepspeed/runtime/fp16/fused_optimizer.py
+++ b/deepspeed/runtime/fp16/fused_optimizer.py
@@ -12,8 +12,8 @@
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import groups, logger, log_dist
+from deepspeed import comm as dist
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
-import torch.distributed as dist
class FP16_Optimizer(DeepSpeedOptimizer):
@@ -338,7 +338,7 @@ def _get_norm_with_moe_layers(self, all_groups_norm):
dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=pg)
all_groups_norm = scaled_norm_tensor.item()
- #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {torch.distributed.get_rank()}")
+ #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
return all_groups_norm
def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py
index 5ce451d7676f..71805176ae41 100644
--- a/deepspeed/runtime/fp16/onebit/adam.py
+++ b/deepspeed/runtime/fp16/onebit/adam.py
@@ -6,7 +6,7 @@
import importlib
import numpy as np
import time
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils.logging import logger
@@ -185,7 +185,7 @@ def step(self, closure=None, grads=None):
device=p.device)
torch.cuda.empty_cache()
self.adam_freeze_key = True
- if not self.initialize and torch.distributed.get_rank() == 0:
+ if not self.initialize and dist.get_rank() == 0:
print("Cupy Buffers Initialized Successfully.")
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
@@ -249,9 +249,7 @@ def step(self, closure=None, grads=None):
if not self.initialize:
self.adam_freeze_key = False
self.initialize = True
- print(
- f"Finished the initialization step at rank {torch.distributed.get_rank()}"
- )
+ print(f"Finished the initialization step at rank {dist.get_rank()}")
return loss
if self.adam_freeze_key is False:
@@ -282,7 +280,7 @@ def load_state_dict(self, state_dict):
state_dict['param_groups'][i].pop('exp_avg_mask')
super().load_state_dict(state_dict)
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print("Checkpoint loaded and OnebitAdam warmup stage starts/continues.")
if self.adam_freeze_key is True:
self.adam_freeze_key = False
@@ -291,7 +289,7 @@ def load_state_dict(self, state_dict):
else:
self.deepspeed.enable_backward_allreduce = True
else:
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print(
"Checkpoint loaded and OnebitAdam compression stage starts/continues."
)
diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py
index 01c6cd878488..aeff08b9861b 100644
--- a/deepspeed/runtime/fp16/onebit/lamb.py
+++ b/deepspeed/runtime/fp16/onebit/lamb.py
@@ -4,7 +4,7 @@
import types
import torch
import numpy as np
-import torch.distributed as dist
+from deepspeed import comm as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@@ -305,7 +305,7 @@ def step(self, closure=None, grads=None):
torch.zeros(self.server_chunk_sizes[i],
device=self.exp_avg_flat[i].device))
torch.cuda.empty_cache()
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print("Cupy Buffers Initialized Successfully.")
self.comm_backend_handle.compressed_allreduce(
@@ -314,7 +314,7 @@ def step(self, closure=None, grads=None):
self.server_errors[0],
self.deepspeed.local_rank)
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print('Pop out errors', flush=True)
del self.worker_errors[:]
del self.server_errors[:]
@@ -389,9 +389,7 @@ def step(self, closure=None, grads=None):
if not self.initialize:
self.lamb_freeze_key = False
self.initialize = True
- print(
- f"Finished the initialization step at rank {torch.distributed.get_rank()}"
- )
+ print(f"Finished the initialization step at rank {dist.get_rank()}")
return loss
if self.lamb_freeze_key is False:
@@ -427,7 +425,7 @@ def load_state_dict(self, state_dict):
del self.corrected_tensor_sizes[:]
del self.server_chunk_sizes[:]
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print("Checkpoint loaded and OnebitLamb warmup stage starts/continues.")
if self.lamb_freeze_key is True:
self.lamb_freeze_key = False
@@ -442,7 +440,7 @@ def load_state_dict(self, state_dict):
if 'scaling_coeff' in self.state[p]:
self.state[p].pop('scaling_coeff')
else:
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print(
"Checkpoint loaded and OnebitLamb compression stage starts/continues."
)
diff --git a/deepspeed/runtime/fp16/onebit/zoadam.py b/deepspeed/runtime/fp16/onebit/zoadam.py
index b0238b1997f5..11ccb404a9ff 100644
--- a/deepspeed/runtime/fp16/onebit/zoadam.py
+++ b/deepspeed/runtime/fp16/onebit/zoadam.py
@@ -6,7 +6,7 @@
import importlib
import numpy as np
import time
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils.logging import logger
@@ -198,7 +198,7 @@ def step(self, closure=None, grads=None):
state['momentum_accumulator'] = torch.zeros_like(p.data)
torch.cuda.empty_cache()
# self.freeze_key = True
- if not self.initialize and torch.distributed.get_rank() == 0:
+ if not self.initialize and dist.get_rank() == 0:
print("Cupy Buffers Initialized Successfully.")
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
@@ -313,9 +313,7 @@ def step(self, closure=None, grads=None):
if not self.initialize:
self.initialize = True
- print(
- f"Finished the initialization step at rank {torch.distributed.get_rank()}"
- )
+ print(f"Finished the initialization step at rank {dist.get_rank()}")
return loss
if self.state[self.param_groups[0]['params'][0]]['step'] > self.var_freeze_step:
diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py
index 1fb4eb672dab..88f0df443405 100755
--- a/deepspeed/runtime/fp16/unfused_optimizer.py
+++ b/deepspeed/runtime/fp16/unfused_optimizer.py
@@ -14,6 +14,7 @@
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT
+from deepspeed import comm as dist
class FP16_UnfusedOptimizer(DeepSpeedOptimizer):
@@ -36,7 +37,7 @@ def __init__(self,
self.fused_lamb_legacy = fused_lamb_legacy
self._global_grad_norm = 0.
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
if not torch.cuda.is_available:
diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py
index 61ef60f535cf..94add6f9c8e4 100644
--- a/deepspeed/runtime/pipe/engine.py
+++ b/deepspeed/runtime/pipe/engine.py
@@ -12,7 +12,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils.logging import logger
from deepspeed.utils.timer import SynchronizedWallClockTimer, ThroughputTimer
@@ -365,16 +365,12 @@ def train_batch(self, data_iter=None):
f'iter time (s): {iter_time:0.3f} '
f'samples/sec: {tput:0.3f}')
- # Tensorboard
- if self.tensorboard_enabled():
- if self.global_rank == 0:
- self.summary_events = [(f'Train/Samples/train_loss',
- self.agg_train_loss.mean().item(),
- self.global_samples)]
- for event in self.summary_events: # write_summary_events
- self.summary_writer.add_scalar(event[0], event[1], event[2])
- if self.global_steps % self.steps_per_print() == 0:
- self.summary_writer.flush()
+ # Monitoring
+ if self.global_rank == 0 and self.monitor.enabled:
+ self.summary_events = [(f'Train/Samples/train_loss',
+ self.agg_train_loss.mean().item(),
+ self.global_samples)]
+ self.monitor.write_events(self.summary_events)
if self.wall_clock_breakdown(
) and self.global_steps % self.steps_per_print() == 0:
@@ -458,14 +454,11 @@ def eval_batch(self,
if compute_loss:
eval_output = self._bcast_pipe_scalar(eval_output)
- if self.tensorboard_enabled():
- if self.global_rank == 0:
- self.summary_events = [(f'Train/Samples/eval_loss',
- eval_output.mean().item(),
- self.global_samples)]
- for event in self.summary_events: # write_summary_events
- self.summary_writer.add_scalar(event[0], event[1], event[2])
- self.summary_writer.flush()
+ if self.global_rank == 0 and self.monitor.enabled:
+ self.summary_events = [(f'Train/Samples/eval_loss',
+ eval_output.mean().item(),
+ self.global_samples)]
+ self.monitor.write_events(self.summary_events)
# Restore the training iterator
self.set_dataiterator(train_iterator)
@@ -1171,17 +1164,15 @@ def _exec_optimizer_step(self, lr_kwargs=None):
self.mem_status('AFTER STEP')
- if self.tensorboard_enabled():
- if self.global_rank == 0:
- self.summary_events = [(f'Train/Samples/lr',
- self.get_lr()[0],
- self.global_samples)]
- if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
- self.summary_events.append((f'Train/Samples/loss_scale',
- self.optimizer.cur_scale,
- self.global_samples))
- for event in self.summary_events: # write_summary_events
- self.summary_writer.add_scalar(event[0], event[1], event[2])
+ if self.global_rank == 0 and self.monitor.enabled:
+ self.summary_events = [(f'Train/Samples/lr',
+ self.get_lr()[0],
+ self.global_samples)]
+ if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):
+ self.summary_events.append((f'Train/Samples/loss_scale',
+ self.optimizer.cur_scale,
+ self.global_samples))
+ self.monitor.write_events(self.summary_events)
if self.wall_clock_breakdown():
self.timers('step_microstep').stop()
diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py
index 4121a0f8fd3f..f47806ee8673 100644
--- a/deepspeed/runtime/pipe/module.py
+++ b/deepspeed/runtime/pipe/module.py
@@ -9,7 +9,7 @@
import torch
import torch.nn as nn
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils import logger
from .. import utils as ds_utils
diff --git a/deepspeed/runtime/pipe/p2p.py b/deepspeed/runtime/pipe/p2p.py
index 8a1b71926006..d0c9c2f9f364 100644
--- a/deepspeed/runtime/pipe/p2p.py
+++ b/deepspeed/runtime/pipe/p2p.py
@@ -6,7 +6,7 @@
import typing
import torch
-import torch.distributed as dist
+from deepspeed import comm as dist
# To query whether we have send/recv support
from packaging.version import Version
@@ -25,7 +25,7 @@ def can_send_recv() -> bool:
#initializes adjacent process groups
-#run this only after torch.distributed.init_process_group() has been called
+#run this only after deepspeed.init_distributed() has been called
def init_process_groups(grid):
global _groups, _grid
_grid = grid
diff --git a/deepspeed/runtime/pipe/topology.py b/deepspeed/runtime/pipe/topology.py
index 240c973a3fc1..954e73592943 100644
--- a/deepspeed/runtime/pipe/topology.py
+++ b/deepspeed/runtime/pipe/topology.py
@@ -2,7 +2,7 @@
from deepspeed.utils import logger
-import torch.distributed as dist
+from deepspeed import comm as dist
import sys
from collections import namedtuple
diff --git a/deepspeed/runtime/swap_tensor/async_swapper.py b/deepspeed/runtime/swap_tensor/async_swapper.py
index e6e19a4c67ef..45614abd794e 100644
--- a/deepspeed/runtime/swap_tensor/async_swapper.py
+++ b/deepspeed/runtime/swap_tensor/async_swapper.py
@@ -6,6 +6,7 @@
"""
import torch
+from deepspeed import comm as dist
from deepspeed.utils.logging import logger
from deepspeed.runtime.swap_tensor.utils import swap_out_tensors, SwapBuffer
@@ -66,10 +67,10 @@ def swap_out_tensors(self, tensor_list, path_list):
self._swap_out_tensor(tensor, swap_path)
def _report_statistics(self, message):
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
element_size = torch.tensor([], dtype=self.dtype).element_size()
swapped_GB = (self.num_elements_swapped * element_size) / (1024**3)
- logger.info(
+ logger.debug(
f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB'
)
diff --git a/deepspeed/runtime/swap_tensor/optimizer_utils.py b/deepspeed/runtime/swap_tensor/optimizer_utils.py
index a08af96f2a12..f34ff3a457a8 100644
--- a/deepspeed/runtime/swap_tensor/optimizer_utils.py
+++ b/deepspeed/runtime/swap_tensor/optimizer_utils.py
@@ -8,6 +8,7 @@
import os
import torch
+from deepspeed import comm as dist
from deepspeed.utils.logging import logger
from deepspeed.runtime.zero.offload_constants import *
from deepspeed.runtime.swap_tensor.constants import *
@@ -133,7 +134,7 @@ def __init__(self,
self.swap_element_size = torch.tensor([], dtype=dtype).element_size()
self.swap_folder = os.path.join(base_folder,
'optimizer',
- f'rank{torch.distributed.get_rank()}')
+ f'rank{dist.get_rank()}')
os.makedirs(self.swap_folder, exist_ok=True)
self.optimizer = optimizer
@@ -271,7 +272,7 @@ def _initialize_from_swapped_fp16_params(self,
fp16_partitions_info=fp16_partitions_info[curr_index:],
fp16_swap_buffers=fp16_swap_buffers)
- if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
+ if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
for i, tensor in enumerate(fp16_pinned_tensors):
true_index = curr_index + i
logger.info(
@@ -376,7 +377,7 @@ def _initialize_parameters(self, parameters, src_tensors, aio_handle):
dest_paths=swap_paths,
pinned_buffers=pinned_buffers)
- if torch.distributed.get_rank() == 0 and SWAPPER_DEBUG_MODE:
+ if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
for i, tensor in enumerate(src_tensors):
logger.info(
f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}'
diff --git a/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py
index cbcb136aa337..124500de888b 100644
--- a/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py
+++ b/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py
@@ -10,6 +10,7 @@
from deepspeed.utils.logging import logger
from deepspeed.ops.aio import AsyncIOBuilder
+from deepspeed import comm as dist
from deepspeed.runtime.swap_tensor.constants import *
from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object, \
@@ -62,7 +63,7 @@ def __init__(self,
'print_exclude_list'
]
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print_object(obj=self,
name='PartitionedOptimizerSwapper',
exclude_list=self.print_exclude_list)
@@ -160,7 +161,7 @@ def swap_out_optimizer_state(self, parameter, async_swap=False):
self._log_timers([WRITE_TIMER])
- if DEBUG_MODE and torch.distributed.get_rank() == 0:
+ if DEBUG_MODE and dist.get_rank() == 0:
logger.info(f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')
def swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors):
@@ -200,7 +201,7 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
t.data = buffer.data
self._log_timers([READ_TIMER, WAIT_TIMER])
- if DEBUG_MODE and torch.distributed.get_rank() == 0:
+ if DEBUG_MODE and dist.get_rank() == 0:
logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
def _separate_pinned_tensors(self, swap_info):
diff --git a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py
index 5128bbf6a59a..c488b84692cd 100644
--- a/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py
+++ b/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py
@@ -9,7 +9,7 @@
import shutil
from enum import Enum
import torch
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils.logging import logger
from deepspeed.ops.aio import AsyncIOBuilder
@@ -19,7 +19,7 @@
def print_rank_0(message, debug=False, force=False):
- if torch.distributed.get_rank() == 0 and (debug or force):
+ if dist.get_rank() == 0 and (debug or force):
print(message)
@@ -305,7 +305,7 @@ def swap_in(self, params, async_op=True, swap_in_buffers=None):
f'Num inflight: params {len(self.inflight_params)}, buffers {len(self.inflight_swap_in_buffers)}, numel = {self.inflight_numel}',
force=True)
print_rank_0(
- f'Num available: param {len(self.available_params)}, numel = {self.available_numel}',
+ f'Num available params: count = {len(self.available_params)}, ids = {self.available_params}, numel = {self.available_numel}',
force=True)
assert len(swap_in_paths) <= len(self.available_buffer_ids), f"Not enough buffers {len(self.available_buffer_ids)} for swapping {len(swap_in_paths)}"
diff --git a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py
index 7d0116faab5b..598585078632 100644
--- a/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py
+++ b/deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py
@@ -10,6 +10,7 @@
from deepspeed.utils.logging import logger
from deepspeed.ops.aio import AsyncIOBuilder
+from deepspeed import comm as dist
from deepspeed.runtime.zero.offload_constants import *
from deepspeed.runtime.swap_tensor.constants import *
@@ -113,7 +114,7 @@ def __init__(self,
'print_exclude_list'
]
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print_object(obj=self,
name='PipelinedOptimizerSwapper',
exclude_list=self.print_exclude_list)
diff --git a/deepspeed/runtime/swap_tensor/utils.py b/deepspeed/runtime/swap_tensor/utils.py
index 2a751e398000..9d544e3d9b11 100644
--- a/deepspeed/runtime/swap_tensor/utils.py
+++ b/deepspeed/runtime/swap_tensor/utils.py
@@ -9,6 +9,8 @@
import torch
from deepspeed.utils.logging import logger
+from deepspeed import comm as dist
+
from deepspeed.runtime.swap_tensor.constants import AIO_BLOCK_SIZE, AIO_QUEUE_DEPTH, \
AIO_THREAD_COUNT, AIO_SINGLE_SUBMIT, AIO_OVERLAP_EVENTS
@@ -190,7 +192,7 @@ def __init__(self, num_elems, count, dtype):
self.gigabytes = (self.all_buffers[0].element_size() * num_elems * count) / (1024
**3)
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
exclude_list = ['all_buffers']
print_object(obj=self, name='SwapBufferManager', exclude_list=exclude_list)
diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py
index ff6daba1c6c9..69660ac3c272 100755
--- a/deepspeed/runtime/utils.py
+++ b/deepspeed/runtime/utils.py
@@ -17,7 +17,7 @@
import torch
from torch._six import inf
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils import groups, logger
from deepspeed.runtime.constants import PIPE_REPLICATED
@@ -202,11 +202,11 @@ def check_using_norm(self, norm_group, reduce_overflow=True):
op=dist.ReduceOp.MAX,
group=groups._get_max_expert_parallel_group())
if self.mpu is not None:
- torch.distributed.all_reduce(overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=self.mpu.get_model_parallel_group())
+ dist.all_reduce(overflow_gpu,
+ op=dist.ReduceOp.MAX,
+ group=self.mpu.get_model_parallel_group())
elif reduce_overflow:
- dist.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX)
+ dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX)
dist.barrier()
overflow = overflow_gpu[0].item()
return bool(overflow)
@@ -243,8 +243,8 @@ def has_overflow(self, params, has_moe_params=None):
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = torch.cuda.ByteTensor([overflow])
- # torch.distributed.all_reduce(overflow_gpu,
- # op=torch.distributed.ReduceOp.MAX,
+ # deepspeeed.comm.all_reduce(overflow_gpu,
+ # op=deepspeed.comm.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
if has_moe_params:
# All reduce this across expert_parallel_group, so that if an expert
@@ -253,9 +253,9 @@ def has_overflow(self, params, has_moe_params=None):
op=dist.ReduceOp.MAX,
group=groups._get_max_expert_parallel_group())
if self.zero_reduce_scatter:
- torch.distributed.all_reduce(overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=torch.distributed.group.WORLD)
+ dist.all_reduce(overflow_gpu,
+ op=dist.ReduceOp.MAX,
+ group=dist.get_world_group())
elif self.mpu is not None:
if self.deepspeed is not None:
using_pipeline = hasattr(self.deepspeed,
@@ -264,17 +264,16 @@ def has_overflow(self, params, has_moe_params=None):
and self.deepspeed.pipeline_enable_backward_allreduce is False
) or (not using_pipeline
and self.deepspeed.enable_backward_allreduce is False):
- torch.distributed.all_reduce(
- overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=self.mpu.get_data_parallel_group())
- torch.distributed.all_reduce(overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=self.mpu.get_model_parallel_group())
+ dist.all_reduce(overflow_gpu,
+ op=dist.ReduceOp.MAX,
+ group=self.mpu.get_data_parallel_group())
+ dist.all_reduce(overflow_gpu,
+ op=dist.ReduceOp.MAX,
+ group=self.mpu.get_model_parallel_group())
elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
- torch.distributed.all_reduce(overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=torch.distributed.group.WORLD)
+ dist.all_reduce(overflow_gpu,
+ op=dist.ReduceOp.MAX,
+ group=dist.get_world_group())
overflow = overflow_gpu[0].item()
return bool(overflow)
@@ -304,7 +303,7 @@ def _has_inf_or_nan(x, i):
def _handle_overflow(cpu_sum, x, i):
import math
- rank = torch.distributed.get_rank()
+ rank = dist.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
@@ -356,9 +355,9 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0
@@ -375,9 +374,9 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
# Need to average total_norm across different GPUs due to the presence of moe params
@@ -422,9 +421,9 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
@@ -445,9 +444,9 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
@@ -491,9 +490,9 @@ def get_grad_zeros(parameters, mpu=None):
# Sum across all model parallel GPUs.
total_zeros_cuda = torch.cuda.FloatTensor([float(total_zeros)])
if mpu is not None:
- torch.distributed.all_reduce(total_zeros_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_zeros_cuda,
+ op=dist.ReduceOp.SUM,
+ group=mpu.get_model_parallel_group())
total_zeros = total_zeros_cuda[0].item()
return total_zeros
@@ -525,9 +524,9 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
@@ -548,9 +547,9 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
@@ -810,7 +809,7 @@ def memory_status(msg, print_rank=-1, reset_max=False):
def get_ma_status():
- if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
+ if dist.is_initialized() and not dist.get_rank() == 0:
return 0
return torch.cuda.memory_allocated()
@@ -818,7 +817,7 @@ def get_ma_status():
def see_memory_usage(message, force=False):
if not force:
return
- if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
+ if dist.is_initialized() and not dist.get_rank() == 0:
return
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
@@ -915,18 +914,18 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None):
total_norm = max(t.data.abs().max() for t in input_tensors)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = sum(
[t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_model_parallel_group())
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
diff --git a/deepspeed/runtime/zero/contiguous_memory_allocator.py b/deepspeed/runtime/zero/contiguous_memory_allocator.py
index a5b14ae13342..883e58301e04 100644
--- a/deepspeed/runtime/zero/contiguous_memory_allocator.py
+++ b/deepspeed/runtime/zero/contiguous_memory_allocator.py
@@ -1,8 +1,10 @@
import torch
+from deepspeed import comm as dist
+
def print_rank_0(message):
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print(message)
diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py
index f72f594f8368..02a585b02b01 100644
--- a/deepspeed/runtime/zero/linear.py
+++ b/deepspeed/runtime/zero/linear.py
@@ -18,12 +18,13 @@
from torch.nn import init
from torch.nn.modules.module import Module
from deepspeed.runtime.utils import noop_decorator
+from deepspeed import comm as dist
tensor_map = {}
def print_rank_0(message, debug=False, force=False):
- if torch.distributed.get_rank() == 0 and (debug or force):
+ if dist.get_rank() == 0 and (debug or force):
print(message)
diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py
new file mode 100644
index 000000000000..688b81900e36
--- /dev/null
+++ b/deepspeed/runtime/zero/parameter_offload.py
@@ -0,0 +1,485 @@
+"""
+"Copyright 2022 The Microsoft DeepSpeed Team.
+Licensed under the MIT license.
+"""
+
+import torch
+from torch.cuda import Stream
+from collections import OrderedDict
+from deepspeed.runtime.utils import see_memory_usage
+from deepspeed.runtime.zero.partition_parameters import _init_external_params
+from deepspeed.runtime.zero.partition_parameters import *
+from deepspeed.runtime.zero.offload_constants import *
+from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params
+
+FWD_MODULE_STACK = list()
+
+
+def is_builtin_type(obj):
+ # https://stackoverflow.com/a/17795199
+ return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"
+
+
+#apply torch.autograd.Function that calls a backward_function to tensors in output
+def _apply_to_tensors_only(module, functional, backward_function, outputs):
+ if isinstance(outputs, (tuple, list)):
+ touched_outputs = []
+ for output in outputs:
+ touched_output = _apply_to_tensors_only(module,
+ functional,
+ backward_function,
+ output)
+ touched_outputs.append(touched_output)
+ return outputs.__class__(touched_outputs)
+ elif isinstance(outputs, dict):
+ # apply inplace to avoid recreating dict inherited objects
+ for key in outputs.keys():
+ outputs[key] = _apply_to_tensors_only(module,
+ functional,
+ backward_function,
+ outputs[key])
+ return outputs
+
+ elif type(outputs) is torch.Tensor:
+ return functional.apply(module, backward_function, outputs)
+ else:
+ if not is_builtin_type(outputs):
+ logger.warning(
+ f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
+ "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
+ "output tensors and therefore may not get triggered properly.")
+ return outputs
+
+
+#for each tensor in outputs run the forward_function and register backward_function as hook
+def _apply_forward_and_backward_to_tensors_only(module,
+ forward_function,
+ backward_function,
+ outputs):
+ if type(outputs) is tuple:
+ touched_outputs = []
+ for output in outputs:
+ touched_output = _apply_forward_and_backward_to_tensors_only(
+ module,
+ forward_function,
+ backward_function,
+ output)
+ touched_outputs.append(touched_output)
+ return tuple(touched_outputs)
+ elif type(outputs) is torch.Tensor:
+ forward_function(outputs)
+ if outputs.requires_grad:
+ outputs.register_hook(backward_function)
+ return outputs
+ else:
+ return outputs
+
+
+class ZeROOrderedDict(OrderedDict):
+ def __init__(self, parent_module, *args, **kwargs):
+ """A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
+
+ Args:
+ parent_module (``collections.OrderedDict``): the collection to replace
+ """
+
+ super().__init__(*args, **kwargs)
+ self._parent_module = parent_module
+ self._in_forward = False
+
+ def __getitem__(self, key):
+ param = super().__getitem__(key)
+
+ # Params can be registered as None (e.g., bias)
+ if param is None:
+ return param
+
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if self._parent_module._parameters._in_forward:
+ register_external_parameter(FWD_MODULE_STACK[-1], param)
+ param.all_gather()
+ print_rank_0(
+ f'Registering external parameter from getter {key} ds_id = {param.ds_id}',
+ force=False)
+
+ return param
+
+
+def _inject_parameters(module, cls):
+ for module in module.modules():
+ if cls == ZeROOrderedDict:
+ new_param = cls(parent_module=module)
+ else:
+ new_param = cls()
+
+ for key, param in module._parameters.items():
+ new_param[key] = param
+ module._parameters = new_param
+
+
+class PreBackwardFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, module, pre_backward_function, outputs):
+ ctx.module = module
+ ctx.pre_backward_function = pre_backward_function
+ if not hasattr(module, "applied_pre_backward_ref_cnt"):
+ module.applied_pre_backward_ref_cnt = 0
+ module.applied_pre_backward_ref_cnt += 1
+ #print(f"After Forward: {ctx.module.__class__.__name__}")
+ outputs = outputs.detach()
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ #print(f"Before Backward: {ctx.module.__class__.__name__}")
+ ctx.pre_backward_function(ctx.module)
+ return (None, None) + args
+
+
+class PostBackwardFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, module, pre_backward_function, output):
+ ctx.module = module
+ if output.requires_grad:
+ #TODO SOME TIMES post backward does not seem to be triggered debug in detail
+ #Should only cause increase in memory not correctness issue
+ #if output.grad_fn.__class__.__name__ == 'ViewBackward':
+ # ctx.view=True
+ # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
+ #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
+ #if module.ds_grads_remaining == 0:
+ # print(f"Before Forward: {ctx.module.__class__.__name__}")
+ module.ds_grads_remaining += 1
+ ctx.pre_backward_function = pre_backward_function
+ output = output.detach()
+ return output
+
+ @staticmethod
+ def backward(ctx, *args):
+ ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
+ if ctx.module.ds_grads_remaining == 0:
+ ctx.pre_backward_function(ctx.module)
+ #print(f"After Backward: {ctx.module.__class__.__name__}")
+ return (None, None) + args
+
+
+class DeepSpeedZeRoOffload(object):
+ def __init__(self,
+ module,
+ timers,
+ ds_config,
+ overlap_comm=True,
+ prefetch_bucket_size=50000000,
+ max_reuse_distance=1000000000,
+ max_live_parameters=1000000000,
+ param_persistence_threshold=100000,
+ offload_param_config=None,
+ mpu=None):
+
+ see_memory_usage("TensorOffload initialize beginning", force=True)
+
+ print_rank_0(f"initialized {__class__.__name__} with args: {locals()}",
+ force=False)
+
+ self.module = module
+ self.dtype = list(module.parameters())[0].dtype
+ self.offload_device = None
+ self.offload_param_pin_memory = False
+ if offload_param_config is not None:
+ self.offload_device = offload_param_config[OFFLOAD_PARAM_DEVICE]
+ self.offload_param_pin_memory = offload_param_config[
+ OFFLOAD_PARAM_PIN_MEMORY]
+
+ self._convert_to_zero_parameters(ds_config, module, mpu)
+
+ for m in module.modules():
+ _init_external_params(m)
+
+ _inject_parameters(module, ZeROOrderedDict)
+
+ self.persistence_threshold = int(param_persistence_threshold)
+ self.persistent_parameters = self.mark_persistent_parameters()
+
+ self.param_coordinators = {}
+ self._prefetch_bucket_sz = int(prefetch_bucket_size)
+ self._max_reuse_distance_in_numel = int(max_reuse_distance)
+ self._max_available_parameters_in_numel = int(max_live_parameters)
+ self.__allgather_stream = Stream(
+ ) if overlap_comm else torch.cuda.default_stream()
+
+ self.forward_hooks = []
+ self.backward_hooks = []
+ self.setup_zero_stage3_hooks()
+ print_rank_0(
+ f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
+ force=False)
+
+ @instrument_w_nvtx
+ def partition_all_parameters(self):
+ """Partitioning Parameters that were not partitioned usually if parameters
+ of modules whose input parameters do not require grad computation do not
+ trigger post call and will therefore will remain unpartitioned"""
+ self.get_param_coordinator(training=self.module.training).release_and_reset_all(
+ self.module)
+ for param in iter_params(self.module, recurse=True):
+ if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
+ raise RuntimeError(f"{param.ds_summary()} expected to be released")
+
+ def get_param_coordinator(self, training):
+ if not training in self.param_coordinators:
+ self.param_coordinators[training] = PartitionedParameterCoordinator(
+ prefetch_bucket_sz=self._prefetch_bucket_sz,
+ max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
+ max_available_parameters_in_numel=self.
+ _max_available_parameters_in_numel,
+ allgather_stream=self.__allgather_stream,
+ prefetch_nvme=self.offload_device == OFFLOAD_NVME_DEVICE,
+ )
+
+ return self.param_coordinators[training]
+
+ def _convert_to_zero_parameters(self, ds_config, module, mpu):
+ non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
+ if non_zero_params:
+ zero_params = [p for p in module.parameters() if is_zero_param(p)]
+ if zero_params:
+ zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
+ else:
+ group = None
+ if mpu:
+ group = mpu.get_data_parallel_group()
+
+ Init(module=module,
+ data_parallel_group=group,
+ dtype=self.dtype,
+ config_dict_or_path=ds_config,
+ remote_device=self.offload_device,
+ pin_memory=self.offload_param_pin_memory,
+ mpu=mpu)
+
+ def destroy(self):
+ self._remove_module_hooks()
+
+ def _remove_module_hooks(self):
+ num_forward_hooks = len(self.forward_hooks)
+ num_backward_hooks = len(self.backward_hooks)
+
+ for hook in self.forward_hooks:
+ hook.remove()
+
+ for hook in self.backward_hooks:
+ hook.remove()
+
+ print_rank_0(
+ f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
+ force=False)
+
+ def setup_zero_stage3_hooks(self):
+ self.hierarchy = 0
+
+ #reset step if in inference mode
+ @instrument_w_nvtx
+ def _end_of_forward_hook(module, *args):
+
+ if not torch._C.is_grad_enabled():
+ self.get_param_coordinator(training=False).reset_step()
+
+ #likely one of them should be enough but just to be safe
+ self._register_hooks_recursively(self.module)
+ self.module.register_forward_hook(_end_of_forward_hook)
+
+ # Add top module to stack trace
+ global FWD_MODULE_STACK
+ FWD_MODULE_STACK.append(self.module)
+
+ def mark_persistent_parameters(self):
+ persistent_params = []
+ total_persistent_parameters = 0
+ params_count = 0
+ for _, param in self.module.named_parameters(recurse=True):
+ if param.ds_numel < self.persistence_threshold:
+ params_count += 1
+ param.ds_persist = True
+ persistent_params.append(param)
+ total_persistent_parameters += param.ds_numel
+
+ print_rank_0(
+ f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
+ force=False)
+
+ return persistent_params
+
+ def _register_hooks_recursively(self, module, count=[0]):
+ my_count = count[0]
+ module.id = my_count
+
+ #print(f"{module.__class__} : {module.id}")
+
+ for child in module.children():
+ count[0] = count[0] + 1
+ self._register_hooks_recursively(child, count=count)
+
+ @instrument_w_nvtx
+ def _pre_forward_module_hook(module, *args):
+ self.pre_sub_module_forward_function(module)
+
+ @instrument_w_nvtx
+ def _post_forward_module_hook(module, input, output):
+ global FWD_MODULE_STACK
+ FWD_MODULE_STACK.pop()
+ if output is None:
+ output = []
+ elif not isinstance(output, (list, tuple)):
+ if torch.is_tensor(output):
+ output = [output]
+ else:
+ #print(f'got UNKNOWN type {type(output)}')
+ outputs = []
+ output = output if isinstance(output, dict) else vars(output)
+ for name, val in output.items():
+ if not name.startswith('__') and torch.is_tensor(val):
+ outputs.append(val)
+ output = outputs
+ #print(f'convert output to {output}')
+
+ for item in filter(lambda item: is_zero_param(item), output):
+ if not any(id(item) in m._external_params for m in FWD_MODULE_STACK):
+ item.is_external_param = True
+ module_to_register = FWD_MODULE_STACK[-1]
+ register_external_parameter(module_to_register, item)
+ print_rank_0(
+ f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {item.ds_id}.',
+ force=False)
+
+ # It's possible that the parameter was already external to the completed module. If so, remove it the
+ # registration as it will be covered by the outer module instead.
+ if id(item) in module._external_params:
+ print_rank_0(
+ f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {item.ds_id}',
+ force=False)
+ unregister_external_parameter(module, item)
+
+ item.all_gather()
+
+ self.post_sub_module_forward_function(module)
+
+ def _pre_backward_module_hook(module, inputs, output):
+ @instrument_w_nvtx
+ def _run_before_backward_function(sub_module):
+ # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
+ # before doing backwards, so each backward will need a pre-fetch - using reference
+ # counting to support this scenario
+ #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
+ if sub_module.applied_pre_backward_ref_cnt > 0:
+ self.pre_sub_module_backward_function(sub_module)
+ sub_module.applied_pre_backward_ref_cnt -= 1
+ #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
+
+ return _apply_to_tensors_only(module,
+ PreBackwardFunction,
+ _run_before_backward_function,
+ output)
+
+ #This is an alternate to doing _post_backward_module_hook
+ #it uses tensor.register_hook instead of using torch.autograd.Function
+ def _alternate_post_backward_module_hook(module, inputs):
+ module.ds_grads_remaining = 0
+
+ #print(f"Before Forward {module.__class__.__name__}")
+
+ def _run_after_backward_hook(*unused):
+ module.ds_grads_remaining = module.ds_grads_remaining - 1
+ if module.ds_grads_remaining == 0:
+ #print(f"After backward {module.__class__.__name__}")
+ self.post_sub_module_backward_function(module)
+
+ def _run_before_forward_function(input):
+ if input.requires_grad:
+ module.ds_grads_remaining += 1
+
+ return _apply_forward_and_backward_to_tensors_only(
+ module,
+ _run_before_forward_function,
+ _run_after_backward_hook,
+ inputs)
+
+ def _post_backward_module_hook(module, inputs):
+ module.ds_grads_remaining = 0
+
+ @instrument_w_nvtx
+ def _run_after_backward_function(sub_module):
+ if sub_module.ds_grads_remaining == 0:
+ self.post_sub_module_backward_function(sub_module)
+
+ return _apply_to_tensors_only(module,
+ PostBackwardFunction,
+ _run_after_backward_function,
+ inputs)
+
+ # Pre forward hook
+ self.forward_hooks.append(
+ module.register_forward_pre_hook(_pre_forward_module_hook))
+
+ # Post forward hook
+ self.forward_hooks.append(
+ module.register_forward_hook(_post_forward_module_hook))
+
+ # Pre backward hook
+ self.backward_hooks.append(
+ module.register_forward_hook(_pre_backward_module_hook))
+
+ # post backward hook
+ self.backward_hooks.append(
+ module.register_forward_pre_hook(_post_backward_module_hook))
+
+ @torch.no_grad()
+ def pre_sub_module_forward_function(self, sub_module):
+ see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}",
+ force=False)
+
+ global FWD_MODULE_STACK
+ FWD_MODULE_STACK.append(sub_module)
+
+ param_coordinator = self.get_param_coordinator(training=sub_module.training)
+ param_coordinator.trace_prologue(sub_module)
+ if param_coordinator.is_record_trace():
+ param_coordinator.record_module(sub_module)
+ param_coordinator.fetch_sub_module(sub_module)
+
+ see_memory_usage(
+ f"Before sub module function {sub_module.__class__.__name__} after fetch",
+ force=False)
+
+ @torch.no_grad()
+ def post_sub_module_forward_function(self, sub_module):
+ see_memory_usage(
+ f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
+ force=False)
+
+ param_coordinator = self.get_param_coordinator(training=sub_module.training)
+ param_coordinator.release_sub_module(sub_module)
+
+ see_memory_usage(
+ f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
+ force=False)
+
+ @torch.no_grad()
+ def pre_sub_module_backward_function(self, sub_module):
+ param_coordinator = self.get_param_coordinator(training=sub_module.training)
+ param_coordinator.trace_prologue(sub_module)
+ if param_coordinator.is_record_trace():
+ param_coordinator.record_module(sub_module)
+ param_coordinator.fetch_sub_module(sub_module)
+
+ @torch.no_grad()
+ def post_sub_module_backward_function(self, sub_module):
+ see_memory_usage(
+ f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
+ force=False)
+
+ self.get_param_coordinator(
+ training=sub_module.training).release_sub_module(sub_module)
+
+ see_memory_usage(
+ f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
+ force=False)
diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py
index 8fb080f81072..62cd21b3710f 100755
--- a/deepspeed/runtime/zero/partition_parameters.py
+++ b/deepspeed/runtime/zero/partition_parameters.py
@@ -15,8 +15,7 @@
import torch
from torch import Tensor
-import torch.distributed as dist
-from torch.distributed.distributed_c10d import _get_global_rank, group
+from deepspeed import comm as dist
from torch.nn import Module
from torch.nn import Parameter
@@ -26,7 +25,8 @@
import deepspeed
from ..utils import get_only_unique_item, see_memory_usage
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks
-from deepspeed.utils import init_distributed, instrument_w_nvtx, logger
+from deepspeed.utils import instrument_w_nvtx, logger
+from deepspeed.comm.comm import init_distributed
from deepspeed.utils.debug import (debug_param2name_id_shape,
debug_param2name_id_shape_device,
debug_module2name,
@@ -42,40 +42,16 @@
param_count = 0
partitioned_param_data_shape = [0]
-if hasattr(torch.distributed, "_all_gather_base"):
- def torch_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group):
- try:
- return instrument_w_nvtx(torch.distributed._all_gather_base)(
- output_tensor,
- input_tensor,
- group=group,
- async_op=True,
- )
- except RuntimeError as e:
- raise RuntimeError(
- f"output_tensor: {output_tensor.device}, input_tensor: {input_tensor.device}"
- ) from e
-else:
- logger.warning(
- "unable to find torch.distributed._all_gather_base. will fall back to "
- "torch.distributed.all_gather which will result in suboptimal performance. "
- "please consider upgrading your pytorch installation.")
-
- def torch_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group):
- output_tensors = list(
- torch.chunk(output_tensor,
- torch.distributed.get_world_size(group)))
- return instrument_w_nvtx(torch.distributed.all_gather)(
- output_tensors,
- input_tensor,
- group=group,
- async_op=True,
- )
+def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group):
+ return instrument_w_nvtx(dist.allgather_fn)(output_tensor,
+ input_tensor,
+ group=group,
+ async_op=True)
def print_rank_0(message, debug=False, force=False):
- rank = torch.distributed.get_rank()
+ rank = dist.get_rank()
if rank == 0 and (debug or force):
print(message)
# other variations
@@ -86,7 +62,7 @@ def print_rank_0(message, debug=False, force=False):
def debug_rank0(msg: str) -> None:
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
logger.debug(msg)
@@ -341,9 +317,7 @@ def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None:
fn_to_apply(module_to_apply_fn_to)
for param in params_to_apply_fn_to:
- torch.distributed.broadcast(param.data,
- 0,
- group=param.ds_process_group)
+ dist.broadcast(param.data, 0, group=param.ds_process_group)
for param in params_to_apply_fn_to:
param.partition(has_been_updated=True)
@@ -459,7 +433,7 @@ def _disable_class(cls):
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
logger.info("finished initializing model with %.2fB parameters",
param_count / 1e9)
@@ -573,7 +547,7 @@ def __init__(self,
Args:
module (``torch.nn.Module``, optional): If provided, partition the model as
if it was constructed in the context.
- data_parallel_group (``torch.distributed`` process group, optional):
+ data_parallel_group (``deepspeed.comm`` process group, optional):
The group of processes to partition among. Defaults to all processes.
mem_efficient_linear (bool, optional): Replace
torch.nn.functional.linear with an implementation that allows
@@ -622,7 +596,7 @@ def __init__(self,
this feature must be used.
.. note::
- Initializes ``torch.distributed`` if it has not already been done so.
+ Initializes ``deepspeed.comm`` if it has not already been done so.
See :meth:`deepseed.init_distributed` for more information.
.. note::
@@ -677,16 +651,16 @@ def get_model():
mem_efficient_linear=mem_efficient_linear,
ds_config=_ds_config,
dtype=dtype)
- if not torch.distributed.is_initialized():
+ if not dist.is_initialized():
init_distributed()
- assert torch.distributed.is_initialized(), "Parameters cannot be scattered without initializing torch.distributed"
+ assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm"
if data_parallel_group is None:
- self.ds_process_group = torch.distributed.group.WORLD
+ self.ds_process_group = dist.get_world_group()
else:
self.ds_process_group = data_parallel_group
- self.rank = torch.distributed.get_rank(group=self.ds_process_group)
- self.world_size = torch.distributed.get_world_size(group=self.ds_process_group)
+ self.rank = dist.get_rank(group=self.ds_process_group)
+ self.world_size = dist.get_world_size(group=self.ds_process_group)
# Local device is the device where the parameters are consumed, must be default device.
# It is the device where parameters are fully instantiated using allgather
@@ -717,10 +691,9 @@ def get_model():
self._convert_to_zero_parameters(module.parameters(recurse=True))
self.use_all_gather_base = False
- try:
- from torch.distributed.distributed_c10d import _all_gather_base as all_gather
+ if dist.has_allgather_base():
self.use_all_gather_base = True
- except:
+ else:
logger.info(
f"_all_gather_base API is not available in torch {torch.__version__}")
@@ -764,9 +737,9 @@ def _post_init_method(self, module):
)
if param.is_cuda:
- torch.distributed.broadcast(param, 0, self.ds_process_group)
+ dist.broadcast(param, 0, self.ds_process_group)
else:
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
logger.warn(f"param `{name}` in {module.__class__.__name__} "
f"not on GPU so was not broadcasted from rank 0")
@@ -858,7 +831,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
device=torch.cuda.current_device(),
requires_grad=False,
)
- handle = torch_allgather_fn(
+ handle = _dist_allgather_fn(
param.ds_tensor.to(torch.cuda.current_device()),
param_buffer,
self.ds_process_group,
@@ -885,7 +858,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(torch.cuda.current_device()) for p in params],
out=partitions[self.rank])
- handle = torch_allgather_fn(partitions[self.rank],
+ handle = _dist_allgather_fn(partitions[self.rank],
flat_tensor,
self.ds_process_group)
@@ -937,8 +910,8 @@ def aligned_size():
def padding_size():
return self._padding_size(param)
- def partitioned_size():
- return self._partitioned_size(param)
+ def partition_numel():
+ return self._partition_numel(param)
def item_override():
param.all_gather()
@@ -980,7 +953,7 @@ def wrapped(*args, **kwargs):
# Partitioning size utilities
param.aligned_size = aligned_size
param.padding_size = padding_size
- param.partitioned_size = partitioned_size
+ param.partition_numel = partition_numel
param.ds_summary = types.MethodType(ds_summary, param)
param.item = allgather_before(param.item)
@@ -994,7 +967,7 @@ def _padding_size(self, param):
remainder = param.ds_numel % self.world_size
return (self.world_size - remainder) if remainder else 0
- def _partitioned_size(self, param):
+ def _partition_numel(self, param):
return param.ds_tensor.ds_numel
def _ensure_availability_of_partitioned_params(self, params):
@@ -1074,7 +1047,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False):
# if numel in empty_buffers:
# empty_buffers[numel].append(buffer)
- # if torch.distributed.get_rank():
+ # if deepspeed.comm.get_rank():
# print(f"Releasing {param.data.numel()}")
if param.ds_tensor is not None and not has_been_updated:
@@ -1225,10 +1198,10 @@ def _allgather_param(self, param, async_op=False, hierarchy=0):
# return None
if self.use_all_gather_base:
# try the _all_gather_base on PyTorch master branch
- handle = dist._all_gather_base(flat_tensor,
- param.ds_tensor.cuda(),
- group=self.ds_process_group,
- async_op=async_op)
+ handle = dist.all_gather_base(flat_tensor,
+ param.ds_tensor.cuda(),
+ group=self.ds_process_group,
+ async_op=async_op)
else:
partitions = []
for i in range(self.world_size):
@@ -1281,10 +1254,10 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0):
if self.use_all_gather_base:
# try the _all_gather_base from Pytorch master
- h = dist._all_gather_base(allgather_params[param_idx],
- input_tensor,
- group=self.ds_process_group,
- async_op=True)
+ h = dist.all_gather_base(allgather_params[param_idx],
+ input_tensor,
+ group=self.ds_process_group,
+ async_op=True)
else:
output_list = []
for i in range(self.world_size):
@@ -1346,10 +1319,10 @@ def _allgather_params(self, param_list, hierarchy=0):
offset += param_numel
- torch.distributed.all_gather(partitions,
- partitions[self.rank],
- group=self.ds_process_group,
- async_op=False)
+ dist.all_gather(partitions,
+ partitions[self.rank],
+ group=self.ds_process_group,
+ async_op=False)
param_offset = 0
for param in param_list:
@@ -1443,11 +1416,11 @@ def _reduce_scatter_gradient(self, param):
#print("after reduce scatter gradients")
input_list.append(input)
- rank = torch.distributed.get_rank(group=self.ds_process_group)
- handle = torch.distributed.reduce_scatter(input_list[rank],
- input_list,
- group=self.ds_process_group,
- async_op=True)
+ rank = dist.get_rank(group=self.ds_process_group)
+ handle = dist.reduce_scatter(input_list[rank],
+ input_list,
+ group=self.ds_process_group,
+ async_op=True)
return handle, input_list[rank]
@@ -1479,7 +1452,7 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):
assert partition_buffer.numel(
) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"
- rank = torch.distributed.get_rank(group=self.ds_process_group)
+ rank = dist.get_rank(group=self.ds_process_group)
start = partition_size * rank
end = start + partition_size
@@ -1559,12 +1532,12 @@ def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
with deepspeed.zero.GatheredParameters(linear.weight,
modifier_rank=0):
- if torch.distributed.get_rank() == 0:
+ if deepspeed.comm.get_rank() == 0:
linear.weight.zero_()
with deepspeed.zero.GatheredParameters(linear.weight,
modifier_rank=0):
- if torch.distributed.get_rank() == 0:
+ if deepspeed.comm.get_rank() == 0:
linear.weight.zero_()
#. Collect a partitioned weight to pass to another module during
@@ -1598,7 +1571,7 @@ def load(module: nn.Module, prefix=""):
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
- if torch.distributed.get_rank() == 0:
+ if deepspeed.comm.get_rank() == 0:
module._load_from_state_dict(state_dict, prefix)
for name, child in module._modules.items():
@@ -1626,12 +1599,12 @@ def load(module: nn.Module, prefix=""):
self.params = [p for p in params if hasattr(p, "ds_id")]
self.src_rank = None
if modifier_rank is not None:
- if self.params[0].ds_process_group == torch.distributed.group.WORLD:
+ if self.params[0].ds_process_group == dist.get_world_group():
self.src_rank = modifier_rank
else:
# A group was specified; convert DP rank to global rank
- self.src_rank = _get_global_rank(self.params[0].ds_process_group,
- modifier_rank)
+ self.src_rank = dist.get_global_rank(self.params[0].ds_process_group,
+ modifier_rank)
self.fwd_module = fwd_module
if self.fwd_module is not None:
# is a no-op if already registered
@@ -1650,10 +1623,10 @@ def __exit__(self, *exc):
return
handles = [
- torch.distributed.broadcast(p,
- self.src_rank,
- group=p.ds_process_group,
- async_op=True) for p in self.params
+ dist.broadcast(p,
+ self.src_rank,
+ group=p.ds_process_group,
+ async_op=True) for p in self.params
]
for h in handles:
h.wait()
diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py
index e4064dd03d3e..7baf12f9f4b7 100644
--- a/deepspeed/runtime/zero/partitioned_param_coordinator.py
+++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py
@@ -12,6 +12,7 @@
from torch.cuda import Event, Stream
from torch.nn import Module, Parameter
+from deepspeed import comm as dist
from deepspeed.utils.logging import logger
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.offload_constants import *
@@ -420,12 +421,23 @@ def __params_to_release(self,
params_to_release = set(p.ds_id for p in iter_params(submodule_to_release)
if not p.ds_persist)
+ # Problem: When prefetcher scans the param trace, it skips AVAILABLE params.
+ # This creates issues if those params are released before the skipped uses:
+ # 1) It hurts performance as the skipped uses are never prefetched.
+ # 2) For nvme params, we run out of swap buffers because the prefetch order
+ # diverges from the trace.
+ # Solution: Don't release params whose reuse was skipped by prefetch. This is
+ # possible because we detect such skips during prefetch and mark those params.
+ for param in iter_params(submodule_to_release):
+ if self.__most_recent_step_id_param_fetched_for[param] > step_id:
+ params_to_release.discard(param.ds_id)
+
# examine all modules within `max_reuse_dist_in_numel` of the current step,
# if we see any of the candidate parameters to be released reoccur while
# doing this, remove them from the set of parameters to release.
params_traversed = 0
for module in self.__submodule_order[step_id:]:
- if params_traversed > self.__max_reuse_dist_in_numel:
+ if params_traversed >= self.__max_reuse_dist_in_numel:
break
for param in iter_params(module):
params_to_release.discard(param.ds_id)
diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py
index e963ef643677..e4afdfdbfb0c 100755
--- a/deepspeed/runtime/zero/stage3.py
+++ b/deepspeed/runtime/zero/stage3.py
@@ -14,7 +14,7 @@
import torch
from torch.cuda import Event, Stream
from torch.nn import Module, Parameter
-import torch.distributed as dist
+from deepspeed import comm as dist
import math
from torch._six import inf
from torch.nn import Module
@@ -24,9 +24,10 @@
from deepspeed.utils.logging import logger
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced
-from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim
+from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partition_parameters import _init_external_params
+from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.op_builder import UtilsBuilder
@@ -41,12 +42,11 @@
# with gradient partitioning and without
pg_correctness_test = False
-FWD_MODULE_STACK = list()
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file
def print_rank_0(message, debug=False, force=False):
- rank = torch.distributed.get_rank()
+ rank = dist.get_rank()
if rank == 0 and (debug or force):
print(message)
# other variations
@@ -74,154 +74,6 @@ def move_to_cpu(tensor_list):
tensor.data = tensor.data.cpu()
-def is_builtin_type(obj):
- # https://stackoverflow.com/a/17795199
- return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"
-
-
-#apply torch.autograd.Function that calls a backward_function to tensors in output
-def _apply_to_tensors_only(module, functional, backward_function, outputs):
- if isinstance(outputs, (tuple, list)):
- touched_outputs = []
- for output in outputs:
- touched_output = _apply_to_tensors_only(module,
- functional,
- backward_function,
- output)
- touched_outputs.append(touched_output)
- return outputs.__class__(touched_outputs)
- elif isinstance(outputs, dict):
- # apply inplace to avoid recreating dict inherited objects
- for key in outputs.keys():
- outputs[key] = _apply_to_tensors_only(module,
- functional,
- backward_function,
- outputs[key])
- return outputs
-
- elif type(outputs) is torch.Tensor:
- return functional.apply(module, backward_function, outputs)
- else:
- if not is_builtin_type(outputs):
- logger.warning(
- f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
- "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
- "output tensors and therefore may not get triggered properly.")
- return outputs
-
-
-#for each tensor in outputs run the forward_function and register backward_function as hook
-def _apply_forward_and_backward_to_tensors_only(module,
- forward_function,
- backward_function,
- outputs):
- if type(outputs) is tuple:
- touched_outputs = []
- for output in outputs:
- touched_output = _apply_forward_and_backward_to_tensors_only(
- module,
- forward_function,
- backward_function,
- output)
- touched_outputs.append(touched_output)
- return tuple(touched_outputs)
- elif type(outputs) is torch.Tensor:
- forward_function(outputs)
- if outputs.requires_grad:
- outputs.register_hook(backward_function)
- return outputs
- else:
- return outputs
-
-
-class ZeROOrderedDict(OrderedDict):
- def __init__(self, parent_module, *args, **kwargs):
- """A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
-
- Args:
- parent_module (``collections.OrderedDict``): the collection to replace
- """
-
- super().__init__(*args, **kwargs)
- self._parent_module = parent_module
- self._in_forward = False
-
- def __getitem__(self, key):
- param = super().__getitem__(key)
-
- # Params can be registered as None (e.g., bias)
- if param is None:
- return param
-
- if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
- if self._parent_module._parameters._in_forward:
- register_external_parameter(FWD_MODULE_STACK[-1], param)
- param.all_gather()
- print_rank_0(
- f'Registering external parameter from getter {key} ds_id = {param.ds_id}',
- force=False)
-
- return param
-
-
-def _inject_parameters(module, cls):
- for module in module.modules():
- if cls == ZeROOrderedDict:
- new_param = cls(parent_module=module)
- else:
- new_param = cls()
-
- for key, param in module._parameters.items():
- new_param[key] = param
- module._parameters = new_param
-
-
-class PreBackwardFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, module, pre_backward_function, outputs):
- ctx.module = module
- ctx.pre_backward_function = pre_backward_function
- if not hasattr(module, "applied_pre_backward_ref_cnt"):
- module.applied_pre_backward_ref_cnt = 0
- module.applied_pre_backward_ref_cnt += 1
- #print(f"After Forward: {ctx.module.__class__.__name__}")
- outputs = outputs.detach()
- return outputs
-
- @staticmethod
- def backward(ctx, *args):
- #print(f"Before Backward: {ctx.module.__class__.__name__}")
- ctx.pre_backward_function(ctx.module)
- return (None, None) + args
-
-
-class PostBackwardFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, module, pre_backward_function, output):
- ctx.module = module
- if output.requires_grad:
- #TODO SOME TIMES post backward does not seem to be triggered debug in detail
- #Should only cause increase in memory not correctness issue
- #if output.grad_fn.__class__.__name__ == 'ViewBackward':
- # ctx.view=True
- # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
- #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
- #if module.ds_grads_remaining == 0:
- # print(f"Before Forward: {ctx.module.__class__.__name__}")
- module.ds_grads_remaining += 1
- ctx.pre_backward_function = pre_backward_function
- output = output.detach()
- return output
-
- @staticmethod
- def backward(ctx, *args):
- ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
- if ctx.module.ds_grads_remaining == 0:
- ctx.pre_backward_function(ctx.module)
- #print(f"After Backward: {ctx.module.__class__.__name__}")
- return (None, None) + args
-
-
INITIAL_MICRO_STEP_ID = -1
@@ -266,7 +118,7 @@ def __init__(self,
elastic_checkpoint=False,
aio_config=None):
- see_memory_usage("Stage 3 initialize beginning", force=False)
+ see_memory_usage("Stage 3 initialize beginning", force=True)
print_rank_0(f"initialized {__class__.__name__} with args: {locals()}",
force=False)
@@ -285,8 +137,8 @@ def __init__(self,
# - master grad and unflat master weight never exist. TODO: a way to save out unflat master?
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
+
self.optimizer = init_optimizer
- self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)
# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
@@ -309,20 +161,21 @@ def __init__(self,
self.params_in_nvme_and_cpu = False
self.max_params_in_cpu = 0
+ self.parameter_offload = DeepSpeedZeRoOffload(module,
+ timers,
+ ds_config,
+ overlap_comm,
+ prefetch_bucket_size,
+ max_reuse_distance,
+ max_live_parameters,
+ param_persistence_threshold,
+ offload_param_config)
+ self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
- self._convert_to_zero_parameters(ds_config, module, mpu)
-
- for m in module.modules():
- _init_external_params(m)
-
self.module = module
self.elastic_checkpoint = elastic_checkpoint
- # Replace ._parameters with a new class to enable auto-registration of
- # external parameters
- _inject_parameters(module, ZeROOrderedDict)
-
self.__inf_or_nan_tracker: Tensor = torch.zeros(
1,
dtype=torch.bool,
@@ -335,41 +188,14 @@ def __init__(self,
self.device = torch.cuda.current_device(
) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE
### streams used for overlapping computation with communication
- self.__allgather_stream = Stream(
- ) if overlap_comm else torch.cuda.default_stream()
self.__reduce_and_partition_stream = Stream(
) if overlap_comm else torch.cuda.default_stream()
############################################################################
- see_memory_usage("Before Partitioned Parameter Coordinator", force=False)
- self.param_coordinators = {}
- self._prefetch_bucket_sz = int(prefetch_bucket_size)
- self._max_reuse_distance_in_numel = int(max_reuse_distance)
- self._max_available_parameters_in_numel = int(max_live_parameters)
- see_memory_usage("After Partitioned Parameter Coordinator", force=False)
-
self.__n_caching_allocator_flushes = 0
#-------------Stage 3 Setup-------------------#
- # parameters smaller than the threshold will be collectively gathered at the
- # end of the optimizer step and will be kept till the end of the backward pass
- # TODO maybe worth just replicating these parameters and doing all reduce for them
- self.persistence_threshold = int(param_persistence_threshold)
-
- self.persistent_parameters = self.persistent_parameters()
-
- self.forward_hooks = []
- self.backward_hooks = []
- self.setup_zero_stage3_hooks()
- print_rank_0(
- f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
- force=False)
-
- #resetting ds_tensor just in case parameters have been changed after initialization
- #example .half() or .to()
- #self.reset_ds_tensor()
- #---------------------------------------------#
self.timers = timers
@@ -426,6 +252,7 @@ def __init__(self,
self.all_reduce_print = False
self.prefetch_elements = int(prefetch_bucket_size)
+
self.contiguous_gradients = contiguous_gradients
# padding on each partition for alignment purposes
@@ -488,10 +315,9 @@ def __init__(self,
f'Largest partitioned param numel = {largest_partitioned_param_numel}',
force=False)
+ self._setup_for_real_optimizer()
self.grad_position = {}
- if self.using_real_optimizer:
- self._setup_for_real_optimizer()
- self.set_grad_positions()
+ self.set_grad_positions()
if self.offload_optimizer:
self.norm_for_param_grads = {}
@@ -517,7 +343,6 @@ def __init__(self,
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(scale=loss_scale_value)
- cur_iter = 0
else:
if dynamic_loss_args is None:
self.loss_scaler = DynamicLossScaler()
@@ -532,21 +357,7 @@ def __init__(self,
see_memory_usage(f"After initializing ZeRO optimizer", force=True)
def destroy(self):
- self._remove_module_hooks()
-
- def _remove_module_hooks(self):
- num_forward_hooks = len(self.forward_hooks)
- num_backward_hooks = len(self.backward_hooks)
-
- for hook in self.forward_hooks:
- hook.remove()
-
- for hook in self.backward_hooks:
- hook.remove()
-
- print_rank_0(
- f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
- force=False)
+ self.parameter_offload.destroy()
def _setup_for_real_optimizer(self):
see_memory_usage("Before creating fp32 partitions", force=False)
@@ -579,7 +390,7 @@ def _setup_for_real_optimizer(self):
all_params = list(itertools.chain.from_iterable(self.fp16_groups))
grad_partitions_flat_buffer: Tensor = torch.zeros(
- sum(p.ds_tensor.ds_numel for p in all_params),
+ sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device,
pin_memory=self.offload_optimizer_pin_memory)
@@ -590,8 +401,8 @@ def _setup_for_real_optimizer(self):
param.ds_id] = grad_partitions_flat_buffer.narrow(
0,
offset,
- param.ds_tensor.numel())
- offset += param.ds_tensor.numel()
+ param.partition_numel())
+ offset += param.partition_numel()
def set_lr(self, lr):
"""Set the learning rate."""
@@ -641,17 +452,7 @@ def defragment(tensors: List[Tensor]) -> Tensor:
return device_buffer
def _get_param_coordinator(self, training):
- if not training in self.param_coordinators:
- self.param_coordinators[training] = PartitionedParameterCoordinator(
- prefetch_bucket_sz=self._prefetch_bucket_sz,
- max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
- max_available_parameters_in_numel=self.
- _max_available_parameters_in_numel,
- allgather_stream=self.__allgather_stream,
- prefetch_nvme=self.params_in_nvme_and_cpu,
- )
-
- return self.param_coordinators[training]
+ return self.parameter_offload.get_param_coordinator(training)
def _configure_offloading(self, offload_optimizer_config, offload_param_config):
###################### offload optimizer setup ##################################
@@ -666,8 +467,6 @@ def _configure_offloading(self, offload_optimizer_config, offload_param_config):
###################### offload param setup ##################################
if offload_param_config is not None:
- if self.using_real_optimizer:
- assert self.offload_optimizer, "parameter offload is only available with optimizer state offload"
self.offload_param = True
self.offload_param_pin_memory = offload_param_config[
OFFLOAD_PARAM_PIN_MEMORY]
@@ -678,38 +477,12 @@ def _configure_offloading(self, offload_optimizer_config, offload_param_config):
f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}",
force=False)
- def _convert_to_zero_parameters(self, ds_config, module, mpu):
- non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
- if non_zero_params:
- zero_params = [p for p in module.parameters() if is_zero_param(p)]
- if zero_params:
- zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
- else:
- group = None
- if mpu:
- group = mpu.get_data_parallel_group()
-
- if self.params_in_nvme_and_cpu:
- remote_device = OFFLOAD_NVME_DEVICE
- elif self.offload_param:
- remote_device = OFFLOAD_CPU_DEVICE
- else:
- remote_device = None
-
- Init(module=module,
- data_parallel_group=group,
- dtype=self.dtype,
- config_dict_or_path=ds_config,
- remote_device=remote_device,
- pin_memory=self.offload_param_pin_memory,
- mpu=mpu)
-
def _configure_tensor_swapping(self, offload_optimizer_config, aio_config):
nvme_swap_folder = os.path.join(
offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH],
'zero_stage_3')
os.makedirs(nvme_swap_folder, exist_ok=True)
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
logger.info(f'Tensor Swapping: Adding optimizer tensors')
swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[
@@ -748,7 +521,7 @@ def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False):
'''if the parameter was initialized in nvme then bring it to the destination buffer directly'''
if src.status == PartitionedParamStatus.NOT_AVAILABLE:
print_rank_0(
- f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU"
+ f"Swapping in {param.ds_id} with partition size {param.partition_numel()} permanently to CPU"
)
param.nvme_swapper.swap_into_buffer(param, dest)
src.data = dest.data
@@ -767,7 +540,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self):
aggregate_params_count = 0
for j, param_group in enumerate(self.optimizer.param_groups):
- params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']])
+ params_in_group = sum([p.partition_numel() for p in param_group['params']])
flat_buffer_size = params_in_group
@@ -816,7 +589,7 @@ def _create_fp16_partitions_with_defragmentation(self):
# record total elements of parameter partitions in sub group
self.fp16_partitioned_groups_flat_numel.append(
- sum(p.ds_tensor.ds_numel for p in sub_group))
+ sum(p.partition_numel() for p in sub_group))
# record padding required to align group to world size (only applies to last rank)
rank_requires_padding = dist.get_rank(
@@ -839,7 +612,7 @@ def _create_fp16_partitions_with_defragmentation(self):
# contiguous flat buffer for all parameters that we created earlier
offset = 0
for sub_group in self.fp16_groups:
- sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group)
+ sub_group_numel = sum(param.partition_numel() for param in sub_group)
self.fp16_partitioned_groups_flat.append(
device_buffer.narrow(0,
offset,
@@ -851,7 +624,7 @@ def _create_fp16_partitions_with_defragmentation(self):
for param_group_idx, param_group in enumerate(param_groups):
flat_offset = 0
for i, sub_group in enumerate(param_group):
- total_elements = sum(p.ds_tensor.ds_numel for p in sub_group)
+ total_elements = sum(p.partition_numel() for p in sub_group)
print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}")
#Flat buffer may not be available for parameters that reside in NVME
if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[
@@ -887,7 +660,7 @@ def _create_fp16_partitions_with_defragmentation(self):
if should_create_fp16_flat_reuse_buffer:
max_partition_numel, largest_partition_numel = 0, None
for sub_group in self.fp16_groups:
- total_elements = sum(t.ds_tensor.ds_numel for t in sub_group)
+ total_elements = sum(t.partition_numel() for t in sub_group)
if total_elements > max_partition_numel:
largest_partition_numel = [t.ds_numel for t in sub_group]
max_partition_numel = total_elements
@@ -905,7 +678,7 @@ def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id):
dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel)
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
print_rank_0(
- f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}"
+ f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.partition_numel()}"
)
param.nvme_swapper.swap_in([param], async_op=False)
dest.data.copy_(partitioned_param.data)
@@ -935,7 +708,7 @@ def _get_sub_group_partitions(self, sub_group_id):
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
swap_path = param.nvme_swapper.get_path(param, True)
sub_group_partitions.append((partitioned_param,
- param.ds_tensor.ds_numel,
+ param.partition_numel(),
swap_path))
else:
sub_group_partitions.append((partitioned_param,
@@ -1051,7 +824,7 @@ def _create_fp32_partitions(self):
def _create_fp16_sub_groups(self, params_group):
- params_group_numel = sum([param.partitioned_size() for param in params_group])
+ params_group_numel = sum([param.partition_numel() for param in params_group])
sub_group_size = self.sub_group_size
if sub_group_size is None or sub_group_size >= params_group_numel:
@@ -1063,7 +836,7 @@ def _create_fp16_sub_groups(self, params_group):
for param in params_group:
sub_group.append(param)
- local_sub_group_size += param.partitioned_size()
+ local_sub_group_size += param.partition_numel()
if local_sub_group_size >= sub_group_size or id(param) == id(
params_group[-1]):
@@ -1075,221 +848,6 @@ def _create_fp16_sub_groups(self, params_group):
return sub_groups
- # def reset_ds_tensor(self):
- # for name, param in self.module.named_parameters(recurse=True):
- # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible"
- # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now"
- # param.ds_tensor.data = param.data
-
- def setup_zero_stage3_hooks(self):
- self.hierarchy = 0
-
- #reset step if in inference mode
- @instrument_w_nvtx
- def _end_of_forward_hook(module, *args):
-
- if not torch._C.is_grad_enabled():
- self._get_param_coordinator(training=False).reset_step()
-
- #likely one of them should be enough but just to be safe
- self._register_hooks_recursively(self.module)
- self.module.register_forward_hook(_end_of_forward_hook)
-
- # Add top module to stack trace
- global FWD_MODULE_STACK
- FWD_MODULE_STACK.append(self.module)
-
- def persistent_parameters(self):
- persistent_params = []
- total_persistent_parameters = 0
- params_count = 0
- for _, param in self.module.named_parameters(recurse=True):
- if param.ds_numel < self.persistence_threshold:
- params_count += 1
- param.ds_persist = True
- persistent_params.append(param)
- total_persistent_parameters += param.ds_numel
-
- print_rank_0(
- f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params",
- force=False)
- return persistent_params
-
- def _register_hooks_recursively(self, module, count=[0]):
- my_count = count[0]
- module.id = my_count
-
- #print(f"{module.__class__} : {module.id}")
-
- for child in module.children():
- count[0] = count[0] + 1
- self._register_hooks_recursively(child, count=count)
-
- @instrument_w_nvtx
- def _pre_forward_module_hook(module, *args):
- self.pre_sub_module_forward_function(module)
-
- @instrument_w_nvtx
- def _post_forward_module_hook(module, input, output):
- global FWD_MODULE_STACK
- FWD_MODULE_STACK.pop()
- if output is None:
- output = []
- elif not isinstance(output, (list, tuple)):
- if torch.is_tensor(output):
- output = [output]
- else:
- #print(f'got UNKNOWN type {type(output)}')
- outputs = []
- output = output if isinstance(output, dict) else vars(output)
- for name, val in output.items():
- if not name.startswith('__') and torch.is_tensor(val):
- outputs.append(val)
- output = outputs
- #print(f'convert output to {output}')
-
- for item in filter(lambda item: is_zero_param(item), output):
- if not any(id(item) in m._external_params for m in FWD_MODULE_STACK):
- item.is_external_param = True
- module_to_register = FWD_MODULE_STACK[-1]
- register_external_parameter(module_to_register, item)
- print_rank_0(
- f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {item.ds_id}.',
- force=False)
-
- # It's possible that the parameter was already external to the completed module. If so, remove it the
- # registration as it will be covered by the outer module instead.
- if id(item) in module._external_params:
- print_rank_0(
- f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {item.ds_id}',
- force=False)
- unregister_external_parameter(module, item)
-
- item.all_gather()
-
- self.post_sub_module_forward_function(module)
-
- def _pre_backward_module_hook(module, inputs, output):
- @instrument_w_nvtx
- def _run_before_backward_function(sub_module):
- # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
- # before doing backwards, so each backward will need a pre-fetch - using reference
- # counting to support this scenario
- #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
- if sub_module.applied_pre_backward_ref_cnt > 0:
- self.pre_sub_module_backward_function(sub_module)
- sub_module.applied_pre_backward_ref_cnt -= 1
- #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
-
- return _apply_to_tensors_only(module,
- PreBackwardFunction,
- _run_before_backward_function,
- output)
-
- #This is an alternate to doing _post_backward_module_hook
- #it uses tensor.register_hook instead of using torch.autograd.Function
- def _alternate_post_backward_module_hook(module, inputs):
- module.ds_grads_remaining = 0
-
- #print(f"Before Forward {module.__class__.__name__}")
-
- def _run_after_backward_hook(*unused):
- module.ds_grads_remaining = module.ds_grads_remaining - 1
- if module.ds_grads_remaining == 0:
- #print(f"After backward {module.__class__.__name__}")
- self.post_sub_module_backward_function(module)
-
- def _run_before_forward_function(input):
- if input.requires_grad:
- module.ds_grads_remaining += 1
-
- return _apply_forward_and_backward_to_tensors_only(
- module,
- _run_before_forward_function,
- _run_after_backward_hook,
- inputs)
-
- def _post_backward_module_hook(module, inputs):
- module.ds_grads_remaining = 0
-
- @instrument_w_nvtx
- def _run_after_backward_function(sub_module):
- if sub_module.ds_grads_remaining == 0:
- self.post_sub_module_backward_function(sub_module)
-
- return _apply_to_tensors_only(module,
- PostBackwardFunction,
- _run_after_backward_function,
- inputs)
-
- # Pre forward hook
- self.forward_hooks.append(
- module.register_forward_pre_hook(_pre_forward_module_hook))
-
- # Post forward hook
- self.forward_hooks.append(
- module.register_forward_hook(_post_forward_module_hook))
-
- # Pre backward hook
- self.backward_hooks.append(
- module.register_forward_hook(_pre_backward_module_hook))
-
- # post backward hook
- self.backward_hooks.append(
- module.register_forward_pre_hook(_post_backward_module_hook))
-
- @torch.no_grad()
- def pre_sub_module_forward_function(self, sub_module):
- see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}",
- force=False)
-
- global FWD_MODULE_STACK
- FWD_MODULE_STACK.append(sub_module)
-
- param_coordinator = self._get_param_coordinator(training=sub_module.training)
- param_coordinator.trace_prologue(sub_module)
- if param_coordinator.is_record_trace():
- param_coordinator.record_module(sub_module)
- param_coordinator.fetch_sub_module(sub_module)
-
- see_memory_usage(
- f"Before sub module function {sub_module.__class__.__name__} after fetch",
- force=False)
-
- @torch.no_grad()
- def post_sub_module_forward_function(self, sub_module):
- see_memory_usage(
- f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
- force=False)
-
- param_coordinator = self._get_param_coordinator(training=sub_module.training)
- param_coordinator.release_sub_module(sub_module)
-
- see_memory_usage(
- f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
- force=False)
-
- @torch.no_grad()
- def pre_sub_module_backward_function(self, sub_module):
- param_coordinator = self._get_param_coordinator(training=sub_module.training)
- param_coordinator.trace_prologue(sub_module)
- if param_coordinator.is_record_trace():
- param_coordinator.record_module(sub_module)
- param_coordinator.fetch_sub_module(sub_module)
-
- @torch.no_grad()
- def post_sub_module_backward_function(self, sub_module):
- see_memory_usage(
- f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
- force=False)
-
- self._get_param_coordinator(
- training=sub_module.training).release_sub_module(sub_module)
-
- see_memory_usage(
- f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
- force=False)
-
def _release_ipg_buffers(self):
if self.contiguous_gradients:
self.ipg_buffer = None
@@ -1633,7 +1191,7 @@ def set_grad_positions(self):
current_offset = 0
for param in group:
param_id = self.get_param_id(param)
- num_elements = param.ds_tensor.ds_numel
+ num_elements = param.partition_numel()
self.grad_position[param_id] = [
int(i),
@@ -1680,12 +1238,11 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=self.dp_process_group)
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=self.dp_process_group)
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM)
+ self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
@@ -1700,7 +1257,7 @@ def __partition_grads(self,
params_to_release: List[Parameter],
grad_partitions: List[Tensor]) -> None:
for param, grad_partition in zip(params_to_release, grad_partitions):
- if param.ds_tensor.ds_numel * dist.get_rank(
+ if param.partition_numel() * dist.get_rank(
self.dp_process_group) > param.ds_numel:
# this grad partition is empty - don't need to do anything
continue
@@ -1866,7 +1423,7 @@ def allreduce_bucket(self,
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
- global_rank = _get_global_rank(self.dp_process_group, rank)
+ global_rank = dist.get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
@@ -1982,9 +1539,7 @@ def _model_parallel_all_reduce(self, tensor, op):
if self.model_parallel_group is None:
pass
else:
- torch.distributed.all_reduce(tensor=tensor,
- op=op,
- group=self.model_parallel_group)
+ dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group)
@instrument_w_nvtx
def get_grad_norm_direct(self, gradients, params, norm_type=2):
@@ -2008,13 +1563,12 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=self.dp_process_group)
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=self.dp_process_group)
# Take max across all GPUs.
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX)
+ self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
# if dist.get_rank() == 0:
@@ -2027,12 +1581,11 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
# Sum across all model parallel GPUs.
total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2))
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=self.dp_process_group)
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=self.dp_process_group)
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM)
+ self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda.item()**(1. / norm_type)
@@ -2277,7 +1830,7 @@ def _overflow_clean_up(self, prev_scale):
see_memory_usage('After overflow after clearing gradients', force=False)
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
logger.info(
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(dist.get_rank(),
@@ -2472,9 +2025,9 @@ def has_overflow(self, partition_gradients=True):
overflow = self.local_overflow
#overflow = self.has_overflow_partitioned_grads_serial()
overflow_gpu = torch.cuda.ByteTensor([overflow])
- torch.distributed.all_reduce(overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=self.dp_process_group)
+ dist.all_reduce(overflow_gpu,
+ op=dist.ReduceOp.MAX,
+ group=self.dp_process_group)
else:
params = []
@@ -2487,8 +2040,7 @@ def has_overflow(self, partition_gradients=True):
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
- self._model_parallel_all_reduce(tensor=overflow_gpu,
- op=torch.distributed.ReduceOp.MAX)
+ self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
@@ -2564,14 +2116,7 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]:
@instrument_w_nvtx
def _partition_all_parameters(self):
- """Partitioning Parameters that were not partitioned usually if parameters
- of modules whose input parameters do not require grad computation do not
- trigger post call and will therefore will remain unpartitioned"""
- self._get_param_coordinator(training=self.module.training).release_and_reset_all(
- self.module)
- for param in iter_params(self.module, recurse=True):
- if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
- raise RuntimeError(f"{param.ds_summary()} expected to be released")
+ self.parameter_offload.partition_all_parameters()
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
@@ -2866,7 +2411,7 @@ def checkpoint_event_epilogue(self):
def _handle_overflow(cpu_sum, x, i):
import math
- rank = torch.distributed.get_rank()
+ rank = dist.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py
index 6e3fd3e0612b..7b683d12eeb4 100755
--- a/deepspeed/runtime/zero/stage_1_and_2.py
+++ b/deepspeed/runtime/zero/stage_1_and_2.py
@@ -3,8 +3,7 @@
'''
import torch
-from torch.distributed.distributed_c10d import _get_global_rank
-import torch.distributed as dist
+from deepspeed import comm as dist
from torch._six import inf
from packaging import version as pkg_version
@@ -517,7 +516,14 @@ def is_moe_group(self, group):
return 'moe' in group and group['moe']
def _configure_moe_settings(self):
- assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
+ # if we're using ZeRO stage 2, ensure contiguous gradients are used
+ if self.partition_gradients:
+ assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
+ # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
+ if not self.partition_gradients and not self.contiguous_gradients:
+ logger.warn(
+ "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental."
+ )
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
assert any([self.is_moe_group(group) for group in self.optimizer.param_groups]), "The model has moe layers, but None of the param groups are marked as MoE. Create a param group with 'moe' key set to True before creating optimizer"
@@ -961,7 +967,7 @@ def average_tensor(self, tensor):
# print(f"Rank {dist.get_rank()} rank offset id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}")
# dist.barrier()
#dist.barrier()
- dst_rank = _get_global_rank(real_dp_process_group[i], dst)
+ dst_rank = dist.get_global_rank(real_dp_process_group[i], dst)
async_handle = dist.reduce(grad_slice,
dst=dst_rank,
group=real_dp_process_group[i],
@@ -1141,12 +1147,11 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=self.dp_process_group)
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=self.dp_process_group)
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM)
+ self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
@@ -1348,7 +1353,7 @@ def allreduce_bucket(self,
# "All Reducing"
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)
else:
- global_rank = _get_global_rank(self.dp_process_group, rank)
+ global_rank = dist.get_global_rank(self.dp_process_group, rank)
dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group)
if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
@@ -1489,9 +1494,7 @@ def _model_parallel_all_reduce(self, tensor, op):
if self.model_parallel_group is None:
pass
else:
- torch.distributed.all_reduce(tensor=tensor,
- op=op,
- group=self.model_parallel_group)
+ dist.all_reduce(tensor=tensor, op=op, group=self.model_parallel_group)
def get_grad_norm_direct(self, gradients, params, norm_type=2):
"""Clips gradient norm of an iterable of parameters.
@@ -1514,13 +1517,12 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX,
- group=self.dp_process_group)
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.MAX,
+ group=self.dp_process_group)
# Take max across all GPUs.
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.MAX)
+ self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
@@ -1535,12 +1537,11 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- torch.distributed.all_reduce(total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM,
- group=self.dp_process_group)
+ dist.all_reduce(total_norm_cuda,
+ op=dist.ReduceOp.SUM,
+ group=self.dp_process_group)
- self._model_parallel_all_reduce(tensor=total_norm_cuda,
- op=torch.distributed.ReduceOp.SUM)
+ self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
@@ -1841,9 +1842,9 @@ def has_overflow(self, partition_gradients=True):
overflow_gpu = torch.cuda.ByteTensor([overflow])
'''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process'''
- torch.distributed.all_reduce(overflow_gpu,
- op=torch.distributed.ReduceOp.MAX,
- group=self.dp_process_group)
+ dist.all_reduce(overflow_gpu,
+ op=dist.ReduceOp.MAX,
+ group=self.dp_process_group)
else:
params = []
@@ -1856,8 +1857,7 @@ def has_overflow(self, partition_gradients=True):
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
- self._model_parallel_all_reduce(tensor=overflow_gpu,
- op=torch.distributed.ReduceOp.MAX)
+ self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)
overflow = overflow_gpu[0].item()
return bool(overflow)
@@ -2222,7 +2222,7 @@ def load_state_dict(self,
def _handle_overflow(cpu_sum, x, i):
import math
- rank = torch.distributed.get_rank()
+ rank = dist.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py
index fff8ebd31ef3..665802ce8e93 100755
--- a/deepspeed/runtime/zero/utils.py
+++ b/deepspeed/runtime/zero/utils.py
@@ -2,7 +2,7 @@
from typing import List
import torch
-import torch.distributed as dist
+from deepspeed import comm as dist
from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adam import FusedAdam
@@ -21,7 +21,7 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
my_group = None
for i in range(data_parallel_size // parameter_parallel_size):
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
- group = torch.distributed.new_group(ranks)
+ group = dist.new_group(ranks)
if rank in ranks:
my_group = group
return my_group
diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py
index efa45853f0df..efe18c264316 100644
--- a/deepspeed/utils/__init__.py
+++ b/deepspeed/utils/__init__.py
@@ -1,5 +1,5 @@
from .logging import logger, log_dist
-from .distributed import init_distributed
+#from .distributed import init_distributed
from .groups import *
from .nvtx import instrument_w_nvtx
from deepspeed.runtime.dataloader import RepeatingLoader
diff --git a/deepspeed/utils/debug.py b/deepspeed/utils/debug.py
index 0841b7627a1a..288f20198f77 100644
--- a/deepspeed/utils/debug.py
+++ b/deepspeed/utils/debug.py
@@ -70,7 +70,7 @@ def printflock(*msgs):
2. Override the usual versions with ::
def print_rank_0(message, debug=False, force=False):
- rank = torch.distributed.get_rank()
+ rank = deepspeed.comm.get_rank()
printflock(f"[{rank}] {message}")
3. run the program and you get both logs non-interleaved
@@ -104,7 +104,7 @@ def log_rank_file(rank, *msgs):
2. Override the usual versions of print_rank_0 in those files with ::
def print_rank_0(message, debug=False, force=False):
- rank = torch.distributed.get_rank()
+ rank = deepspeed.comm.get_rank()
log_rank_file(rank, message)
3. run the program
diff --git a/deepspeed/utils/distributed.py b/deepspeed/utils/distributed.py
deleted file mode 100644
index b643f4672e86..000000000000
--- a/deepspeed/utils/distributed.py
+++ /dev/null
@@ -1,166 +0,0 @@
-'''
-Copyright 2020 The Microsoft DeepSpeed Team
-'''
-import os
-import torch
-from datetime import timedelta
-
-from .logging import logger
-from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
-
-
-def init_distributed(dist_backend="nccl",
- auto_mpi_discovery=True,
- distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
- verbose=True,
- timeout=default_pg_timeout,
- init_method=None):
- """Initialize torch.distributed backend, potentially performing MPI discovery if needed
-
- Arguments:
- dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
-
- auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
-
- distributed_port: Optional (int). torch distributed backend port
-
- verbose: Optional (bool). verbose logging
-
- timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
-
- init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
- """
- required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
- if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
- if verbose:
- logger.info(
- "Not using the DeepSpeed or torch.distributed launchers, attempting to detect MPI environment..."
- )
- if in_aml() and not in_dlts():
- patch_aml_env_for_torch_nccl_backend(verbose=verbose)
- elif in_aws_sm():
- patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
- else:
- mpi_discovery(distributed_port=distributed_port, verbose=verbose)
-
- if not torch.distributed.is_initialized():
- if verbose and int(os.getenv('RANK', '0')) == 0:
- logger.info(
- "Initializing torch distributed with backend: {}".format(dist_backend))
- assert isinstance(timeout, timedelta)
- torch.distributed.init_process_group(backend=dist_backend,
- timeout=timeout,
- init_method=init_method)
-
-
-def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
- """
- Discovery MPI environment via mpi4py and map to relevant torch.distributed state
- """
- from mpi4py import MPI
- import subprocess
- comm = MPI.COMM_WORLD
- rank = comm.Get_rank()
- world_size = comm.Get_size()
-
- master_addr = None
- if rank == 0:
- hostname_cmd = ["hostname -I"]
- result = subprocess.check_output(hostname_cmd, shell=True)
- master_addr = result.decode('utf-8').split()[0]
- master_addr = comm.bcast(master_addr, root=0)
-
- # Determine local rank by assuming hostnames are unique
- proc_name = MPI.Get_processor_name()
- all_procs = comm.allgather(proc_name)
- local_rank = sum([i == proc_name for i in all_procs[:rank]])
-
- os.environ['RANK'] = str(rank)
- os.environ['WORLD_SIZE'] = str(world_size)
- os.environ['LOCAL_RANK'] = str(local_rank)
- os.environ['MASTER_ADDR'] = master_addr
- os.environ['MASTER_PORT'] = str(distributed_port)
-
- if verbose:
- logger.info(
- "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
- .format(os.environ['RANK'],
- os.environ['LOCAL_RANK'],
- os.environ['WORLD_SIZE'],
- os.environ['MASTER_ADDR'],
- os.environ['MASTER_PORT']))
-
- if torch.distributed.is_initialized():
- assert torch.distributed.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
- rank, torch.distributed.get_rank())
- assert torch.distributed.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
- world_size, torch.distributed.get_world_size())
-
-
-def in_aml():
- # Are we running inside an Azure Machine Learning (AML) environment?
- return 'AZUREML_EXPERIMENT_ID' in os.environ
-
-
-def in_aws_sm():
- # Are we running inside an AWS SageMaker environment?
- return 'SM_TRAINING_ENV' in os.environ
-
-
-def in_dlts():
- # Are we running on a DLTS cluster?
- return 'DLTS_JOB_ID' in os.environ
-
-
-def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
- """Helper routine to get and set environment variables.
- This is adapted from Azure ML's documentation available from:
- https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
- """
- os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
- os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
- single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
- os.environ["WORLD_SIZE"])
-
- if not single_node:
- master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
- os.environ["MASTER_ADDR"] = master_node_params[0]
- # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
- if "MASTER_PORT" not in os.environ:
- os.environ["MASTER_PORT"] = str(master_port)
- else:
- os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
- os.environ["MASTER_PORT"] = "54965"
-
- if verbose:
- logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
- os.environ["NCCL_SOCKET_IFNAME"]))
-
- os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
- os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
-
- if verbose:
- logger.info(
- "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
- .format(os.environ['RANK'],
- os.environ['LOCAL_RANK'],
- os.environ['WORLD_SIZE'],
- os.environ['MASTER_ADDR'],
- os.environ['MASTER_PORT']))
-
-
-def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
- """Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
- """
- os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
- os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
- os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
-
- if verbose:
- logger.info(
- "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
- .format(os.environ['RANK'],
- os.environ['LOCAL_RANK'],
- os.environ['WORLD_SIZE'],
- os.environ['MASTER_ADDR'],
- os.environ['MASTER_PORT']))
diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py
index 48e2c0ce57b7..f45b6681c77b 100644
--- a/deepspeed/utils/groups.py
+++ b/deepspeed/utils/groups.py
@@ -25,7 +25,8 @@
"""
import torch
-from torch.distributed.distributed_c10d import _get_global_rank
+from deepspeed import comm as dist
+
from deepspeed.utils import logger, log_dist
from deepspeed.utils.exceptions import DeprecatedException
@@ -33,7 +34,7 @@
_EXPERT_PARALLEL_GROUP = {}
# Expert data parallel group that the current rank belongs to.
_EXPERT_DATA_PARALLEL_GROUP = {}
-# torch.distributed world group needs to be cloned for some cases
+# dist world group needs to be cloned for some cases
_WORLD_GROUP = None
# global object to maintain mpu object if passed by a Megatron client
mpu = None
@@ -79,25 +80,25 @@ def _create_model_parallel(model_parallel_size_):
log_dist(f'Creating model parallel group with size {model_parallel_size_}',
ranks=[0])
# Get world size and rank. Ensure some consistencies.
- assert torch.distributed.is_initialized()
- world_size = torch.distributed.get_world_size()
+ assert dist.is_initialized()
+ world_size = dist.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size)
_ensure_divisibility(world_size, model_parallel_size)
- rank = torch.distributed.get_rank()
+ rank = dist.get_rank()
_DATA_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP = None
# Build the data parallel groups.
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
- group = torch.distributed.new_group(ranks)
+ group = dist.new_group(ranks)
if i == (rank % model_parallel_size):
_DATA_PARALLEL_GROUP = group
# Build the model parallel groups.
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
- group = torch.distributed.new_group(ranks)
+ group = dist.new_group(ranks)
if i == (rank // model_parallel_size):
_MODEL_PARALLEL_GROUP = group
@@ -117,11 +118,11 @@ def _create_expert_and_data_parallel(ep_size):
expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all
data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE
"""
- assert torch.distributed.is_initialized()
+ assert dist.is_initialized()
log_dist(f'Creating expert and data parallel groups with size {ep_size}', ranks=[0])
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
expert_parallel_size_ = min(ep_size, world_size)
_ensure_divisibility(world_size, expert_parallel_size_)
@@ -135,7 +136,7 @@ def _create_expert_and_data_parallel(ep_size):
if group_name not in _EXPERT_DATA_PARALLEL_GROUP:
for i in range(expert_parallel_size_):
ranks = range(i, world_size, expert_parallel_size_)
- group = torch.distributed.new_group(ranks)
+ group = dist.new_group(ranks)
log_dist(
f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}',
[0])
@@ -149,7 +150,7 @@ def _create_expert_and_data_parallel(ep_size):
if group_name not in _EXPERT_PARALLEL_GROUP:
for i in range(world_size // expert_parallel_size_):
ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_)
- group = torch.distributed.new_group(ranks)
+ group = dist.new_group(ranks)
log_dist(
f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}',
[0])
@@ -218,11 +219,11 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu):
expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15]
expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15]
"""
- assert torch.distributed.is_initialized(), "torch distributed is not initialized"
+ assert dist.is_initialized(), "dist is not initialized"
model_parallel_size_ = mpu.get_model_parallel_world_size()
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
dp_world_size = mpu.get_data_parallel_world_size()
dp_rank = mpu.get_data_parallel_rank()
@@ -247,12 +248,12 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu):
expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks(
world_size, model_parallel_size_, expert_parallel_size_)
for ranks in expert_parallel_groups:
- group = torch.distributed.new_group(ranks)
+ group = dist.new_group(ranks)
if rank in list(ranks):
_EXPERT_PARALLEL_GROUP[group_name] = group
for ranks in expert_data_parallel_groups:
- group = torch.distributed.new_group(ranks)
+ group = dist.new_group(ranks)
if rank in list(ranks):
_EXPERT_DATA_PARALLEL_GROUP[group_name] = group
@@ -304,66 +305,64 @@ def _get_expert_data_parallel_group_dict():
def _clone_world_group():
"""Create a clone of the world group
- Note: We need to clone the torch.distributed world group because we
- use _get_global_rank() utility function in DeepSpeed at many places.
- As that function does not work on torch.distributed.group.WORLD, we
+ Note: We need to clone the dist world group because we
+ use dist.get_global_rank() utility function in DeepSpeed at many places.
+ As that function does not work on dist.group.WORLD, we
need to keep a clone of it.
"""
- assert torch.distributed.is_initialized(), "torch.distributed is not initialized"
+ assert dist.is_initialized(), "dist is not initialized"
global _WORLD_GROUP
if _WORLD_GROUP is None:
# If not cloned already, clone the world group
- _WORLD_GROUP = torch.distributed.new_group(
- ranks=range(torch.distributed.get_world_size()))
+ _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size()))
return _WORLD_GROUP
def _get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
- assert torch.distributed.is_initialized(), \
- 'torch.distributed is not initialized'
+ assert dist.is_initialized(), \
+ 'dist is not initialized'
global mpu
if mpu is not None:
return mpu.get_data_parallel_group()
- # Return the clone of torch.distributed world group
+ # Return the clone of dist world group
return _clone_world_group()
def _get_broadcast_src_rank():
- return _get_global_rank(_get_data_parallel_group(), 0)
+ return dist.get_global_rank(_get_data_parallel_group(), 0)
def _get_expert_broadcast_src_rank(group_name):
- return _get_global_rank(_get_expert_data_parallel_group(group_name), 0)
+ return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0)
def _get_expert_parallel_world_size(group_name):
"""Return world size for the expert parallel group."""
- return torch.distributed.get_world_size(group=_get_expert_parallel_group(group_name))
+ return dist.get_world_size(group=_get_expert_parallel_group(group_name))
def _get_expert_data_parallel_world_size(group_name):
"""Return world size for the expert data parallel group."""
- return torch.distributed.get_world_size(
- group=_get_expert_data_parallel_group(group_name))
+ return dist.get_world_size(group=_get_expert_data_parallel_group(group_name))
def _get_expert_parallel_rank(group_name):
"""Return my rank for the expert parallel group."""
- return torch.distributed.get_rank(group=_get_expert_parallel_group(group_name))
+ return dist.get_rank(group=_get_expert_parallel_group(group_name))
def _get_expert_parallel_src_rank(group_name):
"""Calculate the global rank corresponding to a local rank zero
in the expert parallel group."""
- global_rank = torch.distributed.get_rank()
+ global_rank = dist.get_rank()
local_world_size = _get_expert_parallel_world_size(group_name)
return (global_rank // local_world_size) * local_world_size
def _get_expert_data_parallel_rank(group_name):
"""Return my rank for the expert data parallel group."""
- return torch.distributed.get_rank(group=_get_expert_data_parallel_group(group_name))
+ return dist.get_rank(group=_get_expert_data_parallel_group(group_name))
def _get_data_parallel_world_size():
@@ -371,7 +370,7 @@ def _get_data_parallel_world_size():
global mpu
if mpu is not None:
return mpu.get_data_parallel_world_size()
- return torch.distributed.get_world_size(group=_get_data_parallel_group())
+ return dist.get_world_size(group=_get_data_parallel_group())
def _get_model_parallel_world_size():
@@ -387,4 +386,4 @@ def _get_data_parallel_rank():
global mpu
if mpu is not None:
return mpu.get_data_parallel_rank()
- return torch.distributed.get_rank(group=_get_data_parallel_group())
+ return dist.get_rank(group=_get_data_parallel_group())
diff --git a/deepspeed/utils/logging.py b/deepspeed/utils/logging.py
index 3a089b3dd1b8..55636173b0a4 100644
--- a/deepspeed/utils/logging.py
+++ b/deepspeed/utils/logging.py
@@ -2,7 +2,7 @@
import sys
import os
-import torch.distributed as dist
+from deepspeed import comm as dist
log_levels = {
"debug": logging.DEBUG,
diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py
index b04a0591fa52..19671add031f 100755
--- a/deepspeed/utils/timer.py
+++ b/deepspeed/utils/timer.py
@@ -8,6 +8,7 @@
import torch
from numpy import mean
from deepspeed.utils.logging import log_dist
+from deepspeed import comm as dist
from deepspeed.utils import logger
@@ -112,7 +113,7 @@ def memory_usage():
def log(self, names, normalizer=1.0, reset=True, memory_breakdown=False, ranks=None):
"""Log a group of timers."""
assert normalizer > 0.0
- string = f"rank={torch.distributed.get_rank()} time (ms)"
+ string = f"rank={dist.get_rank()} time (ms)"
for name in names:
if name in self.timers:
elapsed_time = (self.timers[name].elapsed(reset=reset) / normalizer)
diff --git a/docs/_config.yml b/docs/_config.yml
index dc79fc033b1a..456b16ff1d16 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -48,6 +48,7 @@ collections:
- mixture-of-experts.md
- mixture-of-experts-nlg.md
- mixture-of-experts-inference.md
+ - monitor.md
- one-cycle.md
- onebit-adam.md
- zero-one-adam.md
diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml
index 20f00b66760f..babcb8da2283 100755
--- a/docs/_data/navigation.yml
+++ b/docs/_data/navigation.yml
@@ -61,8 +61,8 @@ lnav:
url: /docs/config-json/#activation-checkpointing
- title: 'Sparse Attention'
url: /docs/config-json/#sparse-attention
- - title: 'Logging to TensorBoard'
- url: /docs/config-json/#tensorboard-options
+ - title: 'Monitoring'
+ url: /docs/config-json/#monitoring-module-tensorboard-wandb-csv
- title: 'Tutorials'
url: /tutorials/
children:
@@ -100,6 +100,8 @@ lnav:
url: /tutorials/mixture-of-experts-inference/
- title: 'Mixture-of-Quantization'
url: /tutorials/MoQ-tutorial/
+ - title: 'Monitoring'
+ url: /tutorials/monitor
- title: 'One-Cycle Schedule'
url: /tutorials/one-cycle/
- title: 'One-Bit Adam'
diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md
index 53df586ec3e6..3b283b459b18 100755
--- a/docs/_pages/config-json.md
+++ b/docs/_pages/config-json.md
@@ -964,13 +964,15 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s
| ---------------------------------------------------------------------------------------------------------------------------- | ------- |
| List of which step to change difficulty level. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A |
-### Logging to Tensorboard
+### Monitoring Module (TensorBoard, WandB, CSV)
**Note:** Deepspeed logs to TensorBoard through PyTorch. Logging to TensorBoard requires that the `tensorboard` package is installed (read more in the [PyTorch documentation](https://pytorch.org/docs/1.8.0/tensorboard.html)).
{: .notice--warning}
+**Note:** Logging to WandB requires that the `wandb` package is installed (read more in the [WandB documentation](https://docs.wandb.ai/quickstart)).
+{: .notice--warning}
-Deepspeed can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file. Below is an overview of what deepspeed will log.
+Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), or to simple CSV files. Below is an overview of what DeepSpeed will log automatically.
| Field | Description |Conditions |
| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- |
@@ -989,11 +991,11 @@ Deepspeed can log training details into a [Tensorboard](https://www.tensorflow.o
| Fields | Value |Default |
| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- |
| enabled | Whether logging to [Tensorboard](https://www.tensorflow.org/tensorboard) is enabled. | `false` |
-| job_name | Name for the current job. This will become a new directory inside `output_path` | `"DeepSpeedJobName"` |
-| output_path | Path to where the Tensorboard logs will be written. | `~/tensorboard/` |
+| output_path | Path to where the Tensorboard logs will be written. If None, the output path is set under the training script's launching path. | `null` |
+| job_name | Name for the current job. This will become a new directory inside `output_path`. | `"DeepSpeedJobName"` |
-Example of ** tensorboard** configuration:
+Example of **tensorboard** configuration:
```json
"tensorboard": {
@@ -1002,3 +1004,43 @@ Example of ** tensorboard** configuration:
"job_name": "train_bert"
}
```
+
+**wandb**: [dictionary]
+
+| Fields | Value |Default |
+| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- |
+| enabled | Whether logging to [WandB](https://wandb.ai/site) is enabled. | `false` |
+| group | Name for the WandB group. This can be used to group together runs. | `None` |
+| team | Name for the WandB team. | `None` |
+| project | Name for the WandB project. | `deepspeed` |
+
+
+Example of **wandb** configuration:
+
+```json
+"wandb": {
+ "enabled": true,
+ "group": "my_group",
+ "team": "my_team",
+ "project": "my_project"
+}
+```
+
+**csv_monitor**: [dictionary]
+
+| Fields | Value |Default |
+| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- |
+| enabled | Whether logging to local CSV files is enabled. | `false` |
+| output_path | Path to where the csv files will be written. If None, the output path is set under the training script's launching path. | `null` |
+| job_name | Name for the current job. This will become a new directory inside `output_path` | `"DeepSpeedJobName"` |
+
+
+Example of **csv_monitor** configuration:
+
+```json
+"csv_monitor": {
+ "enabled": true,
+ "output_path": "output/ds_logs/",
+ "job_name": "train_bert"
+}
+```
diff --git a/docs/_pages/features.md b/docs/_pages/features.md
index 4410f2b10268..c2da91340bda 100755
--- a/docs/_pages/features.md
+++ b/docs/_pages/features.md
@@ -322,6 +322,33 @@ The DeepSpeed Autotuner uses model information, system information, and heurist
```
The flops profiler can also be used as a standalone package. Please refer to the [Flops Profiler](/tutorials/flops-profiler) tutorial for more details.
+### Monitor
+
+The DeepSpeed Monitor logs live training metrics to one or more monitoring backends, including PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), or simply to CSV files. The Monitor can be configured with one or more backends in the `deepspeed_config` file as follows:
+
+```json
+{
+ "tensorboard": {
+ "enabled": true,
+ "output_path": "output/ds_logs/",
+ "job_name": "train_bert"
+ }
+ "wandb": {
+ "enabled": true,
+ "team": "my_team",
+ "group": "my_group",
+ "project": "my_project"
+ }
+ "csv_monitor": {
+ "enabled": true,
+ "output_path": "output/ds_logs/",
+ "job_name": "train_bert"
+ }
+}
+
+```
+
+The Monitor can also be added to log custom metrics and client codes. Please refer to the [Monitor](/tutorials/monitor) tutorial for more details.
## Sparse Attention
DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial.
diff --git a/docs/_tutorials/monitor.md b/docs/_tutorials/monitor.md
new file mode 100644
index 000000000000..a9c111f8eeec
--- /dev/null
+++ b/docs/_tutorials/monitor.md
@@ -0,0 +1,105 @@
+---
+title: "Monitor"
+excerpt: "Monitor your model's training metrics live and log for future analysis"
+tags: profiling performance-tuning
+---
+
+In this tutorial, we introduce the DeepSpeed Monitor and provide examples of its usage.
+
+ - [Overview](#overview)
+ - [Usage](#usage)
+
+## Overview
+
+Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), and simple CSV files.
+
+Below is a live monitoring view for TensorBoard:
+
+{: .align-center}
+
+Below is a live monitoring view for WandB:
+
+{: .align-center}
+
+## Usage
+
+The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics.
+
+ - [Automatic Monitoring](#automatic-monitoring)
+ - [Custom Monitoring](#custom-monitoring)
+
+### Automatic Monitoring
+
+When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module-tensorboard-wandb-csv) for details.
+
+```json
+{
+ "tensorboard": {
+ "enabled": true,
+ "output_path": "output/ds_logs/",
+ "job_name": "train_bert"
+ }
+ "wandb": {
+ "enabled": true,
+ "team": "my_team",
+ "group": "my_group",
+ "project": "my_project"
+ }
+ "csv_monitor": {
+ "enabled": true,
+ "output_path": "output/ds_logs/",
+ "job_name": "train_bert"
+ }
+}
+```
+
+DeepSpeed will automatically log to all available and enabled monitoring backends listed in the config, and will generate live monitoring views such as those listed above.
+
+### Custom Monitoring
+
+In addition to automatic monitoring, users can log their own custom metrics in client scripts. Currently, there are two ways to initialize Monitor objects:
+
+1. (Recommended) - Create a `MonitorMaster(ds_config.monitor_config)` object, which automatically initializes all monitor backends present in the DeepSpeed configuration
+2. Create a specific `TensorBoardMonitor(ds_config.monitor_config)`, `WandbMonitor(ds_config.monitor_config)`, `csvMonitor(ds_config.monitor_config)` object which will only initialize a specific monitor backend present in the DeepSpeed configuration
+
+
+The steps to create a custom monitor are as follows:
+
+1. Add import to your desired Monitor
+2. Initialize monitor with DeepSpeed config's `monitor_config`
+3. Create a list of one or more 3-tuples in the format `[("label", value, ds_engine.global_samples), ...]`\*
+4. Call `monitor.write_events` on the list from step 3
+
+\* Note - Some Monitor backends don't support mixed sample values. Be sure to use your DeepSpeed engine object's `global_samples` attribute in each 3-tuple
+
+For example usage, see the following modified [DeepSpeedExamples/cifar](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) example:
+
+```python
+# Step 1: Import monitor (and DeepSpeed config, if needed)
+from deepspeed.monitor.monitor import MonitorMaster
+from deepspeed.runtime.config import DeepSpeedConfig
+
+# Step 2: Initialized monitor with DeepSpeed config (get DeepSpeed config object, if needed)
+ds_config = DeepSpeedConfig("ds_config.json")
+monitor = MonitorMaster(ds_config.monitor_config)
+
+for epoch in range(2):
+
+ running_loss = 0.0
+ for i, data in enumerate(trainloader):
+ pre = time.time()
+ inputs, labels = data[0].to(model_engine.local_rank), data[1].to(
+ model_engine.local_rank)
+ if fp16:
+ inputs = inputs.half()
+ outputs = model_engine(inputs)
+ loss = criterion(outputs, labels)
+
+ model_engine.backward(loss)
+ model_engine.step()
+ post = time.time()
+ # Step 3: Create list of 3-tuple records (single entry in this case)
+ events = [("Time per step", post-pre, model_engine.global_samples)]
+ # Step 4: Call monitor.write_events on the list from step 3
+ monitor.write_events(events)
+```
diff --git a/docs/assets/images/tensorboard_monitor.PNG b/docs/assets/images/tensorboard_monitor.PNG
new file mode 100644
index 000000000000..b62d96c335b1
Binary files /dev/null and b/docs/assets/images/tensorboard_monitor.PNG differ
diff --git a/docs/assets/images/wandb_monitor.PNG b/docs/assets/images/wandb_monitor.PNG
new file mode 100644
index 000000000000..f65aa6c5cda8
Binary files /dev/null and b/docs/assets/images/wandb_monitor.PNG differ
diff --git a/docs/index.md b/docs/index.md
index 374359333253..a9eaec6dec6c 100755
--- a/docs/index.md
+++ b/docs/index.md
@@ -6,6 +6,7 @@ title: "Latest News"
---
+* [2022/06/22] DeepSpeed Compression: 50x model size reduction via [XTC](https://arxiv.org/abs/2206.01859) and 5000x compression cost reduction via [ZeroQuant](https://arxiv.org/abs/2206.01861). Stay tuned for upcoming code release!
* [2022/03/21] [Supporting efficient large model training on AMD Instinct GPUs with DeepSpeed](https://cloudblogs.microsoft.com/opensource/2022/03/21/supporting-efficient-large-model-training-on-amd-instinct-gpus-with-deepspeed/)
* [2022/03/07] [Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/)
* [2022/01/19] [DeepSpeed: Advancing MoE inference and training to power next-generation AI scale](https://www.microsoft.com/en-us/research/blog/deepspeed-advancing-moe-inference-and-training-to-power-next-generation-ai-scale/)
@@ -246,6 +247,8 @@ comments.
9. Yucheng Lu, Conglong Li, Minjia Zhang, Christopher De Sa, Yuxiong He. (2022) Maximizing Communication Efficiency for Large-scale Training via 0/1 Adam. [arXiv:2202.06009](https://arxiv.org/abs/2202.06009).
10. Samyam Rajbhandari, Conglong Li, Zhewei Yao, Minjia Zhang, Reza Yazdani Aminabadi, Ammar Ahmad Awan, Jeff Rasley, Yuxiong He. (2022) DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale [arXiv:2201.05596](https://arxiv.org/abs/2201.05596).
11. Shaden Smith, Mostofa Patwary, Brandon Norick, Patrick LeGresley, Samyam Rajbhandari, Jared Casper, Zhun Liu, Shrimai Prabhumoye, George Zerveas, Vijay Korthikanti, Elton Zhang, Rewon Child, Reza Yazdani Aminabadi, Julie Bernauer, Xia Song, Mohammad Shoeybi, Yuxiong He, Michael Houston, Saurabh Tiwary, Bryan Catanzaro. (2022) Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model [arXiv:2201.11990](https://arxiv.org/abs/2201.11990).
+12. Xiaoxia Wu, Zhewei Yao, Minjia Zhang, Conglong Li, Yuxiong He. (2022) Extreme Compression for Pre-trained Transformers Made Simple and Efficient. [arXiv:2206.01859](https://arxiv.org/abs/2206.01859).
+13. Zhewei Yao, Reza Yazdani Aminabadi, Minjia Zhang, Xiaoxia Wu, Conglong Li, Yuxiong He. (2022) ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers. [arXiv:2206.01861](https://arxiv.org/abs/2206.01861).
# Videos
1. DeepSpeed KDD 2020 Tutorial
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 000000000000..c61ee047df9a
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,9 @@
+# DeepSpeed Examples
+
+If you are looking for examples using DeepSpeed please see the following resources:
+
+1. [DeepSpeedExamples](https://github.com/microsoft/DeepSpeedExamples)
+2. [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed)
+3. [DeepSpeed + AzureML](https://github.com/Azure/azureml-examples/tree/main/python-sdk/workflows/train/deepspeed)
+4. [DeepSpeed + Hugging Face Transformers Integration](https://huggingface.co/docs/transformers/main_classes/deepspeed)
+5. [DeepSpeed + PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.deepspeed.html)
diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt
new file mode 100644
index 000000000000..e9f45a392e1d
--- /dev/null
+++ b/requirements/requirements-inf.txt
@@ -0,0 +1,2 @@
+lm-eval>=0.2.0
+transformers
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 895e252a454f..e40a19b622fc 100755
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -4,5 +4,6 @@ numpy
packaging
psutil
py-cpuinfo
+pydantic
torch
tqdm
diff --git a/scripts/check-torchdist.py b/scripts/check-torchdist.py
new file mode 100755
index 000000000000..695bee58fd48
--- /dev/null
+++ b/scripts/check-torchdist.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+"""
+Checks each file in sys.argv for the string "torch.distributed".
+Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py
+"""
+from __future__ import annotations
+
+import os
+import subprocess
+import sys
+
+
+def err(s: str) -> None:
+ print(s, file=sys.stderr)
+
+
+# There are many ways we could search for the string "torch.distributed", but `git
+# grep --no-index` is nice because
+# - it's very fast (as compared to iterating over the file in Python)
+# - we can reasonably assume it's available on all machines
+# - unlike plain grep, which is slower and has different flags on MacOS versus
+# Linux, git grep is always the same.
+res = subprocess.run(
+ ["git",
+ "grep",
+ "-Hn",
+ "--no-index",
+ "torch\.distributed",
+ *sys.argv[1:]],
+ capture_output=True,
+)
+if res.returncode == 0:
+ err('Error: The string "torch.distributed" was found. Please replace all calls to torch.distributed with "deepspeed.comm"'
+ )
+ err(res.stdout.decode("utf-8"))
+ sys.exit(1)
+elif res.returncode == 2:
+ err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
+ err(res.stderr.decode("utf-8"))
+ sys.exit(2)
diff --git a/setup.py b/setup.py
index b35bb48b70da..532d0dd2976a 100755
--- a/setup.py
+++ b/setup.py
@@ -61,7 +61,8 @@ def fetch_requirements(path):
'dev': fetch_requirements('requirements/requirements-dev.txt'),
'autotuning': fetch_requirements('requirements/requirements-autotuning.txt'),
'autotuning_ml': fetch_requirements('requirements/requirements-autotuning-ml.txt'),
- 'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt')
+ 'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),
+ 'inf': fetch_requirements('requirements/requirements-inf.txt')
}
# Add specific cupy version to both onebit extension variants
@@ -221,17 +222,28 @@ def create_dir_symlink(src, dest):
version_str += f'+{git_hash}'
torch_version = ".".join([TORCH_MAJOR, TORCH_MINOR])
+bf16_support = False
# Set cuda_version to 0.0 if cpu-only
cuda_version = "0.0"
+nccl_version = "0.0"
# Set hip_version to 0.0 if cpu-only
hip_version = "0.0"
if torch_available and torch.version.cuda is not None:
cuda_version = ".".join(torch.version.cuda.split('.')[:2])
+ if isinstance(torch.cuda.nccl.version(), int):
+ # This will break if minor version > 9
+ nccl_version = ".".join(str(torch.cuda.nccl.version())[:2])
+ else:
+ nccl_version = ".".join(map(str, torch.cuda.nccl.version()[:2]))
+ if hasattr(torch.cuda, 'is_bf16_supported'):
+ bf16_support = torch.cuda.is_bf16_supported()
if torch_available and hasattr(torch.version, 'hip') and torch.version.hip is not None:
hip_version = ".".join(torch.version.hip.split('.')[:2])
torch_info = {
"version": torch_version,
+ "bf16_support": bf16_support,
"cuda_version": cuda_version,
+ "nccl_version": nccl_version,
"hip_version": hip_version
}
@@ -269,10 +281,18 @@ def create_dir_symlink(src, dest):
},
install_requires=install_requires,
extras_require=extras_require,
- packages=find_packages(exclude=["docker",
- "third_party",
- "csrc",
- "op_builder"]),
+ packages=find_packages(exclude=[
+ "azure",
+ "csrc",
+ "docker",
+ "docs",
+ "examples",
+ "op_builder",
+ "release",
+ "requirements",
+ "scripts",
+ "tests"
+ ]),
include_package_data=True,
scripts=[
'bin/deepspeed',
@@ -280,6 +300,7 @@ def create_dir_symlink(src, dest):
'bin/ds',
'bin/ds_ssh',
'bin/ds_report',
+ 'bin/dsr',
'bin/ds_elastic'
],
classifiers=[
diff --git a/tests/conftest.py b/tests/conftest.py
index a0e4705f4984..4d4f23afe252 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,9 +1,45 @@
# tests directory-specific settings - this file is run automatically by pytest before any tests are run
import sys
+import pytest
from os.path import abspath, dirname, join
+import torch
+import warnings
# allow having multiple repository checkouts and not needing to remember to rerun
# 'pip install -e .[dev]' when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)
+
+
+def pytest_addoption(parser):
+ parser.addoption("--torch_ver", default=None, type=str)
+ parser.addoption("--cuda_ver", default=None, type=str)
+
+
+def validate_version(expected, found):
+ version_depth = expected.count('.') + 1
+ found = '.'.join(found.split('.')[:version_depth])
+ return found == expected
+
+
+@pytest.fixture(scope="session", autouse=True)
+def check_environment(pytestconfig):
+ expected_torch_version = pytestconfig.getoption("torch_ver")
+ expected_cuda_version = pytestconfig.getoption("cuda_ver")
+ if expected_torch_version is None:
+ warnings.warn(
+ "Running test without verifying torch version, please provide an expected torch version with --torch_ver"
+ )
+ elif not validate_version(expected_torch_version, torch.__version__):
+ pytest.exit(
+ f"expected torch version {expected_torch_version} did not match found torch version {torch.__version__}",
+ returncode=2)
+ if expected_cuda_version is None:
+ warnings.warn(
+ "Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver"
+ )
+ elif not validate_version(expected_cuda_version, torch.version.cuda):
+ pytest.exit(
+ f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}",
+ returncode=2)
diff --git a/tests/onebit/test_mpi_backend.py b/tests/onebit/test_mpi_backend.py
index 785021cf0935..57dc7371c4f9 100644
--- a/tests/onebit/test_mpi_backend.py
+++ b/tests/onebit/test_mpi_backend.py
@@ -1,7 +1,7 @@
from mpi4py import MPI
import time
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import numpy as np
import deepspeed
@@ -19,7 +19,7 @@
device = torch.device('cuda', rank % torch.cuda.device_count())
-# A simulated compression function using torch.distributed
+# A simulated compression function using deepspeed.comm
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
@@ -37,7 +37,7 @@ def torch_sim(a):
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
- torch.distributed.barrier()
+ dist.barrier()
return a_server_compressed, worker_error, server_error
diff --git a/tests/onebit/test_mpi_perf.py b/tests/onebit/test_mpi_perf.py
index 6017ec873c21..b782cbc5dc3e 100644
--- a/tests/onebit/test_mpi_perf.py
+++ b/tests/onebit/test_mpi_perf.py
@@ -1,7 +1,7 @@
from mpi4py import MPI
import time
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import numpy as np
import deepspeed
diff --git a/tests/onebit/test_nccl_backend.py b/tests/onebit/test_nccl_backend.py
index 16de37174c10..6a99c9fe2a9c 100644
--- a/tests/onebit/test_nccl_backend.py
+++ b/tests/onebit/test_nccl_backend.py
@@ -1,6 +1,6 @@
import time
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import numpy as np
import argparse
import deepspeed
@@ -25,7 +25,7 @@
local_rank = args.local_rank
-# A simulated compression function using torch.distributed
+# A simulated compression function using deepspeed.comm
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
@@ -43,7 +43,7 @@ def torch_sim(a):
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
- torch.distributed.barrier()
+ dist.barrier()
return a_server_compressed, worker_error, server_error
diff --git a/tests/onebit/test_nccl_perf.py b/tests/onebit/test_nccl_perf.py
index 1374cda4ddce..d4cfbccfd7da 100644
--- a/tests/onebit/test_nccl_perf.py
+++ b/tests/onebit/test_nccl_perf.py
@@ -1,6 +1,6 @@
import time
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import numpy as np
import argparse
import deepspeed
@@ -62,7 +62,7 @@
for i in range(iters):
timers('compressed_allreduce').start()
backend.compressed_allreduce(a, worker_error, server_error, local_rank)
- #torch.distributed.all_reduce(a_compressed)
+ #deepspeed.comm.all_reduce(a_compressed)
timers('compressed_allreduce').stop()
time_list.append(timers('compressed_allreduce').elapsed())
diff --git a/tests/pytest.ini b/tests/pytest.ini
new file mode 100644
index 000000000000..a52a49e5bbc3
--- /dev/null
+++ b/tests/pytest.ini
@@ -0,0 +1,6 @@
+[pytest]
+addopts = -m "not sequential and not nightly and not inference"
+markers =
+ sequential:Tests that need to be run sequentially
+ inference:Inference model tests
+ nightly:Tests that should be run nightly
diff --git a/tests/small_model_debugging/test_model.py b/tests/small_model_debugging/test_model.py
index c957bf8f1ecb..720adeab3842 100755
--- a/tests/small_model_debugging/test_model.py
+++ b/tests/small_model_debugging/test_model.py
@@ -56,7 +56,7 @@ def get_args(tmpdir, config_dict):
def print0(msg):
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print(msg, flush=True)
@@ -95,7 +95,7 @@ def print0(msg):
def print_params(tag, model):
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
for n, p in model.named_parameters():
print0("{} {}:{}".format(tag, n, p))
@@ -107,7 +107,7 @@ def print_params(tag, model):
#print_params('pre-train', model)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
- if torch.distributed.get_rank() == 0:
+ if dist.get_rank() == 0:
print("LOSS:", loss.item())
model.backward(loss)
model.step()
diff --git a/tests/unit/common.py b/tests/unit/common.py
index 57ed50f17cea..10037008aa90 100644
--- a/tests/unit/common.py
+++ b/tests/unit/common.py
@@ -2,7 +2,7 @@
import time
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
from torch.multiprocessing import Process
import deepspeed
@@ -66,7 +66,7 @@ def set_cuda_visibile():
def distributed_test(world_size=2, backend='nccl'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
- torch.distributed, and catching of errors.
+ deepspeed.comm, and catching of errors.
Usage example:
@distributed_test(worker_size=[2,3])
@@ -82,7 +82,7 @@ def my_test():
def dist_wrap(run_func):
"""Second-level decorator for dist_test. This actually wraps the function. """
def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
- """Initialize torch.distributed and execute the user function. """
+ """Initialize deepspeed.comm and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = get_master_port()
os.environ['LOCAL_RANK'] = str(local_rank)
@@ -96,6 +96,8 @@ def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
set_cuda_visibile()
deepspeed.init_distributed(dist_backend=backend)
+ #dist.init_process_group(backend=backend)
+ dist.barrier()
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
@@ -103,10 +105,9 @@ def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
run_func(*func_args, **func_kwargs)
# make sure all ranks finish at the same time
- torch.distributed.barrier()
-
+ dist.barrier()
# tear down after test completes
- torch.distributed.destroy_process_group()
+ dist.destroy_process_group()
def dist_launcher(num_procs, *func_args, **func_kwargs):
"""Launch processes and gracefully handle failures. """
diff --git a/tests/unit/modeling.py b/tests/unit/modeling.py
index 8bf2d6dba9da..e3b6b4d836f0 100755
--- a/tests/unit/modeling.py
+++ b/tests/unit/modeling.py
@@ -35,7 +35,7 @@
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils import checkpoint
-import torch.distributed as dist
+import deepspeed.comm as dist
from torch.nn import Module
from torch.nn.parameter import Parameter
diff --git a/tests/unit/modelingpreln.py b/tests/unit/modelingpreln.py
index 7661303a4145..34a933bc6b29 100755
--- a/tests/unit/modelingpreln.py
+++ b/tests/unit/modelingpreln.py
@@ -35,7 +35,7 @@
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils import checkpoint
-import torch.distributed as dist
+import deepspeed.comm as dist
from torch.nn import Module
from torch.nn.parameter import Parameter
diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py
index aaa8ffc67971..f94ee288d4b2 100755
--- a/tests/unit/simple_model.py
+++ b/tests/unit/simple_model.py
@@ -6,6 +6,8 @@
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.moe.layer import MoE
+import deepspeed.comm as dist
+
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
@@ -261,10 +263,10 @@ def create_deepspeed_args():
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepspeed = True
- if torch.distributed.is_initialized():
+ if dist.is_initialized():
# We assume up to one full node executing unit tests
- assert torch.distributed.get_world_size() <= torch.cuda.device_count()
- args.local_rank = torch.distributed.get_rank()
+ assert dist.get_world_size() <= torch.cuda.device_count()
+ args.local_rank = dist.get_rank()
return args
diff --git a/tests/unit/test_activation_checkpointing.py b/tests/unit/test_activation_checkpointing.py
index e66f2abf7408..ad32a53385f2 100644
--- a/tests/unit/test_activation_checkpointing.py
+++ b/tests/unit/test_activation_checkpointing.py
@@ -59,8 +59,8 @@ def _match_outputs(ref, tgt):
assert torch.equal(ref, tgt)
-# This is distributed because checkpoint() assumes that torch.distributed is initialized.
-# torch.distributed is used with activation partitioning, but not for these simple cases.
+# This is distributed because checkpoint() assumes that deepspeed.comm is initialized.
+# deepspeed.comm is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
def _test_activation_checkpoint(module, *inputs):
# Move to device
@@ -82,8 +82,8 @@ def _test_activation_checkpoint(module, *inputs):
_match_outputs(b, t)
-# This is distributed because checkpoint() assumes that torch.distributed is initialized.
-# torch.distributed is used with activation partitioning, but not for these simple cases.
+# This is distributed because checkpoint() assumes that deepspeed.comm is initialized.
+# deepspeed.comm is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
def _test_activation_checkpoint_ordering(module, expected_ordering, *inputs):
# Move to device
diff --git a/tests/unit/test_aio.py b/tests/unit/test_aio.py
index fdec95a35ae7..389d422bbc91 100755
--- a/tests/unit/test_aio.py
+++ b/tests/unit/test_aio.py
@@ -3,7 +3,7 @@
import filecmp
import torch
import deepspeed
-import torch.distributed as dist
+import deepspeed.comm as dist
from deepspeed.ops.aio import AsyncIOBuilder
from .common import distributed_test
diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py
index c989f226cf2d..ddac8a3dcd02 100755
--- a/tests/unit/test_checkpointing.py
+++ b/tests/unit/test_checkpointing.py
@@ -1,6 +1,6 @@
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import deepspeed
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
@@ -885,11 +885,9 @@ def _helper(args, model, hidden_dim):
model_parameters=model.parameters())
if valid_mode == "FAIL":
with pytest.raises(AssertionError):
- model.save_checkpoint(save_dir=tmpdir,
- tag=f"tag-{torch.distributed.get_rank()}")
+ model.save_checkpoint(save_dir=tmpdir, tag=f"tag-{dist.get_rank()}")
else:
- model.save_checkpoint(save_dir=tmpdir,
- tag=f"tag-{torch.distributed.get_rank()}")
+ model.save_checkpoint(save_dir=tmpdir, tag=f"tag-{dist.get_rank()}")
_helper(args=args, model=model, hidden_dim=hidden_dim)
diff --git a/tests/unit/test_coalesced_collectives.py b/tests/unit/test_coalesced_collectives.py
index fb6b5354a158..a7e0ec35751b 100644
--- a/tests/unit/test_coalesced_collectives.py
+++ b/tests/unit/test_coalesced_collectives.py
@@ -3,7 +3,7 @@
import pytest
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced
from .common import distributed_test
@@ -17,7 +17,7 @@ def test_reduce_scatter_coalesced_single_input():
dtype=torch.half,
device=torch.cuda.current_device())
- (output, ) = reduce_scatter_coalesced([input], dist.group.WORLD)
+ (output, ) = reduce_scatter_coalesced([input], dist.get_world_group())
assert output.shape == (3, )
assert torch.allclose(output, torch.full_like(output, 0.5))
@@ -35,7 +35,7 @@ def test_reduce_scatter_coalesced_two_inputs():
**tensor_kwargs),
]
- output1, output2 = reduce_scatter_coalesced(inputs, dist.group.WORLD)
+ output1, output2 = reduce_scatter_coalesced(inputs, dist.get_world_group())
if dist.get_rank() == 0:
assert output1.shape == (3, )
@@ -53,7 +53,7 @@ def test_reduce_scatter_coalesced_two_inputs():
def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz():
input = torch.zeros((1, ), dtype=torch.half, device=torch.cuda.current_device())
- (output, ) = reduce_scatter_coalesced([input], dist.group.WORLD)
+ (output, ) = reduce_scatter_coalesced([input], dist.get_world_group())
if dist.get_rank() == 0:
assert output.shape == (1, )
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
index a88cb2931d95..a493fd9ca505 100755
--- a/tests/unit/test_config.py
+++ b/tests/unit/test_config.py
@@ -8,7 +8,7 @@
from .common import distributed_test, get_test_path
from .simple_model import SimpleModel, create_config_from_dict, random_dataloader
-import torch.distributed as dist
+import deepspeed.comm as dist
# A test on its own
import deepspeed
diff --git a/tests/unit/test_configurable_parallel.py b/tests/unit/test_configurable_parallel.py
index 35486181072b..daa2cd1791b0 100755
--- a/tests/unit/test_configurable_parallel.py
+++ b/tests/unit/test_configurable_parallel.py
@@ -6,7 +6,7 @@
import random
import numpy as np
import torch.multiprocessing as mp
-import torch.distributed as dist
+import deepspeed.comm as dist
from .common import distributed_test
from .simple_model import args_from_dict, create_deepspeed_args
from .megatron_model import get_gpt2_model, get_megatron_version
diff --git a/tests/unit/test_curriculum_learning.py b/tests/unit/test_curriculum_learning.py
index 3677b5966781..22dde25fcd35 100644
--- a/tests/unit/test_curriculum_learning.py
+++ b/tests/unit/test_curriculum_learning.py
@@ -1,5 +1,5 @@
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import deepspeed
import argparse
import pytest
diff --git a/tests/unit/test_dist.py b/tests/unit/test_dist.py
index d37133603ce4..6e6fabbfa9d8 100644
--- a/tests/unit/test_dist.py
+++ b/tests/unit/test_dist.py
@@ -1,5 +1,5 @@
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
from .common import distributed_test
diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py
index 0cd258f590d1..d8826e59e8e7 100755
--- a/tests/unit/test_fp16.py
+++ b/tests/unit/test_fp16.py
@@ -1,7 +1,7 @@
import math
from deepspeed.utils import groups
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import deepspeed
import argparse
import pytest
diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py
index 5b3a0cc681bc..c3dbc49055b2 100644
--- a/tests/unit/test_inference.py
+++ b/tests/unit/test_inference.py
@@ -1,123 +1,325 @@
import os
+import sys
+import time
import torch
import pytest
+import itertools
import deepspeed
+from deepspeed.git_version_info import torch_info
from collections import defaultdict
-from transformers import pipeline
from .common import distributed_test
from packaging import version as pkg_version
+from deepspeed.ops.op_builder import OpBuilder
+from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
+from huggingface_hub import HfApi
-pytest.task_query_dict = {
- "fill-mask":
- defaultdict(
- lambda: "Hello I'm a [MASK] model.",
- {"roberta-base": "Hello I'm a model."},
- ),
- "question-answering":
- defaultdict(lambda: {
- "question": "What is the greatest?",
- "context": "DeepSpeed is the greatest",
- }),
- "text-classification":
- defaultdict(lambda: "DeepSpeed is the greatest"),
- "token-classification":
- defaultdict(lambda: "My name is jean-baptiste and I live in montreal."),
- "text-generation":
- defaultdict(lambda: "DeepSpeed is the greatest"),
-}
-pytest.task_model_dict = {
- "fill-mask": {
- "bert": "bert-base-cased",
- "roberta": "roberta-base"
- },
- "question-answering": {
- "bert": "deepset/minilm-uncased-squad2",
- "roberta": "deepset/roberta-base-squad2",
- },
- "text-classification": {
- "bert": "cross-encoder/ms-marco-MiniLM-L-12-v2",
- "roberta": "j-hartmann/emotion-english-distilroberta-base",
- },
- "token-classification": {
- "bert": "dslim/bert-base-NER",
- "roberta": "Jean-Baptiste/roberta-large-ner-english",
- },
- "text-generation": {
- "gpt2": "distilgpt2",
- "gpt_neo": "Norod78/hebrew-bad_wiki-gpt_neo-tiny",
- "gptj": "EleutherAI/gpt-j-6B",
- },
+
+@pytest.fixture(scope="module", autouse=True)
+def lm_eval_imports():
+ global lm_eval
+ import lm_eval
+ import lm_eval.models
+ import lm_eval.tasks
+ import lm_eval.evaluator
+
+
+rocm_version = OpBuilder.installed_rocm_version()
+if rocm_version != (0, 0):
+ pytest.skip("skip inference tests on rocm for now", allow_module_level=True)
+
+_bert_models = [
+ "bert-base-cased",
+ "bert-base-uncased",
+ "bert-large-cased",
+ "bert-large-uncased",
+ "bert-base-multilingual-cased",
+ "bert-base-multilingual-uncased",
+ "deepset/minilm-uncased-squad2",
+ "cross-encoder/ms-marco-MiniLM-L-12-v2",
+ "dslim/bert-base-NER",
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
+ "distilbert-base-cased-distilled-squad",
+]
+_roberta_models = [
+ "roberta-large",
+ "roberta-base",
+ "deepset/roberta-base-squad2",
+ "j-hartmann/emotion-english-distilroberta-base",
+ "Jean-Baptiste/roberta-large-ner-english",
+]
+_gpt_models = [
+ "gpt2",
+ "distilgpt2",
+ "Norod78/hebrew-bad_wiki-gpt_neo-tiny",
+ "EleutherAI/gpt-j-6B",
+]
+_all_models = HfApi().list_models()
+
+test_models = set(_bert_models + _roberta_models + _gpt_models)
+test_tasks = [
+ "fill-mask",
+ "question-answering",
+ "text-classification",
+ "token-classification",
+ "text-generation",
+]
+pytest.all_models = {
+ task: [m.modelId for m in _all_models if m.pipeline_tag == task]
+ for task in test_tasks
}
+_model_w_tasks = itertools.product(*[test_models, test_tasks])
+
+
+def _valid_model_task(model_task):
+ m, t = model_task
+ return m in pytest.all_models[t]
+
+
+pytest.models_w_tasks = list(filter(_valid_model_task, _model_w_tasks))
+pytest.mt_names = [f"{m}-{t}" for m, t in pytest.models_w_tasks]
+"""
+These fixtures iterate all combinations of tasks and models, dtype, & cuda_graph
+"""
+
+
+@pytest.fixture(params=pytest.models_w_tasks, ids=pytest.mt_names)
+def model_w_task(request):
+ return request.param
+
+
+@pytest.fixture(params=[torch.float, torch.half], ids=["fp32", "fp16"])
+def dtype(request):
+ return request.param
+
+
+@pytest.fixture(params=[True, False], ids=["CG", "noCG"])
+def enable_cuda_graph(request):
+ return request.param
+
+
+"""
+This fixture will validate the configuration
+"""
+
+
+@pytest.fixture()
+def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
+ model, task = model_w_task
+ if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
+ msg = "DS inference injection doesn't work well on older torch versions"
+ elif model not in pytest.all_models[task]:
+ msg = f"Not a valid model / task combination: {model} / {task}"
+ elif enable_cuda_graph and (torch_info["cuda_version"] == "0.0"):
+ msg = "CUDA not detected, cannot use CUDA Graph"
+ elif enable_cuda_graph and pkg_version.parse(
+ torch.__version__) < pkg_version.parse("1.10"):
+ msg = "CUDA Graph is only available in torch versions >= 1.10"
+ elif ("gpt-j-6B" in model) and (dtype == torch.float):
+ msg = f"Not enough GPU memory to run {model} with dtype {dtype}"
+ else:
+ msg = ""
+ return msg
+
+
+"""
+These fixtures can be used to customize the query, inference args, and assert
+statement for each combination of model /task
+"""
+
@pytest.fixture
-def model(task, model_family):
- if model_family not in pytest.task_model_dict[task]:
- pytest.skip(f"No models in family {model_family} for task {task}")
- return pytest.task_model_dict[task][model_family]
+def query(model_w_task):
+ model, task = model_w_task
+ if task == "fill-mask":
+ if "roberta" in model:
+ return "Hello I'm a model."
+ else:
+ return "Hell I'm a [MASK] model."
+ elif task == "question-answering":
+ return {
+ "question": "What's my name?",
+ "context": "My name is Clara and I live in Berkeley",
+ }
+ elif task == "text-classification":
+ return "DeepSpeed is the greatest"
+ elif task == "token-classification":
+ return "My name is jean-baptiste and I live in montreal."
+ elif task == "text-generation":
+ return "DeepSpeed is the greatest"
+ else:
+ NotImplementedError(f'query for task "{task}" is not implemented')
@pytest.fixture
-def query(task, model):
- return pytest.task_query_dict[task][model]
+def inf_kwargs(model_w_task):
+ model, task = model_w_task
+ if task == "text-generation":
+ return {"do_sample": False}
+ else:
+ return {}
-@pytest.mark.parametrize(
- "task",
- (
- "fill-mask",
- "question-answering",
- "text-classification",
- "token-classification",
- "text-generation",
- ),
-)
-@pytest.mark.parametrize("model_family", ("bert", "roberta", "gpt2", "gpt_neo"))
-def test_model_task_inject(task, model, query, dtype=torch.float):
- if pkg_version.parse(torch.__version__) <= pkg_version.parse('1.2'):
- pytest.skip("DS inference injection doesn't work well on older torch versions")
+@pytest.fixture
+def assert_fn(model_w_task):
+ model, task = model_w_task
+ if task == "fill-mask":
+ return lambda x, y: set(res["token_str"] for res in x) == set(
+ res["token_str"] for res in y
+ )
+ elif task == "question-answering":
+ return lambda x, y: x["answer"] == y["answer"]
+ elif task == "text-classification":
+ return lambda x, y: set(res["label"] for res in x) == set(
+ res["label"] for res in y
+ )
+ elif task == "token-classification":
+ return lambda x, y: set(ent["word"] for ent in x) == set(
+ ent["word"] for ent in y
+ )
+ elif task == "text-generation":
+ return lambda x, y: set(res["generated_text"] for res in x) == set(
+ res["generated_text"] for res in y
+ )
+ else:
+ NotImplementedError(f'assert_fn for task "{task}" is not implemented')
+
+
+"""
+Tests
+"""
+
+
+@pytest.mark.inference
+def test_model_task(
+ model_w_task,
+ dtype,
+ enable_cuda_graph,
+ query,
+ inf_kwargs,
+ assert_fn,
+ invalid_model_task_config,
+):
+ if invalid_model_task_config:
+ pytest.skip(invalid_model_task_config)
+
+ model, task = model_w_task
@distributed_test(world_size=[1])
def _go():
local_rank = int(os.getenv("LOCAL_RANK", "0"))
- world_size = int(os.getenv("WORLD_SIZE", "1"))
- generator = pipeline(task, model=model, device=local_rank)
- generator.model = deepspeed.init_inference(
- generator.model,
- mp_size=world_size,
+ if "gpt-j-6B" in model and dtype == torch.half:
+ _model = AutoModelForCausalLM.from_pretrained(model)
+ tokenizer = AutoTokenizer.from_pretrained(model)
+ _model.half()
+ pipe = pipeline(
+ task,
+ model=_model,
+ tokenizer=tokenizer,
+ device=local_rank,
+ framework="pt",
+ )
+ else:
+ pipe = pipeline(task, model=model, device=local_rank, framework="pt")
+ if dtype == torch.half:
+ pipe.model.half()
+
+ # Warm-up queries for perf measurement
+ for i in range(10):
+ _ = pipe(query, **inf_kwargs)
+ torch.cuda.synchronize()
+ start = time.time()
+ bs_output = pipe(query, **inf_kwargs)
+ torch.cuda.synchronize()
+ bs_time = time.time() - start
+
+ pipe.model = deepspeed.init_inference(
+ pipe.model,
+ mp_size=1,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
+ enable_cuda_graph=enable_cuda_graph,
)
+ # Warm-up queries for perf measurement
+ for i in range(10):
+ _ = pipe(query, **inf_kwargs)
+ torch.cuda.synchronize()
+ start = time.time()
+ ds_output = pipe(query, **inf_kwargs)
+ torch.cuda.synchronize()
+ ds_time = time.time() - start
- response = generator(query)
+ if task == "text-generation":
+ bs_output = pipe(query, **inf_kwargs)
- _go()
+ # These performance tests are only measuring the time for a single
+ # inference request, we just want to check that performance isn't terrible
+ assert ds_time <= (bs_time * 1.1)
+ assert assert_fn(bs_output, ds_output)
+ _go()
-@pytest.mark.parametrize("dtype", [(torch.float), (torch.half)])
-def test_gpt2_inject(dtype):
- if pkg_version.parse(torch.__version__) <= pkg_version.parse('1.2'):
- pytest.skip("DS inference injection doesn't work well on older torch versions")
+@pytest.mark.nightly
+@pytest.mark.parametrize(
+ "model_family, model_name",
+ (
+ ["gpt2",
+ "EleutherAI/gpt-neo-2.7B"],
+ ["gpt2",
+ "EleutherAI/gpt-j-6B"],
+ ["gpt2",
+ "gpt2-xl"],
+ ),
+)
+@pytest.mark.parametrize("task", ["lambada"])
+def test_lm_correctness(model_family, model_name, task):
@distributed_test(world_size=[1])
def _go():
- local_rank = int(os.getenv("LOCAL_RANK", "0"))
- world_size = int(os.getenv("WORLD_SIZE", "1"))
- generator = pipeline("text-generation", model="gpt2", device=local_rank)
+ local_rank = os.getenv("LOCAL_RANK", "0")
+ device = torch.device(f"cuda:{local_rank}")
+ dtype = torch.float
+ task_dict = lm_eval.tasks.get_task_dict([task])
+
+ if 'gpt-j-6B' in model_name:
+ dtype = torch.half
+ lm = lm_eval.models.get_model(model_family).create_from_arg_string(
+ f"pretrained={model_name}",
+ {"device": "cpu"})
+ setattr(lm, model_family, getattr(lm, model_family).half().to(device))
+ lm._device = device
+ else:
+ lm = lm_eval.models.get_model(model_family).create_from_arg_string(
+ f"pretrained={model_name}",
+ {"device": f"cuda:{local_rank}"})
+
+ torch.cuda.synchronize()
+ start = time.time()
+ bs_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict)
+ torch.cuda.synchronize()
+ bs_time = time.time() - start
- generator.model = deepspeed.init_inference(
- generator.model,
- mp_size=world_size,
+ ds_model = deepspeed.init_inference(
+ getattr(lm,
+ model_family),
+ mp_size=1,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True,
+ enable_cuda_graph=False,
)
+ setattr(lm, model_family, ds_model)
+ torch.cuda.synchronize()
+ start = time.time()
+ ds_output = lm_eval.evaluator.evaluate(lm=lm, task_dict=task_dict)
+ torch.cuda.synchronize()
+ ds_time = time.time() - start
- prompt = "DeepSpeed is"
- string_1 = generator(prompt, do_sample=False, max_length=128)
- string_2 = generator(prompt, do_sample=False, max_length=128)
- assert string_1 == string_2
+ ppl_diff = abs(bs_output["results"][task]["ppl"] -
+ ds_output["results"][task]["ppl"])
+ assert ds_time <= bs_time
+ assert ppl_diff < 0.01
_go()
diff --git a/tests/unit/test_moe.py b/tests/unit/test_moe.py
index e10356902a68..779bafbb758f 100644
--- a/tests/unit/test_moe.py
+++ b/tests/unit/test_moe.py
@@ -1,7 +1,7 @@
import math
from deepspeed.utils import groups
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import deepspeed
import argparse
import pytest
diff --git a/tests/unit/test_monitor.py b/tests/unit/test_monitor.py
new file mode 100644
index 000000000000..95f045d54dea
--- /dev/null
+++ b/tests/unit/test_monitor.py
@@ -0,0 +1,133 @@
+import pytest
+
+from deepspeed.monitor.constants import *
+
+from deepspeed.monitor.monitor import MonitorMaster
+from deepspeed.monitor.tensorboard import TensorBoardMonitor
+from deepspeed.monitor.wandb import WandbMonitor
+from deepspeed.monitor.csv_monitor import csvMonitor
+
+from .simple_model import *
+from .common import distributed_test
+from deepspeed.runtime.config import DeepSpeedConfig
+from deepspeed.monitor.config import DeepSpeedMonitorConfig
+
+try:
+ import tensorboard
+ _tb_available = True
+except ImportError:
+ _tb_available = False
+tb_available = pytest.mark.skipif(not _tb_available,
+ reason="tensorboard is not installed")
+
+try:
+ import wandb
+ _wandb_available = True
+except ImportError:
+ _wandb_available = False
+wandb_available = pytest.mark.skipif(not _wandb_available,
+ reason="wandb is not installed")
+
+
+@tb_available
+def test_tensorboard(tmpdir):
+ @distributed_test(world_size=2)
+ def _test_tensorboard():
+ config_dict = {
+ "train_batch_size": 2,
+ "tensorboard": {
+ "enabled": True,
+ "output_path": "test_output/ds_logs/",
+ "job_name": "test"
+ }
+ }
+ ds_config = DeepSpeedConfig(config_dict)
+ tb_monitor = TensorBoardMonitor(ds_config.monitor_config)
+ assert tb_monitor.enabled == True
+ assert tb_monitor.output_path == "test_output/ds_logs/"
+ assert tb_monitor.job_name == "test"
+
+ _test_tensorboard()
+
+
+@tb_available
+def test_empty_tensorboard(tmpdir):
+ @distributed_test(world_size=2)
+ def _test_empty_tensorboard():
+ config_dict = {"train_batch_size": 2, "tensorboard": {}}
+ ds_config = DeepSpeedConfig(config_dict)
+ tb_monitor = TensorBoardMonitor(ds_config.monitor_config)
+ assert tb_monitor.enabled == TENSORBOARD_ENABLED_DEFAULT
+ assert tb_monitor.output_path == TENSORBOARD_OUTPUT_PATH_DEFAULT
+ assert tb_monitor.job_name == TENSORBOARD_JOB_NAME_DEFAULT
+
+ _test_empty_tensorboard()
+
+
+@wandb_available
+def test_wandb(tmpdir):
+ @distributed_test(world_size=2)
+ def _test_wandb():
+ config_dict = {
+ "train_batch_size": 2,
+ "wandb": {
+ "enabled": False,
+ "group": "my_group",
+ "team": "my_team",
+ "project": "my_project"
+ }
+ }
+ ds_config = DeepSpeedConfig(config_dict)
+ wandb_monitor = WandbMonitor(ds_config.monitor_config)
+ assert wandb_monitor.enabled == False
+ assert wandb_monitor.group == "my_group"
+ assert wandb_monitor.team == "my_team"
+ assert wandb_monitor.project == "my_project"
+
+ _test_wandb()
+
+
+@wandb_available
+def test_empty_wandb(tmpdir):
+ @distributed_test(world_size=2)
+ def _test_empty_wandb():
+ config_dict = {"train_batch_size": 2, "wandb": {}}
+ ds_config = DeepSpeedConfig(config_dict)
+ wandb_monitor = WandbMonitor(ds_config.monitor_config)
+ assert wandb_monitor.enabled == WANDB_ENABLED_DEFAULT
+ assert wandb_monitor.group == WANDB_GROUP_NAME_DEFAULT
+ assert wandb_monitor.team == WANDB_TEAM_NAME_DEFAULT
+ assert wandb_monitor.project == WANDB_PROJECT_NAME_DEFAULT
+
+ _test_empty_wandb()
+
+
+def test_csv_monitor(tmpdir):
+ @distributed_test(world_size=2)
+ def _test_csv_monitor():
+ config_dict = {
+ "train_batch_size": 2,
+ "csv_monitor": {
+ "enabled": True,
+ "output_path": "test_output/ds_logs/",
+ "job_name": "test"
+ }
+ }
+ ds_config = DeepSpeedConfig(config_dict)
+ csv_monitor = csvMonitor(ds_config.monitor_config)
+ assert csv_monitor.enabled == True
+ assert csv_monitor.output_path == "test_output/ds_logs/"
+ assert csv_monitor.job_name == "test"
+
+ _test_csv_monitor()
+
+
+def test_empty_csv_monitor(tmpdir):
+ @distributed_test(world_size=2)
+ def _test_empty_csv_monitor():
+ config_dict = {"train_batch_size": 2, "csv_monitor": {}}
+ ds_config = DeepSpeedConfig(config_dict)
+ csv_monitor = csvMonitor(ds_config.monitor_config)
+ assert csv_monitor.enabled == CSV_MONITOR_ENABLED_DEFAULT
+ assert csv_monitor.output_path == CSV_MONITOR_OUTPUT_PATH_DEFAULT
+ assert csv_monitor.job_name == CSV_MONITOR_JOB_NAME_DEFAULT
diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py
index bfcbdceb0ba7..b7806b0831c7 100644
--- a/tests/unit/test_onebit.py
+++ b/tests/unit/test_onebit.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.distributed as dist
+import deepspeed.comm as dist
import deepspeed
import argparse
import pytest
@@ -1274,7 +1274,7 @@ def _test_compressed_allreduce_basic():
local_rank = dist.get_rank()
device = torch.device("cuda", dist.get_rank())
- # A simulated compression function using torch.distributed
+ # A simulated compression function using deepspeed.comm
def torch_sim(a):
a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)
scale = a.norm() / np.sqrt(a.numel())
@@ -1295,7 +1295,7 @@ def torch_sim(a):
rank = dist.get_rank()
server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank]
torch.cuda.synchronize()
- torch.distributed.barrier()
+ dist.barrier()
return a_server_compressed, worker_error, server_error
tensor_size = 300 * 2**20
diff --git a/tests/unit/test_partition.py b/tests/unit/test_partition.py
index f766e4596509..cf4852e477e2 100644
--- a/tests/unit/test_partition.py
+++ b/tests/unit/test_partition.py
@@ -1,7 +1,7 @@
import pytest
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
from deepspeed.runtime.utils import partition_uniform
from deepspeed.runtime.utils import partition_balanced
diff --git a/tests/unit/test_pipe.py b/tests/unit/test_pipe.py
index f7f2b1a1eff4..832d06f9d3ce 100755
--- a/tests/unit/test_pipe.py
+++ b/tests/unit/test_pipe.py
@@ -4,7 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.distributed as dist
+import deepspeed.comm as dist
import pytest
diff --git a/tests/unit/test_pipe_module.py b/tests/unit/test_pipe_module.py
index 281101492c37..e50c7d6231a5 100644
--- a/tests/unit/test_pipe_module.py
+++ b/tests/unit/test_pipe_module.py
@@ -2,7 +2,7 @@
import torch
import torch.nn as nn
-import torch.distributed as dist
+import deepspeed.comm as dist
import pytest
diff --git a/tests/unit/test_runtime_utils.py b/tests/unit/test_runtime_utils.py
index fb5c8e394546..2012233cdf63 100644
--- a/tests/unit/test_runtime_utils.py
+++ b/tests/unit/test_runtime_utils.py
@@ -1,7 +1,7 @@
from deepspeed.moe.utils import is_moe_param, split_params_grads_into_shared_and_expert_params, split_params_into_shared_and_expert_params
import torch
from torch._utils import _flatten_dense_tensors
-import torch.distributed as dist
+import deepspeed.comm as dist
import pytest
import deepspeed.runtime.utils as ds_utils
@@ -42,7 +42,7 @@ def _test_clip_grad_norm_() -> None:
world_size = dist.get_world_size()
gathered_norm = [torch.zeros(1).cuda() for i in range(world_size)]
- torch.distributed.all_gather(gathered_norm, norm)
+ dist.all_gather(gathered_norm, norm)
assert gathered_norm[0] == gathered_norm[1], "norm at rank 0 does not match the norm at rank 1"
diff --git a/tests/unit/test_sparse_grads.py b/tests/unit/test_sparse_grads.py
index 2506a1d4c8eb..b146946f30a8 100644
--- a/tests/unit/test_sparse_grads.py
+++ b/tests/unit/test_sparse_grads.py
@@ -1,5 +1,5 @@
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
import deepspeed
import pytest
from .common import distributed_test
diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py
index 89bb8ec3dc7c..78b7867013c1 100644
--- a/tests/unit/test_topology.py
+++ b/tests/unit/test_topology.py
@@ -1,7 +1,7 @@
import pytest
import torch
-import torch.distributed as dist
+import deepspeed.comm as dist
from deepspeed.runtime.pipe.topology import PipelineParallelGrid as Grid
from deepspeed.runtime.pipe.topology import ProcessTopology as Topo
diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py
index 154c354cd3ac..453eaaadb0f7 100755
--- a/tests/unit/test_zero.py
+++ b/tests/unit/test_zero.py
@@ -1,7 +1,7 @@
import math
from typing import Dict, List, Set
import pytest
-import torch.distributed as dist
+import deepspeed.comm as dist
import torch
from torch import Tensor
from torch.nn import Linear, Module
@@ -486,35 +486,42 @@ def __init__(
self.loss = L1Loss(reduction="none")
- def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]:
+ def forward(self,
+ x: Tensor,
+ y: Tensor,
+ use_module_trace: bool,
+ param_prefetching: bool) -> Dict[str,
+ Tensor]:
_assert_partition_status(
self,
{
ZeroParamStatus.NOT_AVAILABLE,
ZeroParamStatus.INFLIGHT,
ZeroParamStatus.AVAILABLE
- } if prefetching else {ZeroParamStatus.NOT_AVAILABLE})
+ } if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE})
- layerwise_expected_states = {
- ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE,
+ pre_layer_expected_states = {
+ ZeroParamStatus.INFLIGHT
+ if param_prefetching else ZeroParamStatus.NOT_AVAILABLE,
ZeroParamStatus.AVAILABLE,
}
- _assert_partition_status(self.__layer1, layerwise_expected_states)
+ post_layer_expected_states = {
+ ZeroParamStatus.AVAILABLE
+ if param_prefetching else ZeroParamStatus.NOT_AVAILABLE,
+ }
+
+ _assert_partition_status(self.__layer1, pre_layer_expected_states)
hidden1 = self.__layer1(x)
- _assert_partition_status(self.__layer1, {ZeroParamStatus.NOT_AVAILABLE})
+ _assert_partition_status(self.__layer1, post_layer_expected_states)
- _assert_partition_status(self.__layer2, layerwise_expected_states)
+ _assert_partition_status(self.__layer2, pre_layer_expected_states)
hidden2 = self.__layer2(hidden1)
- _assert_partition_status(self.__layer2, {ZeroParamStatus.NOT_AVAILABLE})
+ _assert_partition_status(self.__layer2, post_layer_expected_states)
- _assert_partition_status(self.__layer3, layerwise_expected_states)
+ _assert_partition_status(self.__layer3, pre_layer_expected_states)
y_hat = self.__layer3(hidden2)
- _assert_partition_status(self.__layer3,
- {
- ZeroParamStatus.AVAILABLE
- if prefetching else ZeroParamStatus.NOT_AVAILABLE
- })
+ _assert_partition_status(self.__layer3, post_layer_expected_states)
loss = self.loss(y_hat, y)
@@ -524,7 +531,7 @@ def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]:
ZeroParamStatus.NOT_AVAILABLE,
ZeroParamStatus.INFLIGHT,
ZeroParamStatus.AVAILABLE
- } if prefetching else {ZeroParamStatus.NOT_AVAILABLE})
+ } if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE})
return {
"hidden1": hidden1,
@@ -539,14 +546,14 @@ def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]:
@pytest.mark.parametrize("contiguous_gradients", [True, False])
@pytest.mark.parametrize("offload_optimizer", [True, False])
@pytest.mark.parametrize("zero_grad", [True, False])
-@pytest.mark.parametrize("iteration", list(range(1)))
+@pytest.mark.parametrize("prefetching", [True, False])
def test_zero3_param_partitioning_base(
param_persistence_threshold: int,
fp16_enabled: bool,
contiguous_gradients: bool,
offload_optimizer: bool,
zero_grad: bool,
- iteration: int,
+ prefetching: bool,
) -> None:
@distributed_test(world_size=[2])
def _test_zero3_param_partitioning():
@@ -557,7 +564,7 @@ def _test_zero3_param_partitioning():
n = 5
weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)]
model = EltwiseMultiplicationTestNetwork(*weights)
-
+ prefetch_bucket_size = sum([p.numel() for p in model.parameters(recurse=True)])
cfg = {
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
@@ -565,6 +572,7 @@ def _test_zero3_param_partitioning():
"stage3_max_reuse_distance": 0,
"stage3_param_persistence_threshold": param_persistence_threshold,
"contiguous_gradients": contiguous_gradients,
+ "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0
},
"optimizer": {
"type": "Adam",
@@ -672,7 +680,8 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor:
n),
dtype=torch.float16 if fp16_enabled else torch.float32,
device=ds_engine.device),
- prefetching=train_iter > 0,
+ use_module_trace=train_iter > 0,
+ param_prefetching=prefetching and train_iter > 0,
)
assert torch.allclose(activations["hidden1"], expected_hidden1)
assert torch.allclose(activations["hidden2"], expected_hidden2)
@@ -1215,7 +1224,7 @@ def _go(model, hidden_dim):
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
- torch.distributed.barrier()
+ dist.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
@@ -1275,7 +1284,7 @@ def _go(hidden_dim):
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
- torch.distributed.barrier()
+ dist.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if return_type == dict:
diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py
index 66521e075ce1..e689005709d9 100644
--- a/tests/unit/test_zero_context.py
+++ b/tests/unit/test_zero_context.py
@@ -6,6 +6,7 @@
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape
+import deepspeed.comm as dist
from .common import distributed_test, get_master_port
@@ -21,9 +22,9 @@ def setup_serial_env():
def test_scattered_init_dist():
setup_serial_env()
- assert not torch.distributed.is_initialized()
+ assert not dist.is_initialized()
with deepspeed.zero.Init():
- assert torch.distributed.is_initialized()
+ assert dist.is_initialized()
@distributed_test(world_size=2)
@@ -52,7 +53,7 @@ def test_gather_update():
# Gather and make a change
with deepspeed.zero.GatheredParameters(l.weight, modifier_rank=1):
assert l.weight.ds_status == ZeroParamStatus.AVAILABLE
- if torch.distributed.get_rank() == 1:
+ if dist.get_rank() == 1:
with torch.no_grad():
l.weight.zero_()
diff --git a/tests/unit/util.py b/tests/unit/util.py
index 966733b1d929..79a459da3c14 100644
--- a/tests/unit/util.py
+++ b/tests/unit/util.py
@@ -13,20 +13,15 @@ def required_torch_version():
def bf16_required_version_check():
- TORCH_MAJOR = int(torch.__version__.split('.')[0])
- TORCH_MINOR = int(torch.__version__.split('.')[1])
-
- if type(torch.cuda.nccl.version()) != tuple:
- return False
- else:
- NCCL_MAJOR = torch.cuda.nccl.version()[0]
- NCCL_MINOR = torch.cuda.nccl.version()[1]
+ split_version = lambda x: map(int, x.split('.')[:2])
+ TORCH_MAJOR, TORCH_MINOR = split_version(torch_info['version'])
+ NCCL_MAJOR, NCCL_MINOR = split_version(torch_info['nccl_version'])
+ CUDA_MAJOR, CUDA_MINOR = split_version(torch_info['cuda_version'])
- CUDA_MAJOR = int(torch_info['cuda_version'].split('.')[0])
if (TORCH_MAJOR > 1 or
(TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and (
NCCL_MAJOR > 2 or
- (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and torch.cuda.is_bf16_supported():
+ (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and torch_info['bf16_support']:
return True
else:
return False