Skip to content

Commit

Permalink
Add Ascend NPU accelerator support (#1676)
Browse files Browse the repository at this point in the history
* add Ascend NPU accelerator support

* fix code  styles

* enable accelerate test on npu

* fix typo&code styles

---------

Co-authored-by: jihuazhong <[email protected]>
  • Loading branch information
ji-huazhong and jihuazhong authored Jul 12, 2023
1 parent 518c206 commit c33adec
Show file tree
Hide file tree
Showing 15 changed files with 165 additions and 36 deletions.
23 changes: 18 additions & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
is_fp8_available,
is_ipex_available,
is_megatron_lm_available,
is_npu_available,
is_safetensors_available,
is_torch_version,
is_tpu_available,
Expand Down Expand Up @@ -413,13 +414,15 @@ def __init__(
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
):
self.native_amp = True
if self.device.type not in ("cuda", "mps"):
if self.device.type not in ("cuda", "mps", "npu"):
raise ValueError(err.format(mode="fp16", requirement="a GPU"))
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

self.scaler = ShardedGradScaler(**kwargs)
elif is_npu_available():
self.scaler = torch.npu.amp.GradScaler(**kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)

Expand Down Expand Up @@ -965,7 +968,7 @@ def join_uneven_inputs(self, joinables, even_batches=None):
... optimizer.zero_grad()
```
"""
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU):
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_XPU):
dl_even_batches_values = []

if even_batches is not None:
Expand Down Expand Up @@ -1292,7 +1295,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model._original_forward = model.forward
model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward
if self.mixed_precision == "fp16":
new_forward = torch.cuda.amp.autocast(dtype=torch.float16)(model_forward_func)
if is_npu_available():
new_forward = torch.npu.amp.autocast(dtype=torch.float16)(model_forward_func)
else:
new_forward = torch.cuda.amp.autocast(dtype=torch.float16)(model_forward_func)
elif self.mixed_precision == "bf16" and self.distributed_type != DistributedType.TPU:
new_forward = torch.autocast(device_type=self.device.type, dtype=torch.bfloat16)(model_forward_func)

Expand Down Expand Up @@ -1324,7 +1330,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
if not evaluation_mode:
if self.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU):
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
):
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
model = torch.nn.parallel.DistributedDataParallel(
Expand Down Expand Up @@ -2686,7 +2696,10 @@ def load_state(self, input_dir: str, **load_model_func_kwargs):

map_location = load_model_func_kwargs.pop("map_location", None)
if map_location is None:
if self.num_processes > 1 and self.distributed_type == DistributedType.MULTI_GPU:
if self.num_processes > 1 and self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
):
map_location = "on_device"
else:
map_location = "cpu"
Expand Down
24 changes: 20 additions & 4 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
def get_cluster_input():
distributed_type = _ask_options(
"Which type of machine are you using?",
["No distributed training", "multi-CPU", "multi-XPU", "multi-GPU", "TPU"],
["No distributed training", "multi-CPU", "multi-XPU", "multi-GPU", "multi-NPU", "TPU"],
_convert_distributed_mode,
)

Expand All @@ -60,7 +60,12 @@ def get_cluster_input():
rdzv_backend = "static"
same_network = True

if distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.MULTI_CPU]:
if distributed_type in [
DistributedType.MULTI_GPU,
DistributedType.MULTI_GPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_CPU,
]:
num_machines = _ask_field(
"How many different machines will you use (use more than 1 for multi-node training)? [1]: ",
int,
Expand Down Expand Up @@ -110,7 +115,11 @@ def get_cluster_input():
default=False,
error_message="Please enter yes or no.",
)
if not use_cpu and is_xpu_available() and distributed_type not in [DistributedType.MULTI_GPU, DistributedType.TPU]:
if (
not use_cpu
and is_xpu_available()
and distributed_type not in [DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.TPU]
):
ipex_config["use_xpu"] = _ask_field(
"Do you want to use XPU plugin to speed up training on XPU? [yes/NO]:",
_convert_yes_no_to_bool,
Expand Down Expand Up @@ -444,6 +453,7 @@ def get_cluster_input():
DistributedType.MULTI_CPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.TPU,
]:
machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
Expand All @@ -468,7 +478,13 @@ def get_cluster_input():
num_processes = 1

if (
distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_XPU, DistributedType.NO]
distributed_type
in [
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.NO,
]
and not use_cpu
and not use_mps
):
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/config/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _convert_compute_environment(value):

def _convert_distributed_mode(value):
value = int(value)
return DistributedType(["NO", "MULTI_CPU", "MULTI_XPU", "MULTI_GPU", "TPU"][value])
return DistributedType(["NO", "MULTI_CPU", "MULTI_XPU", "MULTI_GPU", "MULTI_NPU", "TPU"][value])


def _convert_dynamo_backend(value):
Expand Down
10 changes: 9 additions & 1 deletion src/accelerate/commands/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch

from ...utils import is_xpu_available
from ...utils import is_npu_available, is_xpu_available
from .config_args import ClusterConfig, default_json_config_file
from .config_utils import SubcommandHelpFormatter

Expand Down Expand Up @@ -73,6 +73,14 @@ def write_basic_config(mixed_precision="no", save_location: str = default_json_c
config["distributed_type"] = "MULTI_XPU"
else:
config["distributed_type"] = "NO"
elif is_npu_available():
num_npus = torch.npu.device_count()
config["num_processes"] = num_npus
config["use_cpu"] = False
if num_npus > 1:
config["distributed_type"] = "MULTI_NPU"
else:
config["distributed_type"] = "NO"
else:
num_xpus = 0
config["use_cpu"] = True
Expand Down
4 changes: 3 additions & 1 deletion src/accelerate/commands/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from accelerate import __version__ as version
from accelerate.commands.config import default_config_file, load_config_from_file

from ..utils import is_xpu_available
from ..utils import is_npu_available, is_xpu_available


def env_command_parser(subparsers=None):
Expand All @@ -47,6 +47,7 @@ def env_command(args):
pt_version = torch.__version__
pt_cuda_available = torch.cuda.is_available()
pt_xpu_available = is_xpu_available()
pt_npu_available = is_npu_available()

accelerate_config = "Not found"
# Get the default from the config file.
Expand All @@ -60,6 +61,7 @@ def env_command(args):
"Numpy version": np.__version__,
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
"PyTorch XPU available": str(pt_xpu_available),
"PyTorch NPU available": str(pt_npu_available),
"System RAM": f"{psutil.virtual_memory().total / 1024 ** 3:.2f} GB",
}
if pt_cuda_available:
Expand Down
12 changes: 10 additions & 2 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_filter_args,
is_bf16_available,
is_deepspeed_available,
is_npu_available,
is_rich_available,
is_sagemaker_available,
is_torch_version,
Expand Down Expand Up @@ -828,7 +829,10 @@ def _validate_launch_command(args):
):
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
args.multi_gpu = (
True if defaults.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_XPU) else False
True
if defaults.distributed_type
in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_XPU)
else False
)
args.tpu = defaults.distributed_type == DistributedType.TPU
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
Expand Down Expand Up @@ -896,11 +900,15 @@ def _validate_launch_command(args):
if args.num_processes is None:
if args.use_xpu and is_xpu_available():
args.num_processes = torch.xpu.device_count()
elif is_npu_available():
args.num_processes = torch.npu.device_count()
else:
args.num_processes = torch.cuda.device_count()
warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
if not args.multi_gpu and (
(args.use_xpu and is_xpu_available() and torch.xpu.device_count() > 1) or (torch.cuda.device_count() > 1)
(args.use_xpu and is_xpu_available() and torch.xpu.device_count() > 1)
or (is_npu_available() and torch.npu.device_count() > 1)
or (torch.cuda.device_count() > 1)
):
warned.append(
"\t\tMore than one GPU was found, enabling multi-GPU training.\n"
Expand Down
18 changes: 18 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_fp8_available,
is_ipex_available,
is_mps_available,
is_npu_available,
is_tpu_available,
is_xpu_available,
parse_choice_from_env,
Expand Down Expand Up @@ -195,6 +196,19 @@ def __init__(self, cpu: bool = False, **kwargs):
if self.device is None:
self.device = torch.device("cuda", self.local_process_index)
torch.cuda.set_device(self.device)
elif is_npu_available() and not cpu and int(os.environ.get("LOCAL_RANK", -1)) != -1:
self.distributed_type = DistributedType.MULTI_NPU
if not torch.distributed.is_initialized():
# Backend is not set by the user, we set it here
kwargs.pop("backend", None)
self.backend = "hccl"
torch.distributed.init_process_group(backend=self.backend, **kwargs)
self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
if self.device is None:
self.device = torch.device("npu", self.local_process_index)
torch.npu.set_device(self.device)
elif get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1:
if not cpu and is_xpu_available():
self.distributed_type = DistributedType.MULTI_XPU
Expand Down Expand Up @@ -343,6 +357,7 @@ def wait_for_everyone(self):
"""
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_CPU,
DistributedType.DEEPSPEED,
Expand Down Expand Up @@ -649,6 +664,7 @@ def default_device(self) -> torch.device:
Returns the default device which is:
- MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True.
- CUDA if `torch.cuda.is_available()`
- NPU if `is_npu_available()`
- CPU otherwise
"""
if is_mps_available():
Expand All @@ -658,6 +674,8 @@ def default_device(self) -> torch.device:
return torch.device("cuda")
elif is_xpu_available():
return torch.device("xpu:0")
elif is_npu_available():
return torch.device("npu")
else:
return torch.device("cpu")

Expand Down
3 changes: 2 additions & 1 deletion src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
gather,
is_bf16_available,
is_ipex_available,
is_npu_available,
is_xpu_available,
set_seed,
synchronize_rng_states,
Expand Down Expand Up @@ -358,7 +359,7 @@ def training_check():

accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.")

if torch.cuda.is_available():
if torch.cuda.is_available() or is_npu_available():
# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
print("FP16 training check.")
AcceleratorState._reset_state()
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
is_megatron_lm_available,
is_mlflow_available,
is_mps_available,
is_npu_available,
is_rich_available,
is_safetensors_available,
is_sagemaker_available,
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class DistributedType(str, enum.Enum):
- **NO** -- Not a distributed environment, just a single process.
- **MULTI_CPU** -- Distributed on multiple CPU nodes.
- **MULTI_GPU** -- Distributed on multiple GPUs.
- **MULTI_NPU** -- Distributed on multiple NPUs.
- **MULTI_XPU** -- Distributed on multiple XPUs.
- **DEEPSPEED** -- Using DeepSpeed.
- **TPU** -- Distributed on TPUs.
Expand All @@ -191,6 +192,7 @@ class DistributedType(str, enum.Enum):
NO = "NO"
MULTI_CPU = "MULTI_CPU"
MULTI_GPU = "MULTI_GPU"
MULTI_NPU = "MULTI_NPU"
MULTI_XPU = "MULTI_XPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"
Expand Down Expand Up @@ -335,6 +337,7 @@ class PrecisionType(BaseEnum):
class RNGType(BaseEnum):
TORCH = "torch"
CUDA = "cuda"
NPU = "npu"
XLA = "xla"
XPU = "xpu"
GENERATOR = "generator"
Expand Down
21 changes: 21 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def is_bf16_available(ignore_tpu=False):
return not ignore_tpu
if torch.cuda.is_available():
return torch.cuda.is_bf16_supported()
if is_npu_available():
return False
return True


Expand Down Expand Up @@ -220,6 +222,25 @@ def get_major_and_minor_from_version(full_version):
return True


@lru_cache
def is_npu_available(check_device=False):
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False

import torch
import torch_npu # noqa: F401

if check_device:
try:
# Will raise a RuntimeError if no NPU is found
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
return hasattr(torch, "npu") and torch.npu.is_available()


@lru_cache
def is_xpu_available(check_device=False):
"check if user disables it explicitly"
Expand Down
Loading

0 comments on commit c33adec

Please sign in to comment.