-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Compressedbackend for Onebit optimizers (#5473)
In the process of adding onebit optimizers support for XPU devices, we have noticed that for different accelerator, the main difference of implementation of `compressed_allreduce` lies on `packbits` and `unpackbits`. CUDA uses cupy and NPU uses torch_npu. Instead of replace these to xpu only functions, we provided a CompressedBackend to do the `compressed_allreduce` work where users can add their own packbits/unpackbits kernels, which is a general path for all kinds of accelerators. In this PR, we: 1. Add CompressedBackend for onebitAdam, onebitLamb and zerooneAdam 2. Add XPU implement of packbits/unpackbits with SYCL, built in PackbitsBuilder 3. Add tests for onebit with CompressedBackend --------- Co-authored-by: Olatunji Ruwase <[email protected]>
- Loading branch information
1 parent
6b6d641
commit 11a62a0
Showing
11 changed files
with
504 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#include <ipex.h> | ||
#include <torch/extension.h> | ||
#include <iostream> | ||
#include <sycl/sycl.hpp> | ||
|
||
using namespace sycl; | ||
using namespace xpu; | ||
|
||
void packbitskernel(const float* input, uint8_t* output, const int input_size, id<1> item_ct1) | ||
{ | ||
// get the sign bit of each float and pack them into byte | ||
int i = item_ct1; | ||
for (int j = 0; j < 8; ++j) { | ||
int k = i * 8 + j; | ||
int bit = k < input_size && (!sycl::signbit(input[k])); | ||
output[i] |= bit << (7 - j); | ||
} | ||
} | ||
|
||
void unpackbitskernel(const uint8_t* input, float* output, id<1> item_ct1) | ||
{ | ||
// use the bit value to set float, bit 0 -> float -1, bit 1 -> float 1 | ||
int i = item_ct1; | ||
output[i] = (float((input[i / 8] >> (7 - i % 8)) & 1) - 0.5) * 2; | ||
} | ||
|
||
sycl::queue get_current_queue(at::Device device) | ||
{ | ||
c10::impl::VirtualGuardImpl impl(device.type()); | ||
c10::Stream _stream = impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); | ||
sycl::queue queue = xpu::get_queue_from_stream(_stream); | ||
return queue; | ||
} | ||
|
||
/* | ||
pack float tensor into uint8 tensor. Every eight float elements get packed into one uint8 | ||
if float x >= 0, will be packed as a '1' bit, or will be packed as '0' | ||
Arguments: | ||
tensor: A bool tensor that get packed. | ||
input_size: numel of input tensor | ||
rank: device id in order to get corresponding stream | ||
*/ | ||
at::Tensor packbits(at::Tensor tensor, int input_size, int rank) | ||
{ | ||
at::Device device = "xpu:" + std::to_string(rank); | ||
sycl::queue q = get_current_queue(device); | ||
|
||
int packed_size = (input_size + 7) / 8; | ||
auto unit8_options = at::TensorOptions().dtype(at::kByte).device(at::kXPU); | ||
at::Tensor packed = torch::zeros({packed_size}, unit8_options); | ||
|
||
float* input = (float*)tensor.data_ptr(); | ||
uint8_t* output = (uint8_t*)packed.data_ptr(); | ||
|
||
auto event = q.submit([&](sycl::handler& cgh) { | ||
cgh.parallel_for<>(range(packed_size), [=](id<1> item_ct1) { | ||
packbitskernel(input, output, input_size, item_ct1); | ||
}); | ||
}); | ||
|
||
return packed; | ||
} | ||
|
||
/* | ||
unpack uint8 tensor into float tensor. Every uint8 element get unpacked into eight float | ||
a '1' bit will be converted to a float(1), a '0' bit will be converted to a float(-1). | ||
Arguments: | ||
tensor: A uint8 tensor that get unpacked. | ||
input_size: numel of input tensor | ||
rank: device id in order to get corresponding stream | ||
*/ | ||
at::Tensor unpackbits(at::Tensor tensor, int input_size, int rank) | ||
{ | ||
at::Device device = "xpu:" + std::to_string(rank); | ||
sycl::queue q = get_current_queue(device); | ||
|
||
auto float_options = at::TensorOptions().dtype(at::kFloat).device(at::kXPU); | ||
at::Tensor unpacked = torch::empty({input_size * 8}, float_options); | ||
|
||
uint8_t* input = (uint8_t*)tensor.data_ptr(); | ||
float* output = (float*)unpacked.data_ptr(); | ||
|
||
auto event = q.submit([&](sycl::handler& cgh) { | ||
cgh.parallel_for<>(range(input_size * 8), | ||
[=](id<1> item_ct1) { unpackbitskernel(input, output, item_ct1); }); | ||
}); | ||
|
||
return unpacked; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.def("packbits", &packbits, "DeepSpeed XPU packbits (C++)"); | ||
m.def("unpackbits", &unpackbits, "DeepSpeed XPU unpackbits (C++)"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import numpy as np | ||
import torch | ||
import deepspeed.comm as dist | ||
from deepspeed.accelerator import get_accelerator | ||
from deepspeed.ops.op_builder import PackbitsBuilder | ||
|
||
|
||
class CompressedBackend(object): | ||
|
||
def __init__(self, mpu=None): | ||
if mpu is None: | ||
self.world_group = dist.new_group(ranks=range(dist.get_world_size())) | ||
else: | ||
self.mpu = mpu | ||
self.world_group = self.mpu.get_data_parallel_group() | ||
self.size = dist.get_world_size(group=self.world_group) | ||
self.rank = dist.get_rank(group=self.world_group) | ||
self.packer = PackbitsBuilder().load() | ||
|
||
def my_igather(self, rank, size, group, sendbuf, recvbuf, root): | ||
req = [] | ||
if rank == root: | ||
for idx in range(size): | ||
if idx != rank: | ||
req.append(dist.irecv(recvbuf[idx], src=idx, group=group)) | ||
else: | ||
recvbuf[rank] = sendbuf | ||
else: | ||
req.append(dist.isend(sendbuf, group=group, dst=root)) | ||
return req | ||
|
||
def my_gather(self, rank, size, group, sendbuf, recvbuf, root): | ||
if rank == root: | ||
for idx in range(size): | ||
if idx != rank: | ||
dist.recv(recvbuf[idx], src=idx, group=group) | ||
else: | ||
recvbuf[rank] = sendbuf | ||
else: | ||
dist.send(sendbuf, group=group, dst=root) | ||
|
||
def pack(self, buffer, size): | ||
# pack float tensor into uint8 tensor | ||
packed = self.packer.packbits(buffer.float(), buffer.numel(), self.rank) | ||
return packed.reshape(size, -1) | ||
|
||
def unpack(self, buffer, size, dtype): | ||
# unpack uint8 to float tensor | ||
unpacked = self.packer.unpackbits(buffer, buffer.numel(), self.rank) | ||
return unpacked.reshape(size, -1).to(dtype) | ||
|
||
def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): | ||
original_shape = buffer_m.size() | ||
if len(original_shape) > 1: | ||
buffer_m = torch.flatten(buffer_m) | ||
|
||
# align size of original_buffer and error | ||
original_size = buffer_m.numel() | ||
worker_error_size = worker_error.numel() | ||
if original_size != worker_error_size: | ||
empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) | ||
buffer_m = torch.cat([buffer_m, empty_tensor]) | ||
|
||
buffer_m.add_(worker_error) | ||
worker_scale = torch.linalg.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) | ||
|
||
worker_error.set_(buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) | ||
|
||
sign_list_packed_tmp = self.pack(buffer_m, self.size).type(torch.int8) | ||
|
||
recvbuf_sign = torch.zeros([self.size, len(sign_list_packed_tmp[self.rank])], | ||
dtype=sign_list_packed_tmp[0].dtype, | ||
device=sign_list_packed_tmp.device) | ||
|
||
sign_list_packed = [sign_list_packed_tmp[idx] for idx in range(self.size)] | ||
|
||
recvbuf_scale = [ | ||
torch.zeros(1, dtype=worker_scale.dtype, device=get_accelerator().current_device_name()) | ||
for _ in range(self.size) | ||
] | ||
|
||
# communication phase 1 | ||
# all to all for sign | ||
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) | ||
# all gather for scale | ||
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) | ||
|
||
flattened_recvbuf_sign = recvbuf_sign.type(torch.uint8).flatten() | ||
compensated_server_m = self.unpack(flattened_recvbuf_sign, self.size, torch.float32) \ | ||
.mul_(torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) | ||
|
||
compensated_server_m.add_(server_error) | ||
|
||
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel()) | ||
|
||
server_error.set_(compensated_server_m - | ||
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) | ||
|
||
server_sign_packed = self.pack(compensated_server_m, 1).type(torch.int8) | ||
|
||
# recvbuf_sign_server | ||
recvbuf_sign_server_tmp = torch.zeros([self.size, len(server_sign_packed[0])], | ||
dtype=recvbuf_sign.dtype, | ||
device=server_sign_packed.device) | ||
|
||
recvbuf_sign_server = [recvbuf_sign_server_tmp[idx] for idx in range(self.size)] | ||
|
||
# recvbuf_scale_server | ||
recvbuf_scale_server_tmp = torch.zeros([self.size, 1], | ||
dtype=worker_scale.dtype, | ||
device=server_sign_packed.device) | ||
|
||
recvbuf_scale_server = [recvbuf_scale_server_tmp[idx] for idx in range(self.size)] | ||
|
||
# communication Phase 2 | ||
dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) | ||
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) | ||
|
||
recvbuf_sign_server = torch.stack(recvbuf_sign_server) | ||
|
||
flattened_recvbuf_sign_server = recvbuf_sign_server.type(torch.uint8).flatten() | ||
|
||
buffer_m.data.copy_( | ||
self.unpack(flattened_recvbuf_sign_server, self.size, | ||
torch.float32).mul_(recvbuf_scale_server_tmp).flatten().data) | ||
|
||
if original_size != worker_error_size: | ||
buffer_m = buffer_m[0:original_size] | ||
if len(original_shape) > 1: | ||
buffer_m = buffer_m.reshape(original_shape) | ||
|
||
return buffer_m |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
from .builder import SYCLOpBuilder | ||
|
||
|
||
class PackbitsBuilder(SYCLOpBuilder): | ||
BUILD_VAR = "DS_BUILD_PACK_BITS" | ||
NAME = "pack_bits" | ||
|
||
def __init__(self): | ||
super().__init__(name=self.NAME) | ||
|
||
def absolute_name(self): | ||
return f'deepspeed.ops.{self.NAME}_op' | ||
|
||
def sources(self): | ||
return ['csrc/xpu/packbits/packing.cpp'] | ||
|
||
def include_paths(self): | ||
return ['csrc/xpu/includes'] | ||
|
||
def cxx_args(self): | ||
args = super().cxx_args() | ||
return args + self.version_dependent_macros() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# One-Bit tests | ||
|
||
In this folder, you can test the functionality and performance of different backend for doing compressed allreduce, which is the main algorithm in one-bit optimizers like [One-Bit Adam](https://www.deepspeed.ai/tutorials/onebit-adam/), [One-Bit Lamb](https://www.deepspeed.ai/tutorials/onebit-lamb/) and [Zero-One Adam](https://www.deepspeed.ai/tutorials/zero-one-adam/). | ||
|
||
## How to run | ||
|
||
### NCCL and MPI backend | ||
|
||
Basically it requires your environment have relative communication backend installed, the NCCL backend of PyTorch distributed or Message Passing Interface (MPI) like MVAPICH2-GDR and OpenMPI. [Detailed Pre-requisites](https://www.deepspeed.ai/tutorials/zero-one-adam/#12-pre-requisites-for-01-adam). | ||
|
||
To test accuracy and performance of NCCL backend: | ||
```bash | ||
python test_nccl_backend.py | ||
python test_nccl_perf.py | ||
``` | ||
Similarly, for MPI backend: | ||
```bash | ||
python test_mpi_backend.py | ||
python test_mpi_perf.py | ||
``` | ||
|
||
### Compressed backend | ||
|
||
This backend provides an approach to abstract the generic part of one-bit optimizers and implements accelerator dependent part with DeepSpeed custom op builder. To use this `CompressedBackend` and test it, you should make sure that your current accelerator supports `PackbitsBuilder`, so that it could be loaded to do high performance packing and unpacking between float and Byte datatype. | ||
An example can be found in `Deepspeed/op_builder/xpu/packbits.py`. | ||
|
||
The test usage is same as others: | ||
```bash | ||
python test_compressed_backend.py | ||
python test_compressed_perf.py | ||
``` |
Oops, something went wrong.