Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-72] Improve sparse.adam_update #10062

Merged
merged 19 commits into from
Mar 27, 2018

Conversation

eric-haibin-lin
Copy link
Member

Description

Fix https://discuss.mxnet.io/t/lazy-update-with-adam-optimizer-is-much-slower-for-sparse-input/724

Many kernels written for row_sparse ndarrays are slow on GPU. For example, in sparse SGD, the kernel has the following logic using kernel::launch:

# pragma parallel for
for (i = 0 ... num_nonzero_rows) {
    # thread i computes an entire row in the output
    row = indices[i]
    for (col = 0 .. num_columns) 
        output[row][col] = weight - lr * grad 
}

Such "parallelization by rows"(A) works fine on CPU but is problematic on GPU, because:

  • when num_nonzero_rows is small, only a small number of cuda threads are launched. GPU utilization will be low
  • the memory access pattern is horrible. GPU threads are launched as groups of warps. Each thread in the same warp are accessing different memory location, resulting 32x more unnecessary memory transactions.

Instead, the kernels on GPU should be "parallelized by the number of elements"(B) to update. This way, all threads in the same warp are accessing the same chunk of memory.

On the other hand, if I apply B on CPU with openmp, I see the performance is 3-4x slower. Hence I only applied B for GPUs in this PR. (I didn't dig deeper why this happens - the performance should be comparable if static openmp scheduling is used.I didn't check what default omp scheduling strategy is on my instance. Maybe @cjolivier01 has more insight on cpu performance?).

time(s) for 300 iterations for lazy_update=True (26x improvement)

nnr 1280 12800
Before 1.31863999367 13.2637498379
After 0.050815820694 0.430017948151

time(s) for 30 iterations for lazy_update=False (34x improvement)

nnr 1280 12800
Before 22.2087020874 21.9239070415
After 0.658735990524 0.678998947144

(Just want to trigger the CI. Will update destination branch to master later.)

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here
import time
import mxnet as mx
import numpy as np
import argparse

mx.random.seed(0)
np.random.seed(0)

parser = argparse.ArgumentParser(description='Bench updater')
parser.add_argument('--dim-in', type=int, default=240000, help='weight.shape[0]')
parser.add_argument('--dim-out', type=int, default=512, help='weight.shape[1]')
parser.add_argument('--nnr', type=int, default=5000, help='grad.indices.shape[0]')
parser.add_argument('--repeat', type=int, default=1000, help='num repeat')
parser.add_argument('--dense-grad', action='store_true')
parser.add_argument('--dense-state', action='store_true')
parser.add_argument('--cpu', action='store_true')
args = parser.parse_args()
dim_in = args.dim_in
dim_out = args.dim_out
nnr = args.nnr
ctx = mx.cpu() if args.cpu else mx.gpu()

ones = mx.nd.ones((dim_in, dim_out), ctx=ctx)

if not args.dense_grad:
    weight = ones.tostype('row_sparse')
    indices = np.arange(dim_in)
    np.random.shuffle(indices)
    indices = np.unique(indices[:nnr])
    indices = mx.nd.array(indices, ctx=ctx)
    grad = mx.nd.sparse.retain(weight, indices)
else:
    weight = ones.copy()
    grad = ones.copy()

if args.dense_state:
    mean = ones.copy()
else:
    mean = ones.tostype('row_sparse')

var = mean.copy()

