Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.ops.adam import DeepSpeedCPUAdam

from deepspeed.utils import logger
#Toggle this to true to enable correctness test
Expand Down Expand Up @@ -1416,6 +1415,7 @@ def step(self, closure=None):
#torch.set_num_threads(12)
timers('optimizer_step').start()
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
#self.optimizer.step()
#for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
Expand Down
38 changes: 27 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import shutil
import subprocess
import warnings
import cpufeature
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension

Expand All @@ -25,6 +24,27 @@ def fetch_requirements(path):
return [r.strip() for r in fd.readlines()]


def available_vector_instructions():
try:
import cpufeature
except ImportError:
warnings.warn(
f'import cpufeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}

cpu_vector_instructions = {}
try:
cpu_vector_instructions = cpufeature.CPUFeature
except _:
warnings.warn(
f'cpufeature.CPUFeature failed - CPU vector optimizations are not available for CPUAdam'
)
return {}

return cpu_vector_instructions


install_requires = fetch_requirements('requirements/requirements.txt')
dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
Expand All @@ -43,29 +63,26 @@ def fetch_requirements(path):
SPARSE_ATTN = "sparse-attn"
CPU_ADAM = "cpu-adam"

cpu_vector_instructions = available_vector_instructions()

# Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10
DS_BUILD_SPARSE_ATTN_MASK = 100
DS_BUILD_CPU_ADAM_MASK = 1000
DS_BUILD_AVX512_MASK = 10000

# Allow for build_cuda to turn on or off all ops
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK | DS_BUILD_AVX512_MASK
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS

# Set default of each op based on if build_cuda is set
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM',
OP_DEFAULT)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM', 0)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK
DS_BUILD_AVX512 = int(os.environ.get(
'DS_BUILD_AVX512',
cpufeature.CPUFeature['AVX512f'])) * DS_BUILD_AVX512_MASK

# Final effective mask is the bitwise OR of each op
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN
Expand Down Expand Up @@ -111,11 +128,10 @@ def fetch_requirements(path):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5

cpu_info = cpufeature.CPUFeature
SIMD_WIDTH = ''
if cpu_info['AVX512f'] and DS_BUILD_AVX512:
if cpu_vector_instructions.get('AVX512f', False):
SIMD_WIDTH = '-D__AVX512__'
elif cpu_info['AVX2']:
elif cpu_vector_instructions.get('AVX2', False):
SIMD_WIDTH = '-D__AVX256__'
print("SIMD_WIDTH = ", SIMD_WIDTH)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import copy

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam

if not deepspeed.ops.__installed_ops__['cpu-adam']:
pytest.skip("cpu-adam is not installed", allow_module_level=True)
else:
from deepspeed.ops.adam import DeepSpeedCPUAdam


def check_equal(first, second, atol=1e-2, verbose=False):
Expand Down