Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
e45b5e4
ZeRO-Offload v1 (squash) (#345)
jeffra Sep 2, 2020
75d70a9
update DSE to staging for zero-dual
jeffra Sep 2, 2020
1ebcd6c
Update test_sparse_attention.py
jeffra Sep 3, 2020
0159ebb
Assert ZeRO-Offload+gradient accumulation (#347)
tjruwase Sep 3, 2020
6deac82
Adding link to Sparse Attention in Navigation page (#355)
arashashari Sep 3, 2020
6604a5d
Correctness and perf fixes (#354)
tjruwase Sep 3, 2020
20c414d
add cpu adam optimizer (#356)
RezaYazdaniAminabadi Sep 3, 2020
504a643
make the adam unit test work with random params and grads and for mor…
Sep 3, 2020
af51211
Samyamr/zero offload correctness (#359)
samyam Sep 4, 2020
130dd70
Import path fixes + conditional imports (#358)
jeffra Sep 4, 2020
ea5b991
Enable contiguous gradients for cpu_offload
tjruwase Sep 4, 2020
077cfd4
Merge branch 'staging-zero-dual-v2' of github.com:microsoft/DeepSpeed…
tjruwase Sep 4, 2020
7be128a
Allocating CPU memory directly on CPU without transfering them from G…
samyam Sep 4, 2020
1a4a82b
change gpt2 pretrain to have DeepSpeed adam (#361)
samyam Sep 4, 2020
ac12833
Jekyll installation instructions (#351)
Sep 4, 2020
253b044
Generalize detection of ZeRO supported optimizers (#349)
tjruwase Sep 4, 2020
9ba232a
everything is working
samyam Sep 4, 2020
606543d
fixing the cpu_adam API and add deepspeed_adam flag in config.py (#365)
RezaYazdaniAminabadi Sep 4, 2020
1d4b41f
fixing adam copy fp16-param-add more compile flags for cpu_adam
Sep 5, 2020
59ffc1a
run precommit
Sep 5, 2020
ad8af38
fix variance indexes
Sep 5, 2020
2d17a6d
fix array-sizes
Sep 5, 2020
aa3f289
ZeRO-Offload passing model functionality tests (#366)
tjruwase Sep 5, 2020
3788724
move adam_test
Sep 5, 2020
36d5fde
rename perf test
Sep 5, 2020
d8ff56c
fixing adam copy fp16-param and add more compile flags for cpu_adam (…
RezaYazdaniAminabadi Sep 5, 2020
f0c34d0
Perf tests
tjruwase Sep 5, 2020
942ec90
BumpDSE
tjruwase Sep 5, 2020
a64b0ab
fixed a typo; this was fixed before but seems like it has been lost i…
arashashari Sep 5, 2020
4d4eafb
Move code quality tests to Azure-hosted agents. (#368)
Sep 5, 2020
bb75df7
add casting kernel
Sep 6, 2020
7aaccf4
run precommit
Sep 6, 2020
d4c3b0a
revert changes
Sep 6, 2020
4c257dd
revert changes
Sep 6, 2020
58c0741
Merge branch 'staging-zero-dual-v3' into reza/deepspeed_adam_merge_v3
Sep 6, 2020
773e454
Merge branch 'reza/deepspeed_adam_merge_v3' into staging-zero-dual-v3
Sep 6, 2020
edb770d
merging
Sep 6, 2020
19aac8a
ZeRO-Offload: Integration code fixes (#370)
tjruwase Sep 6, 2020
9e83ef2
Update installation instructions (#362)
tjruwase Sep 6, 2020
9dadf38
Update Sparse Attention Tutorial (#357)
arashashari Sep 6, 2020
bae8131
fixing corner cases (#371)
RezaYazdaniAminabadi Sep 6, 2020
75e9e32
fix adam perormance (#372)
RezaYazdaniAminabadi Sep 6, 2020
485a365
ZeRO-Offload passing model tests (#374)
tjruwase Sep 8, 2020
5b214a1
fix cpu adam compilation for AVX2 (#378)
RezaYazdaniAminabadi Sep 8, 2020
135dd08
Move code quality tests to Azure-hosted agents. (#368)
Sep 5, 2020
2df751c
Bump DSE
tjruwase Sep 9, 2020
b73894d
adding sparse attention to feature index page (#377)
arashashari Sep 9, 2020
d098ad8
Merge branch 'master' into staging-zero-dual-v4
jeffra Sep 9, 2020
9a80f4a
support avx2 by default (#383)
RezaYazdaniAminabadi Sep 9, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DeepSpeedExamples
5 changes: 4 additions & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,14 @@ jobs:

- job: Code_Quality_Checks
pool:
name: 'DS_testing'
vmImage: 'ubuntu-latest'
variables:
conda_env: 'ds_codetest'

steps:
- bash: echo "##vso[task.prependpath]$CONDA/bin"
displayName: Add conda to PATH

- script: |
conda create --force --yes -n $(conda_env) python=3.7
source activate $(conda_env)
Expand Down
627 changes: 627 additions & 0 deletions csrc/adam/cpu_adam.cpp

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions csrc/adam/custom_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@


#include "custom_cuda_layers.h"

__global__ void param_update_kernel(const float* input, __half* output, int size)
{
const float4* input_cast = reinterpret_cast<const float4*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);

int id = blockIdx.x * blockDim.x + threadIdx.x;

if (id < size) {
float4 data = input_cast[id];
float2 cast_data;
__half* output_h = reinterpret_cast<__half*>(&cast_data);

output_h[0] = (__half)data.x;
output_h[1] = (__half)data.y;
output_h[2] = (__half)data.z;
output_h[3] = (__half)data.w;

output_cast[id] = cast_data;
}
}

void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
{
int threads = 512;

size /= 4;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);

param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}
118 changes: 118 additions & 0 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#pragma once

#include <cpuid.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <x86intrin.h>
#include <cassert>
#include "context.h"
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"

#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}

#define TILE (1024 * 1024 * 1024)

#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_WIDTH 16
#else
#if defined(__AVX256__)
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_WIDTH 8
#endif
#endif

class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_buf_index(false)
{
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
}
~Adam_Optimizer()
{
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
}
void Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t param_size,
__half* dev_param = nullptr);
void Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sa,
size_t param_size,
__half* dev_param = nullptr);
void Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
inline void IncrementStep()
{
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}

private:
#if defined(__AVX512__) or defined(__AVX256__)
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#else
__m256 data;
#endif
// float data_f[16];
};
#endif

float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;

float _betta1_t;
float _betta2_t;

float* _doubled_buffer[2];
bool _buf_index;
};
2 changes: 2 additions & 0 deletions csrc/includes/custom_cuda_layers.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,5 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
int rows,
int cols,
cudaStream_t stream);

void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
18 changes: 10 additions & 8 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import sys
import types

from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from deepspeed.runtime.lr_schedules import add_tuning_arguments
from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.runtime.activation_checkpointing import checkpointing
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from deepspeed.utils import logger
from . import ops

from .runtime.engine import DeepSpeedEngine
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_ADAM
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .utils import logger

try:
from deepspeed.git_version_info import version, git_hash, git_branch
from .git_version_info import version, git_hash, git_branch
except ImportError:
version = "0.0.0+unknown"
git_hash = None
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ..git_version_info import installed_ops as __installed_ops__
from . import lamb
from . import transformer
if __installed_ops__['sparse-attn']:
from . import sparse_attention
if __installed_ops__['cpu-adam']:
from . import adam
1 change: 1 addition & 0 deletions deepspeed/ops/adam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cpu_adam import DeepSpeedCPUAdam
81 changes: 81 additions & 0 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import math
import torch
import importlib

ds_opt_adam = None


class DeepSpeedCPUAdam(torch.optim.Optimizer):

optimizer_id = 0

def __init__(self,
model_params,
lr=1e-3,
betas=(0.9,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False):

default_args = dict(lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)

self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1

global ds_opt_adam
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay)

def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)

@torch.no_grad()
def step(self, closure=None, fp16_param_groups=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']):

if p.grad is None:
continue

grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, device='cpu')
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu')

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1

if fp16_param_groups is not None:
p_fp16 = fp16_param_groups[group_id][param_id]
ds_opt_adam.adam_update_copy(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq,
p_fp16)
else:
ds_opt_adam.adam_update(self.opt_id,
p.data,
grad,
exp_avg,
exp_avg_sq)
return loss
33 changes: 33 additions & 0 deletions deepspeed/pt/deepspeed_zero_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch.autograd import Variable
import collections


def async_migrate_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
v = obj.cuda(dev, async=True)
if main_stream is not None:
v.data.record_stream(main_stream)
return v
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
else:
return obj


def async_copy_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
obj = Variable(obj)
if isinstance(obj, Variable):
target = torch.empty_like(obj, device=dev).copy_(obj)
if main_stream is not None:
target.data.record_stream(main_stream)
return target
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
Loading