Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
0518252
add manual workflow to run tests with precompiled ops
jeffra Dec 11, 2020
8a184b6
[build] fix computer capability arch flags, add PTX, handle PTX (#591)
stas00 Dec 11, 2020
66268bd
add DeepSpeedZeroConfig repr method (#596)
stas00 Dec 11, 2020
a4763f5
Supported customizing kwargs for lr_scheduler (#584)
carefree0910 Dec 11, 2020
c5a449f
Update launcher to set local rank environ variable (#597)
jeffra Dec 11, 2020
9f8e8f3
implement missing get_last_lr (#595)
stas00 Dec 14, 2020
007466e
[doc] xref to hostfile discussion (#604)
stas00 Dec 15, 2020
6380ee3
Fixes for RTD build errors (#606)
jeffra Dec 15, 2020
fd2f970
Transformer-kernel - supporting any arbitrary sequence-length (#587)
RezaYazdaniAminabadi Dec 17, 2020
7435b2f
Ability to initialize distributed backend outside deepspeed runtime (…
jeffra Dec 18, 2020
81aeea3
Elastic training support (#602)
jeffra Dec 23, 2020
24e0739
update SA comp check to fix torch-cpu issue (#631)
jeffra Jan 4, 2021
e6ac731
Support initialization with dict configuration (#632)
tjruwase Jan 4, 2021
a9a83a6
Allow DeepSpeed models to be initialized with optimizer=None (#469)
gcooper-isi Jan 5, 2021
d38ad6a
change dist to torch.distributed to fix bug in assert. (#638)
awan-10 Jan 5, 2021
46d2e28
docs: minor spelling tweaks (#623)
brettkoonce Jan 5, 2021
5ab1279
Fix docstring format (#640)
tjruwase Jan 5, 2021
44bd538
Module replacement support (#586)
jeffra Jan 6, 2021
64461da
Update builder.py (#642)
sxjscience Jan 7, 2021
8cea96d
Bump nokogiri from 1.10.10 to 1.11.0 in /docs (#630)
dependabot[bot] Jan 7, 2021
4e2dc4e
Add deepspeed.init_distributed to RTD page (#645)
jeffra Jan 7, 2021
828d75b
document deepspeed.initialize() (#644)
stas00 Jan 8, 2021
bc046dc
add additional validation checks in elastic config (#646)
jeffra Jan 8, 2021
af212f6
Remove a very verbose print statement. (#649)
awan-10 Jan 8, 2021
c14b839
version bump to 0.3.10
jeffra Jan 8, 2021
da5563a
LR scheduler unit tests (#429)
tjruwase Jan 8, 2021
adcfd26
Handle actvitation checkpointing args that are None or non-tensors (#…
Jan 12, 2021
e2fbe4d
squash latest flops profiling changes (#1) (#664)
cli99 Jan 13, 2021
981bc7d
Move workspace memory-allocation to PyTorch (#661)
RezaYazdaniAminabadi Jan 13, 2021
f032e56
Validate consistent ckpt tags across ranks (#667)
jeffra Jan 14, 2021
865104b
Support optimizer AdamW type (#670)
tjruwase Jan 15, 2021
6217a6c
skip empty lines in hostfile (#669)
jeffra Jan 15, 2021
c5e4264
Add AdamW to the supported optimizers (#672)
stas00 Jan 15, 2021
e729a3f
add missing config menu entries (#652)
stas00 Jan 15, 2021
7b07e12
doc fix (#651)
stas00 Jan 15, 2021
82cecf6
add zero-offload paper (#680)
jeffra Jan 19, 2021
7b0bee0
[tutorials] typos (#676)
stas00 Jan 20, 2021
e59ba12
make test_pipe more stable (#683)
Jan 20, 2021
34c83a5
Fix ZeRO 2 + Pipelining (#677)
leogao2 Jan 20, 2021
852c524
Add optional timeout parameter to deepspeed.init_distributed (#637)
sdtblck Jan 25, 2021
5221832
Fix wrong idx bug in invertible LayerNormBackward1 (#692)
Taka152 Jan 26, 2021
7833aed
Create torch16.yml (#699)
jeffra Jan 27, 2021
cd29f8b
Update torch16.yml
jeffra Jan 27, 2021
91b1b7f
[transformer-kernel] turn off unit test printing (#701)
jeffra Jan 27, 2021
2e2dd86
Dist testing backend fixes, etc. (#708)
jeffra Jan 29, 2021
5e522ef
set_batch_fn and remove old sanity check (#712)
Jan 29, 2021
3cecbc1
properly set engine.local_rank if it's set to -1
jeffra Feb 1, 2021
6332e31
Add executable permission to `ds_elastic` and `ds_report` in `bin`. (…
joneyolfson Feb 1, 2021
45c33ee
local rank of -1 means not set (#720)
jeffra Feb 1, 2021
72b23ea
bump to 0.3.11
jeffra Feb 2, 2021
4f1d827
[launcher] look ma, no more zombies (#714)
stas00 Feb 4, 2021
b08aa6f
Improve starred expressions (#696)
joneyolfson Feb 8, 2021
c5b3f40
Fixed typo in Readme. (#737)
TheDudeFromCI Feb 9, 2021
6beca3c
1bit_adam dependencies (#742)
stas00 Feb 10, 2021
6ee3b29
Clickable screenshots (#746)
tjruwase Feb 10, 2021
e2dfe0d
Add flops profiler tutorial (#682)
cli99 Feb 11, 2021
59eed17
Only initialize distributed if required (#734)
Feb 11, 2021
1b8ca8e
fix spelling mistake (#749)
sdtblck Feb 11, 2021
248f638
1-bit Adam documentation fix (#747)
conglongli Feb 11, 2021
6fb1610
Replace timer print rank 0 with logging (#732)
Feb 12, 2021
78e776a
[install] fixes/improvements/docs (#752)
stas00 Feb 12, 2021
7bf1b83
[install] add -e/--examples flag to checkout submodules (#755)
jeffra Feb 12, 2021
ec8b1cb
Activation checkpointing for non-tensor arguments and return values (…
tjruwase Feb 12, 2021
7cab55c
Checks for None tensors and skip them when splitting the buckets in z…
cli99 Feb 16, 2021
c28a71f
Minor doc tweaks (#761)
tjruwase Feb 16, 2021
8067efa
Fix NameError: name 'dist' is not defined (#763)
tma15 Feb 17, 2021
68e138b
[dist] set args.local_rank to LOCAL_RANK (#764)
jeffra Feb 17, 2021
1fcc5f7
Fix transformer kernel CUDA illegal memory access error (#765)
conglongli Feb 18, 2021
ee1ffe2
CPU-Adam fix for scalar mode (#735)
RezaYazdaniAminabadi Feb 18, 2021
29fa4b2
Update engine.py (#767)
jeffra Feb 19, 2021
e60e92e
[doc] fix incorrect param name (#773)
stas00 Feb 20, 2021
48065c0
Fixing the module-inject Api (#786)
RezaYazdaniAminabadi Feb 24, 2021
e2dfcad
Fix the bias-add and add the layer-norm-eps parameter (#791)
RezaYazdaniAminabadi Feb 24, 2021
62396b7
Delete out2 (#798)
vfdev-5 Feb 26, 2021
490e6f7
fixing the compiling issue for the AMD architecture (#796)
RezaYazdaniAminabadi Feb 26, 2021
7eb083c
document the requirement to call for all ranks (#801)
stas00 Feb 26, 2021
db987cf
fixed typo (#802)
vfdev-5 Feb 27, 2021
937c5ce
issue with the implementation of column_sum_reduce (#804)
zmxdream Feb 28, 2021
8295d7a
Fixing gelu_checkpointing memory issue (#812)
RezaYazdaniAminabadi Mar 3, 2021
ba33e86
Update ZeRO-Offload tutorials (#824)
tjruwase Mar 8, 2021
599258f
ZeRO 3 Offload (#834)
samyam Mar 8, 2021
d7de916
update tutorial/doc links for zero3 (#835)
jeffra Mar 8, 2021
75ffdaf
Fix zero3 tutorial link
jeffra Mar 8, 2021
9c5eee3
bump DSE to include ZeRO-3
jeffra Mar 8, 2021
af54897
Fix for RTD
jeffra Mar 8, 2021
6adc19a
Model scale changing 5x to 3x
samyam Mar 8, 2021
4949636
replace home env with ~
jeffra Mar 9, 2021
2e6692c
Fix regression in runner (#843)
jeffra Mar 9, 2021
564eb4b
bumping DSE pointer (#847)
Mar 10, 2021
dd03cff
set adamw_mode default true (follows FusedAdam and < 0.3.11 logic) (#…
jeffra Mar 11, 2021
29853c3
less scary overflow notice (#833)
stas00 Mar 11, 2021
e0f36ed
Add optimizers and schedules to RTD and updated the corresponding par…
cli99 Mar 11, 2021
7925d0c
small tweaks (#839)
stas00 Mar 11, 2021
311795d
Control ZeRO wall clock timers (#849)
tjruwase Mar 11, 2021
18a26f3
[WarmupDecayLR] fix log(0) & 1/log(1) bugs (#772)
stas00 Mar 12, 2021
35fd7cc
bump to v0.3.12
jeffra Mar 12, 2021
458ff02
Bug fix: Remove client optimizer param_group list item that does not …
cli99 Mar 12, 2021
73d762c
[doc] pipeline doc typos/improvements (#659)
stas00 Mar 14, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ name: Build

# Controls when the action will run.
on:
# Triggers the workflow on push or pull request events but only for the master branch
push:
branches: [ master ]
paths-ignore:
- 'docs/**'
pull_request:
branches: [ master ]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
paths-ignore:
- 'docs/**'

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
Expand Down Expand Up @@ -50,4 +48,4 @@ jobs:
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose tests/unit/
47 changes: 47 additions & 0 deletions .github/workflows/pre-compile-ops.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This is a basic workflow to help you get started with Actions

name: Tests-w-precompiled-ops

# Controls when the action will run.
on:
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: self-hosted

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2

# Runs a single command using the runners shell
- name: environment
run: |
nvidia-smi
which python
python --version
which nvcc
nvcc --version
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"

# Runs a set of commands using the runners shell
- name: Install deepspeed
run: |
DS_BUILD_OPS=1 pip install .[dev]
ds_report

- name: Formatting checks
run: |
pre-commit run --all-files

# Runs a set of commands using the runners shell
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
44 changes: 44 additions & 0 deletions .github/workflows/torch16.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Unit test config for manual use on torch1.6 runners

name: Torch16

# Controls when the action will run.
on:
#pull_request:
# paths-ignore:
# - 'docs/**'
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: [self-hosted, torch1.6]

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2

# Runs a single command using the runners shell
- name: environment
run: |
nvidia-smi
which python
python --version
which nvcc
nvcc --version
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
# Runs a set of commands using the runners shell
- name: Install deepspeed
run: |
pip install .[dev]
ds_report
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
2 changes: 1 addition & 1 deletion DeepSpeedExamples
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)


# News
* [2021/03/08] [ZeRO-3 Offload: Scale your models to trillion parameters without code changes while leveraging both CPUs & GPUs](https://www.deepspeed.ai/news/2021/03/07/zero3-offload.html)
* [2020/11/12] [Simplified install, JIT compiled ops, PyPI releases, and reduced dependencies](#installation)
* [2020/11/10] [Efficient and robust compressed training through progressive layer dropping](https://www.deepspeed.ai/news/2020/10/28/progressive-layer-dropping-news.html)
* [2020/09/10] [DeepSpeed v0.3: Extreme-scale model training for everyone](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/)
Expand Down Expand Up @@ -113,7 +114,7 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* [Ultra-fast dense transformer kernels](https://www.deepspeed.ai/news/2020/05/18/bert-record.html)
* [Sparse attention](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html)
* Memory- and compute-efficient sparse kernels
* Support 10x long sequences than dense
* Support 10x longer sequences than dense
* Flexible support to different sparse structures
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html)
* Custom communication collective
Expand Down Expand Up @@ -185,6 +186,8 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: memory optimizations toward training trillion parameter models. [arXiv:1910.02054](https://arxiv.org/abs/1910.02054) and [In Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis (SC '20)](https://dl.acm.org/doi/10.5555/3433701.3433727).
2. Jeff Rasley, Samyam Rajbhandari, Olatunji Ruwase, and Yuxiong He. (2020) DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. [In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20, Tutorial)](https://dl.acm.org/doi/10.1145/3394486.3406703).
3. Minjia Zhang, Yuxiong He. (2020) Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping. [arXiv:2010.13369](https://arxiv.org/abs/2010.13369) and [NeurIPS 2020](https://proceedings.neurips.cc/paper/2020/hash/a1140a3d0df1c81e24ae954d935e8926-Abstract.html).
4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840).
5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888).

# Videos
1. DeepSpeed KDD 2020 Tutorial
Expand Down
42 changes: 42 additions & 0 deletions bin/ds_elastic
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python

import argparse
import json

import deepspeed
from deepspeed.elasticity import compute_elastic_config

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json")
parser.add_argument('-w',
'--world-size',
type=int,
default=0,
help="Intended/current world size")
args = parser.parse_args()
ds_config = json.load(open(args.config, 'r'))

ds_version = deepspeed.__version__

elastic_config = ds_config['elasticity']
print('------------------------------------------')
print("Elasticity config:")
print('------------------------------------------')
print(json.dumps(elastic_config, indent=4, sort_keys=True))

if args.world_size > 0:
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size)
print('------------------------------------------')
print(f"Calculated results for world size {args.world_size}:")
print('------------------------------------------')
print(f'final_batch_size .... {final_batch_size}')
print(f'valid_gpus .......... {valid_gpus}')
print(f'micro_batch_size .... {micro_batch_size}')
else:
final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version)
print('------------------------------------------')
print("Calculated results:")
print('------------------------------------------')
print(f'final_batch_size .... {final_batch_size}')
print(f'valid_gpus .......... {valid_gpus}')
Empty file modified bin/ds_report
100644 → 100755
Empty file.
87 changes: 46 additions & 41 deletions csrc/adam/cpu_adam.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void Adam_Optimizer::Step(float* _params,
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }

#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
Expand Down Expand Up @@ -101,47 +103,50 @@ void Adam_Optimizer::Step(float* _params,
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}

#endif

if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t k = rounded_size; k < _param_size; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;

variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;

grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param;

_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + rounded_size,
(_param_size - rounded_size),
Context::Instance().GetCurrentStream());
for (size_t k = t; k < offset; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;

variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;

grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;

_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
}
}
Expand Down Expand Up @@ -189,6 +194,7 @@ void Adam_Optimizer::Step_4(float* _params,
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4];
Expand Down Expand Up @@ -295,10 +301,8 @@ void Adam_Optimizer::Step_4(float* _params,
}

if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
Expand Down Expand Up @@ -400,6 +404,7 @@ void Adam_Optimizer::Step_8(float* _params,
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8];
Expand Down Expand Up @@ -582,10 +587,8 @@ void Adam_Optimizer::Step_8(float* _params,
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
}
if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index],
dev_params + t,
copy_size,
Context::Instance().GetCurrentStream());
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
Expand Down Expand Up @@ -628,6 +631,7 @@ int ds_adam_step(int optimizer_id,
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));

opt->SynchronizeStreams();
return 0;
}

Expand Down Expand Up @@ -664,6 +668,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);

opt->SynchronizeStreams();
return 0;
}

Expand Down
16 changes: 5 additions & 11 deletions csrc/includes/context.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,10 @@ class Context {
return _ctx;
}

void GenWorkSpace(size_t size)
void SetWorkSpace(void* workspace)
{
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
cudaFree(_workspace);
cudaMalloc(&_workspace, size);
}

_workSpaceSize = size;
if (!workspace) { throw std::runtime_error("Workspace is null."); }
_workspace = workspace;
}

void* GetWorkSpace() { return _workspace; }
Expand All @@ -88,6 +81,8 @@ class Context {
return stream;
}

cudaStream_t GetNewStream() { return at::cuda::getStreamFromPool(); }

cublasHandle_t GetCublasHandle() { return _cublasHandle; }

std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
Expand Down Expand Up @@ -172,6 +167,5 @@ class Context {
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
std::vector<std::array<int, 3>> _gemm_algos;
};
Loading