Skip to content

Commit

Permalink
Merge branch 'master' into zero2_param_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Dec 16, 2024
2 parents 9cb29d8 + da771ed commit 4bb8f91
Show file tree
Hide file tree
Showing 17 changed files with 74 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout 6c3f168b3
# git checkout 6c3f168b3
git rev-parse --short HEAD
pip install .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout 6c3f168b3
# git checkout 6c3f168b3
git rev-parse --short HEAD
pip install .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-torch-nightly-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
git checkout 6c3f168b3
# git checkout 6c3f168b3
git rev-parse --short HEAD
pip install .
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/xpu-compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ jobs:
run: |
apt-get update
apt-get install clinfo libaio-dev python3-pip -y
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/torchvision/
pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl
pip install py-cpuinfo numpy
pip install .[dev,autotuning]
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/xpu-max1100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ jobs:
run: |
apt-get update
apt-get install clinfo libaio-dev python3-pip -y
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/torchvision/
pip install py-cpuinfo numpy
pip install .[dev,autotuning]
Expand Down
6 changes: 4 additions & 2 deletions COMMITTERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
| Olatunji Ruwase | [tjruwase](https://github.com/tjruwase) | Microsoft |
| Logan Adams | [loadams](https://github.com/loadams) | Microsoft |
| Masahiro Tanaka | [tohtana](https://github.com/tohtana) | Microsoft |
| Jeff Rasley | [jeffra](https://github.com/jeffra) | SnowFlake |
| Minjia Zhang | [minjiazhang](https://github.com/minjiazhang) | UIUC |
| Jeff Rasley | [jeffra](https://github.com/jeffra) | SnowFlake |
| Minjia Zhang | [minjiazhang](https://github.com/minjiazhang) | UIUC |
| Ashwin Aji | [ashwinma](https://github.com/ashwinma) | AMD |
| Sam Foreman | [saforem2](https://github.com/saforem2) | Argonne National Laboratory |
25 changes: 14 additions & 11 deletions accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ def get_accelerator():
if accelerator_name is None:
try:
import intel_extension_for_pytorch as ipex

if ipex._C._has_xpu():
accelerator_name = "xpu"
else:
accelerator_name = "cpu"
except ImportError as e:
pass
if accelerator_name is None:
Expand Down Expand Up @@ -162,23 +161,27 @@ def get_accelerator():
except ImportError as e:
pass
if accelerator_name is None:
# borrow this log from PR#5084
try:
import torch

# Determine if we are on a GPU or x86 CPU with torch.
if torch.cuda.is_available(): #ignore-cuda
# "torch.cuda.is_available()" provides a stronger guarantee, #ignore-cuda
# ensuring that we are free from CUDA initialization errors.
# While "torch.cuda.device_count() > 0" check ensures that #ignore-cuda
# we won't try to do any CUDA calls when no device is available
# For reference: https://github.com/microsoft/DeepSpeed/pull/6810
if torch.cuda.device_count() > 0 and torch.cuda.is_available(): #ignore-cuda
accelerator_name = "cuda"
else:
if accel_logger is not None:
accel_logger.warn(
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it."
)
accelerator_name = "cpu"
except (RuntimeError, ImportError) as e:
# TODO need a more decent way to detect which accelerator to use, consider using nvidia-smi command for detection
accelerator_name = "cuda"
pass
if accelerator_name is None:
# borrow this log from PR#5084
if accel_logger is not None:
accel_logger.warn(
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.")
# cpu added as catch-all when accelerator detection fails
accelerator_name = "cpu"

ds_set_method = "auto detect"

Expand Down
3 changes: 3 additions & 0 deletions deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel):
tp_size: int = 1
""" Number of devices to split the model across using tensor parallelism. """

tp_grain_size: int = 64
"Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size."

mpu: object = None
"""
A model parallelism unit object that implements
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -303,6 +303,9 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
if hasattr(model_config, 'num_attention_heads'):
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))

# 4.4 set tp_grain_size
set_tp_grain_size(config.tensor_parallel.tp_grain_size)

# 5. Set linear policies
_autotp.update_linear_policies()

Expand Down
11 changes: 8 additions & 3 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def set_n_embd(num):
n_embd = num


def set_tp_grain_size(num):
global tp_grain_size
tp_grain_size = num


def get_num_kv_heads():
global num_kv_heads
if 'num_kv_heads' in globals():
Expand All @@ -45,9 +50,9 @@ def get_shard_size(total_size, mp_size, name=None, rank=None):
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
if total_size >= 64:
grain_size = total_size // 64
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64
if total_size >= tp_grain_size:
grain_size = total_size // tp_grain_size
return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * tp_grain_size
else:
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)

Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3738,6 +3738,11 @@ def offload_states(self,
assert self.zero_optimization_stage(
) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3."

opt_offload_config = self.zero_offload_optimizer()
assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states."
param_offload_config = self.zero_offload_param()
assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters."

assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."

if device == OffloadDeviceEnum.none:
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/fp16/onebit/zoadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@


class ZeroOneAdam(torch.optim.Optimizer):
"""Implements the 0/1 Adam algorithm. Currently GPU-only.
"""
Implements the 0/1 Adam algorithm. Currently GPU-only.
For usage example please see https://www.deepspeed.ai/tutorials/zero-one-adam/
For technical details please read https://arxiv.org/abs/2202.06009
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class LRRangeTest(object):
"""Sets the learning rate of each parameter group according to
learning rate range test (LRRT) policy. The policy increases learning
rate starting from a base value with a constant frequency, as detailed in
the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
the paper `A disciplined approach to neural network hyper-parameters: Part 1 <https://arxiv.org/abs/1803.09820>`_
LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
configure the LR boundaries for Cyclic LR schedules.
Expand Down Expand Up @@ -379,7 +379,7 @@ class OneCycle(object):
1CLR policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This implementation was adapted from the github repo: `pytorch/pytorch`_
This implementation was adapted from the github repo: `PyTorch <https://github.com/pytorch/pytorch>`_.
Args:
optimizer (Optimizer): Wrapped optimizer.
Expand Down
2 changes: 2 additions & 0 deletions docs/_data/navigation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ lnav:
url: /tutorials/data-efficiency/
- title: 'DeepNVMe'
url: /tutorials/deepnvme/
- title: 'Domino'
url: /tutorials/domino/
- title: 'DS4Sci_EvoformerAttention'
url: /tutorials/ds4sci_evoformerattention/
- title: 'Flops Profiler'
Expand Down
6 changes: 6 additions & 0 deletions docs/_tutorials/domino.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
title: "Domino"
tags: training
---

Domino achieves near-complete communication hiding behind computation for tensor parallel training. Please find our [Domino-tutorial](https://github.com/microsoft/DeepSpeedExamples/blob/master/training/DeepSpeed-Domino/README.md) in DeepSpeedExample repo.
18 changes: 9 additions & 9 deletions docs/code-docs/source/monitor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ overview of what DeepSpeed will log automatically.
:header: "Field", "Description", "Condition"
:widths: 20, 20, 10

`Train/Samples/train_loss`,The training loss.,None
`Train/Samples/lr`,The learning rate during training.,None
`Train/Samples/loss_scale`,The loss scale when training using `fp16`.,`fp16` must be enabled.
`Train/Eigenvalues/ModelBlockParam_{i}`,Eigen values per param block.,`eigenvalue` must be enabled.
`Train/Samples/elapsed_time_ms_forward`,The global duration of the forward pass.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward`,The global duration of the forward pass.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward_inner`,The backward time that does not include the gradient reduction time. Only in cases where the gradient reduction is not overlapped, if it is overlapped then the inner time should be about the same as the entire backward time.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward_allreduce`,The global duration of the allreduce operation.,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_step`,The optimizer step time,`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/train_loss`,"The training loss.",None
`Train/Samples/lr`,"The learning rate during training.",None
`Train/Samples/loss_scale`,"The loss scale when training using `fp16`.",`fp16` must be enabled.
`Train/Eigenvalues/ModelBlockParam_{i}`,"Eigen values per param block.",`eigenvalue` must be enabled.
`Train/Samples/elapsed_time_ms_forward`,"The global duration of the forward pass.",`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward`,"The global duration of the forward pass.",`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward_inner`,"The backward time that does not include the gradient reduction time. Only in cases where the gradient reduction is not overlapped, if it is overlapped then the inner time should be about the same as the entire backward time.",`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_backward_allreduce`,"The global duration of the allreduce operation.",`flops_profiler.enabled` or `wall_clock_breakdown`.
`Train/Samples/elapsed_time_ms_step`,"The optimizer step time.",`flops_profiler.enabled` or `wall_clock_breakdown`.

TensorBoard
-----------
Expand Down
6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@ title: "Latest News"
---
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2024/12] [DeepSpeed Domino: Communication-Free LLM Training Engine](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-domino/README.md)

* [2024/08] [DeepSpeed on Windows](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/japanese/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/windows/08-2024/chinese/README.md)]

* [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/japanese/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/chinese/README.md)]
* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md)[[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)]
* [2024/03] [DeepSpeed-FP6: The Power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)

<!-- NOTE: we must use html for news items otherwise links will be broken in the 'more news' section -->

<details>
<summary>More news</summary>
<ul>
<li>[2024/01] <a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19">DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.</a></li>

<li>[2023/11] <a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference/README.md">Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed</a> [<a href="https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html">Intel version</a>]</li>

<li>[2023/11] <a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp/README.md">DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow</a></li>

<li>[2023/11] <a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen">DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference</a> [<a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md">中文</a>] [<a href="https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md">日本語</a>]</li>


</ul>
</details>
Expand Down

0 comments on commit 4bb8f91

Please sign in to comment.