Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,12 @@ repos:
hooks:
- id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']

- repo: local
hooks:
- id: check-torchcuda
name: check-torchcuda
entry: ./scripts/check-torchcuda.py
language: script
exclude: ^(.github/workflows/|scripts/check-torchcuda.py|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
6 changes: 4 additions & 2 deletions deepspeed/profiling/flops_profiler/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro
import torchvision.models as models
import torch
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator

with torch.cuda.device(0):
with get_accelerator().device(0):
model = models.alexnet()
batch_size = 256
flops, macs, params = get_model_profile(model=model, # model
Expand All @@ -334,6 +335,7 @@ from functools import partial
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator


def bert_input_constructor(batch_size, seq_len, tokenizer):
Expand All @@ -350,7 +352,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer):
return inputs


with torch.cuda.device(0):
with get_accelerator().device(0):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
batch_size = 4
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _set_cuda_rng_state(new_state, device=-1):

Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) #ignore-cuda
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
Expand Down Expand Up @@ -499,7 +499,7 @@ def get_cpu_activations_for_backward(args, inputs):
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` #ignore-cuda
2) the states in the model parallel tracker are also properly
tracked/set/reset.
3) Performance activation partitioning, contiguous memory optimization
Expand Down
13 changes: 4 additions & 9 deletions deepspeed/runtime/zero/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,10 @@ def print_rank_0(message, debug=False, force=False):
print(message)


device = get_accelerator().device_name()
if device == 'cuda':
try:
autocast_custom_fwd = torch.cuda.amp.custom_fwd
autocast_custom_bwd = torch.cuda.amp.custom_bwd
except (ImportError, AttributeError) as exp:
autocast_custom_fwd = noop_decorator
autocast_custom_bwd = noop_decorator
else:
try:
autocast_custom_fwd = get_accelerator().amp().custom_fwd
autocast_custom_bwd = get_accelerator().amp().custom_bwd
except (ImportError, AttributeError) as exp:
autocast_custom_fwd = noop_decorator
autocast_custom_bwd = noop_decorator

Expand Down
2 changes: 1 addition & 1 deletion docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s

| Description | Default |
| ------------------------------------------------------------- | ------- |
| Inserts torch.cuda.synchronize() at each checkpoint boundary. | `false` |
| Inserts get_accelerator().synchronize() at each checkpoint boundary. | `false` |


<i>**profile**</i>: [boolean]
Expand Down
6 changes: 4 additions & 2 deletions docs/_tutorials/flops-profiler.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,9 @@ The following example shows how to profile AlexNet using the DeepSpeed flops pro
import torchvision.models as models
import torch
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator

with torch.cuda.device(0):
with get_accelerator().device(0):
model = models.alexnet()
batch_size = 256
flops, macs, params = get_model_profile(model=model, # model
Expand All @@ -341,6 +342,7 @@ from functools import partial
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from deepspeed.profiling.flops_profiler import get_model_profile
from deepspeed.accelerator import get_accelerator


def bert_input_constructor(batch_size, seq_len, tokenizer):
Expand All @@ -357,7 +359,7 @@ def bert_input_constructor(batch_size, seq_len, tokenizer):
return inputs


with torch.cuda.device(0):
with get_accelerator().device(0):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
batch_size = 4
Expand Down
2 changes: 1 addition & 1 deletion docs/_tutorials/megatron.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ DeepSpeed's `save_checkpoint()`.
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['cuda_rng_state'] = get_accelerator().get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()

model.save_checkpoint(args.save, iteration, client_state = sd)
Expand Down
66 changes: 66 additions & 0 deletions scripts/check-torchcuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python3
from __future__ import annotations
'''Copyright The Microsoft DeepSpeed Team'''
"""
Checks each file in sys.argv for the string "torch.cuda".
Modified from https://github.com/jlebar/pre-commit-hooks/blob/master/check_do_not_submit.py
"""

import subprocess
import sys


def err(s: str) -> None:
print(s, file=sys.stderr)


# There are many ways we could search for the string "torch.cuda", but `git
# grep --no-index` is nice because
# - it's very fast (as compared to iterating over the file in Python)
# - we can reasonably assume it's available on all machines
# - unlike plain grep, which is slower and has different flags on MacOS versus
# Linux, git grep is always the same.
res = subprocess.run(
[
"git",
"grep",
"-Hn",
"--no-index",
"-e",
r"torch\.cuda",
"--and",
"--not",
"-e",
"#ignore-cuda",
*sys.argv[1:]
],
capture_output=True,
)
if res.returncode == 0:
err('Error: The string "torch.cuda" was found.\nPlease replace all calls to torch.cuda with "get_accelerator()" and add the following import line:\n\n from deepspeed.accelerator import get_accelerator\n\nIf your code is mean to be cuda specific, please add the following comment in the line with torch.cuda:\n\n #ignore-cuda\n'
)
err(res.stdout.decode("utf-8"))
sys.exit(1)
elif res.returncode == 2:
err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
err(res.stderr.decode("utf-8"))
sys.exit(2)

res = subprocess.run(
["git",
"grep",
"-Hn",
"--no-index",
r"\.cuda()",
*sys.argv[1:]],
capture_output=True,
)
if res.returncode == 0:
err('Error: The string ".cuda()" was found. This implies convert a tensor to cuda tensor. Please replace all calls to tensor.cuda() with "tensor.to(get_accelerator().device_name())" and add the following import line:\nfrom deepspeed.accelerator import get_accelerator'
)
err(res.stdout.decode("utf-8"))
sys.exit(1)
elif res.returncode == 2:
err(f"Error invoking grep on {', '.join(sys.argv[1:])}:")
err(res.stderr.decode("utf-8"))
sys.exit(2)
5 changes: 3 additions & 2 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def _get_fixture_kwargs(self, request, func):
return fixture_kwargs

def _launch_procs(self, num_procs):
if torch.cuda.is_available() and torch.cuda.device_count() < num_procs:
if get_accelerator().is_available(
) and get_accelerator().device_count() < num_procs:
pytest.skip(
f"Skipping test because not enough GPUs are available: {num_procs} required, {torch.cuda.device_count()} available"
f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
)
mp.set_start_method('forkserver', force=True)
skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def skip_on_arch(min_arch=7):
if deepspeed.accelerator.get_accelerator().device_name() == 'cuda':
if torch.cuda.get_device_capability()[0] < min_arch:
if torch.cuda.get_device_capability()[0] < min_arch: #ignore-cuda
pytest.skip(f"needs higher compute capability than {min_arch}")
else:
assert deepspeed.accelerator.get_accelerator().device_name() == 'xpu'
Expand Down