# warmup
for i in range(10):
    mx.nd.sparse.adam_update(weight, grad, mean, var, out=weight,
                         lr=1, wd=0, beta1=0.9, beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()
a = time.time()
for i in range(args.repeat):
    mx.nd.sparse.adam_update(weight, grad, mean, var, out=weight,
                         lr=1, wd=0, beta1=0.9, beta2=0.99, rescale_grad=0.5, epsilon=1e-8)
weight.wait_to_read()
b = time.time()
print(b - a)

eric-haibin-lin and others added 16 commits January 28, 2018 02:37
* bump

* also update base.h

* revert website changes

* Update index.html
* Update NEWS.md

* Update README.md
* refactor regression ops

* fix err for instantiation of minus_sign

* remove useless header file init_op.h

* replace with macro and address other comments

* update

* minor revise docs

* add mae test
* Fixed 4 broken links

* fixed pylint for long line disable and 1 broken link
* Revert "avoid per-batch blocking in metric (apache#9636)"

This reverts commit 3fe694e.

* Revert "proper flatten in acc (apache#9619)"

This reverts commit ed823b2.

* Revert "use nd for accuracy calculation (apache#9583)"

This reverts commit f5f1b91.

* keep doc change
…icenses to LICENSE file (apache#9701)

* Revert "[Review Required] Fixing Licenses: Cleaning up the Top Level LICENSE file (apache#9484)"

This reverts commit 8930d96.

* Some more LICENSE fixes

* Adding some more packages to the LICENSE file

* Adding dependencies of dependencies
* update navbar model zoo link

* update
@marcoabreu
Copy link
Contributor

Very nice catch! Do you have an estimation how much overall speedup this could bring? We could highlight this in the release notes

@marcoabreu
Copy link
Contributor

marcoabreu commented Mar 11, 2018

By the way, could you add the benchmark at tests/python/benchmark so we can use them later on?

@marcoabreu
Copy link
Contributor

@piiswrong
Copy link
Contributor

does this include the fix for wd & states?
Have we decided what to do with wd in general?

@eric-haibin-lin
Copy link
Member Author

@piiswrong no it doesn't. It only supports wd=0. No decision for wd in general yet

@eric-haibin-lin eric-haibin-lin changed the title [MXNET-72] [WIP] Improve sparse.adam_update [MXNET-72] Improve sparse.adam_update Mar 21, 2018
@eric-haibin-lin eric-haibin-lin changed the base branch from v1.1.0 to master March 21, 2018 21:46
@eric-haibin-lin
Copy link
Member Author

@haojin2 @cjolivier01 @reminisce @anirudh2290 @rahul003 @ZiyueHuang could you help review?

@haojin2
Copy link
Contributor

haojin2 commented Mar 26, 2018

LGTM!

@eric-haibin-lin eric-haibin-lin merged commit f0f745d into apache:master Mar 27, 2018
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
* Bump 1.1 (apache#192)

* bump

* also update base.h

* revert website changes

* Update index.html

* update news.md (apache#191)

* Update NEWS.md

* Update README.md

* refactor regression ops to nnvm interface (apache#9540)

* refactor regression ops

* fix err for instantiation of minus_sign

* remove useless header file init_op.h

* replace with macro and address other comments

* update

* minor revise docs

* add mae test

* Update KEYS

* Update NEWS.md

* fixed links that were missng ndarray folder path (apache#9618)

* Fixed 4 broken links (apache#9698)

* Fixed 4 broken links

* fixed pylint for long line disable and 1 broken link

* Update NEWS.md

* Update NOTICE (apache#9706)

* revert acc changes (apache#9731)

* Revert "avoid per-batch blocking in metric (apache#9636)"

This reverts commit 3fe694e.

* Revert "proper flatten in acc (apache#9619)"

This reverts commit ed823b2.

* Revert "use nd for accuracy calculation (apache#9583)"

This reverts commit f5f1b91.

* keep doc change

* PGP keys add liuyizhi AT apache.org (apache#9728)

* Add my key (apache#9736)

* [REVIEW REQUIRED] Revert PR apache#9484 & add additional dependency licenses to LICENSE file (apache#9701)

* Revert "[Review Required] Fixing Licenses: Cleaning up the Top Level LICENSE file (apache#9484)"

This reverts commit 8930d96.

* Some more LICENSE fixes

* Adding some more packages to the LICENSE file

* Adding dependencies of dependencies

* update navbar model zoo link (apache#9749)

* update navbar model zoo link

* update

* initial commit

* clean up

* refactor

* fix test
cjolivier01 pushed a commit to cjolivier01/mxnet that referenced this pull request Mar 30, 2018
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* Bump 1.1 (apache#192)

* bump

* also update base.h

* revert website changes

* Update index.html

* update news.md (apache#191)

* Update NEWS.md

* Update README.md

* refactor regression ops to nnvm interface (apache#9540)

* refactor regression ops

* fix err for instantiation of minus_sign

* remove useless header file init_op.h

* replace with macro and address other comments

* update

* minor revise docs

* add mae test

* Update KEYS

* Update NEWS.md

* fixed links that were missng ndarray folder path (apache#9618)

* Fixed 4 broken links (apache#9698)

* Fixed 4 broken links

* fixed pylint for long line disable and 1 broken link

* Update NEWS.md

* Update NOTICE (apache#9706)

* revert acc changes (apache#9731)

* Revert "avoid per-batch blocking in metric (apache#9636)"

This reverts commit 3fe694e.

* Revert "proper flatten in acc (apache#9619)"

This reverts commit ed823b2.

* Revert "use nd for accuracy calculation (apache#9583)"

This reverts commit f5f1b91.

* keep doc change

* PGP keys add liuyizhi AT apache.org (apache#9728)

* Add my key (apache#9736)

* [REVIEW REQUIRED] Revert PR apache#9484 & add additional dependency licenses to LICENSE file (apache#9701)

* Revert "[Review Required] Fixing Licenses: Cleaning up the Top Level LICENSE file (apache#9484)"

This reverts commit 8930d96.

* Some more LICENSE fixes

* Adding some more packages to the LICENSE file

* Adding dependencies of dependencies

* update navbar model zoo link (apache#9749)

* update navbar model zoo link

* update

* initial commit

* clean up

* refactor

* fix test
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* Bump 1.1 (apache#192)

* bump

* also update base.h

* revert website changes

* Update index.html

* update news.md (apache#191)

* Update NEWS.md

* Update README.md

* refactor regression ops to nnvm interface (apache#9540)

* refactor regression ops

* fix err for instantiation of minus_sign

* remove useless header file init_op.h

* replace with macro and address other comments

* update

* minor revise docs

* add mae test

* Update KEYS

* Update NEWS.md

* fixed links that were missng ndarray folder path (apache#9618)

* Fixed 4 broken links (apache#9698)

* Fixed 4 broken links

* fixed pylint for long line disable and 1 broken link

* Update NEWS.md

* Update NOTICE (apache#9706)

* revert acc changes (apache#9731)

* Revert "avoid per-batch blocking in metric (apache#9636)"

This reverts commit 3fe694e.

* Revert "proper flatten in acc (apache#9619)"

This reverts commit ed823b2.

* Revert "use nd for accuracy calculation (apache#9583)"

This reverts commit f5f1b91.

* keep doc change

* PGP keys add liuyizhi AT apache.org (apache#9728)

* Add my key (apache#9736)

* [REVIEW REQUIRED] Revert PR apache#9484 & add additional dependency licenses to LICENSE file (apache#9701)

* Revert "[Review Required] Fixing Licenses: Cleaning up the Top Level LICENSE file (apache#9484)"

This reverts commit 8930d96.

* Some more LICENSE fixes

* Adding some more packages to the LICENSE file

* Adding dependencies of dependencies

* update navbar model zoo link (apache#9749)

* update navbar model zoo link

* update

* initial commit

* clean up

* refactor

* fix test
@eric-haibin-lin eric-haibin-lin deleted the updater branch September 2, 2019 23:35
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants