Skip to content

Conversation

@delock
Copy link
Collaborator

@delock delock commented Nov 3, 2022

This is a snapshot of PR #2221 . The purpose of this PR is provide to XPU customer a stable DeepSpeed code base with accelerator abstraction, to get the latest accelerator abstraction code which may still going under development, use #2221

The following is a snapshot of description of PR #2221

This is a proposal to add device abstraction into DeepSpeed. Currently DeepSpeed has CUDA hard coded, which makes it works for device with CUDA abstraction only. In order to make more devices work for DeepSpeed. We need to make DeepSpeed not depending on CUDA, but depend on a device abstraction layer that could support different device types. In this proposal, we could support both CUDA device and Intel GPU device through pytorch XPU extension. In addition, we also support build SYCL kernels through SYCLOpBuilder for Intel GPU device.

This prosoal has the following design goals:

  1. Make DeepSpeed work for both CUDA device and Intel GPU device.
  2. Friendly for extending to other partie's accelerator devices.
  3. Minimal impact to current DeepSpeed models. Current models still work with DeepSpeed on CUDA device without modification. Model with CUDA hard coded with need modification to work on both CUDA device and Intel GPU.
  4. Use as less if...else... as possible when a piece of code needs to support both CUDA device and Intel GPU device.

High level design of accelerator abstrction

The high level design and implementation of accelerator abstracion is based on and extended from #2320:

  1. Use DeepSpeedAccelerator abstract class to define all accelerator interface
  2. A single global DeepSpeedAccelerator object can be actively or lazily initiated and can be used throughout DeepSpeed code and models to access accelerator functionalities. This object can be accessed through get_accelerator() and set with set_accelerator()
  3. Concrete accelerator implementation such as CUDA or XPU can be in external module and can be imported by DeepSpeed during initialization.

DeepSpeedAccelerator abstract class

DeepSpeedAccelerator abstract class define the interface a concrete accelerator needs to implement, it has the following interface types:

  1. Relates to accelerator device name. This is mainly related to usage such as 'cuda', 'cuda:0', etc. The interface name in this category is device_name() and current_device_name().
  2. Relates to accelerator runtime. This is mainly related to torch.cuda.<interface_name> such as is_available(), synchronize(), etc.
  3. Relates to tensor operation. This is mainly related to tensor operation that relies to device type. The interface name in this category is pin_memory() and on_accelerator()
  4. Relates to communication backend. This is used to select accelerator specific communication backend such as 'nccl' for CUDA device and 'ccl' for XPU device. The interface name in this category is communication_backend_name()
  5. Relates to op builder. This is used to select op builder for building accelerator kernels. The interfacename in this category is create_op_builder()

Concrete accelerator selection

Concreate accelerator selection is through deepspeed.accelerator.real_accelerator, there are two interface to set/get concreate accelerator:
set_accelerator(accel_obj) -- set global accelerator to parameter, this interface can be used in the beginning of model before deepspeed intialization
get_accelerator() -- get the global accelerator, if global accelerator had not been set, detect whether xpu or cuda support is present in the system and set global accelerator object accordingly, if no accelerator support is detected, return CUDA accelerator object by default.

Implement concrete accelerator in external module

Concrete accelerator can be implemented in external module, the implementation should provide an accelerator class that derives DeepSpeedAccelerator, an example of implementation can be found in cuda_accelerator.py. A model can import this external module and initiate an accelerator object and use set_accelerator to set DeepSpeed to use this accelerator:

from <external-module-for-accelerator> import <EXTERNAL>_Accelerator
accel = <EXTERNAL>_Accelerator()
from deepspeed.accelerator import set_accelerator
set_accelerator(accel)

Write accelerator specific code in DeepSpeed and model

Accelerator runtime

The accelerator abstraction provide a single entrance for accelerator specific features, which takes the form:

from deepspeed.accelerator import get_accelerator

<code that use accelerator functionality> get_accelerator().<interface name>(...)

For existing torch.cuda.<interface name> runtime call, we convert it like the following example:

if torch.cuda.is_available():
    ...

-->

if get_accelerator().is_available():
    ...

For CUDA specific device name such as 'cuda' or 'cuda:0', or 'cuda:1', we convert them to get_accelerator().device_name(), get_accelerator().device_name(0), and get_accelerator().device_name(1).

It is a little bit trick when we convert places where torch.cuda.current_device() are called. Current device return device index, but if we supply device index in Pytorch code where a device is needed, Pytorch will explain it as a CUDA device. To get current device that can be used as a device name, we need to call get_accelerator().current_device_name():

my_tensor = torch.empty(3, 4, device=get_accelerator().current_device_name())

Only when an integer number is expected we use get_accelerator().current_device():

idx = get_accelerator().current_device()
default_generator = get_accelerator().default_generator(idx)

Tensor operations

When we convert a torch tensor to accelerator device such as my_tensor.cuda(), we use my_tensor.to(get_accelerator().deivce_name())

When we check whether a torch tensor is on accelerator device such as my_tensor.is_cuda, we use get_accelerator().on_accelerator(my_tensor)

When pin a tensor to GPU memory such as my_tensor.pin_memory(), we use get_accelerator().pin_memory(my_tensor)

Communication backend

When a communication backend string is used, the interface get_accelerator().communication_backend_name() is used get get communication backend name. So instead of torch.distributed.init_process_group('nccl'), we use torch.distributed.init_process_group(get_accelerator().communication_backend_name())

Op builder abstraction

Op builders are abstracted through get_accelerator().create_op_builder(<op builder name>), if the op builder is implemented in the accelerator, an object of OpBuilder subclass will be returned. If the op builder is not implemented, None will be returned.

A typical implementation can be referred to from the CUDA implementation, or from an XPU implementation which will be released later. Typical call such as CPUAdamBuilder().load() can be convert to get_accelerator().create_op_builder("CPUAdamBuilder").load().

delock and others added 30 commits August 16, 2022 15:10
delock and others added 5 commits November 4, 2022 13:41
* don't gather partitioned activations for mp size 1

* add inline comment for the change

Co-authored-by: Olatunji Ruwase <[email protected]>
don't gather partitioned activations for mp size 1 (deepspeedai#2454)
stage_1_and_2.py: no allreduce needed when mp size is 1 (deepspeedai#2494)
@loadams
Copy link
Collaborator

loadams commented Aug 18, 2023

Hi @delock - as a part of clearing through some PRs, it looks like this was a snapshot, but this PR hasn't been pushed to in almost a year. But since closing this won't modify the branch, I'm going to close this for now.

@loadams loadams closed this Aug 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants