Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[RFC] A faster version of Gamma sampling on GPU. #15928

Closed
xidulu opened this issue Aug 16, 2019 · 11 comments
Closed

[RFC] A faster version of Gamma sampling on GPU. #15928

xidulu opened this issue Aug 16, 2019 · 11 comments
Labels
Backend Issues related to the backend of MXNet Operator RFC Post requesting for comments

Comments

@xidulu
Copy link
Contributor

xidulu commented Aug 16, 2019

Description

Sampling from Gamma distribution requires rejection sampling, which is applied in the implementation of mxnet.ndarray.random.gamma(). However, two main drawbacks exist in the current implementation ( sampler.h )

  1. Random numbers used in the rejection sampling ( N(0,1) and U(0,1) ) are generated inside the kernel using CUDA device api. Also, although every batch of threads has its own RNG, samples are actually generated in serial inside each batch of threads.

  2. Rejection sampling is achieved by using an infinite while loop inside the kernel, which may potentially affect the performance on GPU.

To solve the problems above, I write a new version of Gamma sampling on GPU innovated by this blog post

Implementation details

My implementation differs from the current version in the following aspects:

  1. Instead of generating samples in the kernel, we generate them in advance using host api, which allows us to fill a buffer with random samples directly.

  2. Redundant samples are generated to replace the while loop. Suppose we are going to generate a Gamma tensor of size (N,), N x (M + 1) zero-one gaussian samples and N x (M + 1) zero-one uniform samples will be generated before entering the kernel, where M is a predefined const. For each entity, we generate M proposed Gamma r.v. and then select the first accepted one as the output. The one extra sample is required when \alpha is less than one.

  3. In case all M proposed samples get rejected in some entities (which would be marked as -1), we simply resample the random buffer again and perform another round of rejection sampling, but only at the entities that fail the last round.

Here's part of the implementation :
(https://gist.github.com/xidulu/cd9da21f2ecbccd9b784cadd67844e23)

In my experiment, I set M to be 1 ( i.e. no redundant samples are generated.) as the adopted policy(Marsaglia and Tsang's method) has a rather high acceptance rate of around 98%.

The profiling result is listed below:

Size native numpy ndarray on GPU my implementation
10e2 <0.1ms 3~5ms 0.5~0.7ms
10e4 0.76ms 7.6~7.8ms 0.72~0.76ms
10e6 70ms 12~13ms 3.1ms
10e8 7200ms 1600~1700ms 150~160ms

The new version is currently under development on numpy branch. It also designed to support broadcastable parameters.


Correctness has been verified with the following script:

import numpy as _np
import mxnet as mx
from mxnet import np, npx
from mxnet.base import MXNetError
from mxnet.gluon import HybridBlock
from mxnet.base import MXNetError
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
from mxnet.test_utils import check_numeric_gradient, use_np
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry
import scipy.stats as ss
import random

samples = 1000000
trials = 8
num_buckets = 5
ctx = npx.gpu(0)

for alpha, beta in [(2.0, 3.0), (0.5, 1.0)]:
    buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.gamma.ppf(x, a=alpha, loc=0, scale=beta), num_buckets)
    buckets = np.array(buckets).tolist()
    generator_mx = lambda x: np.random.gamma(alpha, beta, size=samples, ctx=ctx).asnumpy()
    verify_generator(generator=generator_mx, buckets=buckets, probs=probs,
                        nsamples=samples, nrepeat=trials)
    generator_mx_same_seed =\
        lambda x: _np.concatenate(
            [np.random.gamma(alpha, beta, size=(x // 10), ctx=ctx).asnumpy()
                for _ in range(10)])
    verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs,
                        nsamples=samples, nrepeat=trials)


UPDATE:
ndarray on GPU is actually 50 times faster than I profiled, as I did not turn on the release option during my testing, which prevented the CUDA compiler from further optimization.
In short, my implementation is not fast. However, features like broadcastable parameters and global random seed mechanism still remain useful.
I will be looking into the performance issue in the following weeks. Hope I could find a solution.

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Performance

@yzhliu
Copy link
Member

yzhliu commented Aug 16, 2019

cc @apache/mxnet-committers I think we can gradually refactor current implementation (ndarray api) by adopting this new approach.

@xidulu could you please fix the url links in your post.

@xidulu
Copy link
Contributor Author

xidulu commented Aug 16, 2019

cc @apache/mxnet-committers I think we can gradually refactor current implementation (ndarray api) by adopting this new approach.

@xidulu could you please fix the url links in your post.

Links fixed.

@ptrendx
Copy link
Member

ptrendx commented Aug 16, 2019

Hi @xidulu. I did not look at the differences in the implementation of host-side vs device-side API for RNG in MXNet, but if they are comparable in terms of performance, a possible better approach would be something like this:

  • launch only as many blocks and threads as necessary to fill the GPU, each having their own RNG
  • use following pseudocode
while(my_sample_id < N_samples) {
  float rng = generate_next_rng();
  bool accepted = ... // compute whether this rng value is accepted
  if (accepted) {
    // write the result
    my_sample_id = next_sample();
  }
}

There are 2 ways of implementing next_sample here - either by atomicInc on some global counter or just by adding the total number of threads (so every thread processes the same number of samples). The atomic approach is potentially faster (as with the static assignment you could end up hitting a corner case where 1 thread would still do a lot more work than the other threads), but is nondeterministic, so I think static assignment is preferable here.

@zachgk zachgk added Backend Issues related to the backend of MXNet Operator RFC Post requesting for comments labels Aug 16, 2019
@xidulu
Copy link
Contributor Author

xidulu commented Aug 17, 2019

Hi @ptrendx , thanks for your reply, according to my discussion with @yzhliu , device-side API is much slower than host-side API.

Also, could you please talk a little bit about the advantage of your approach compared with mine? thx :)

@yzhliu
Copy link
Member

yzhliu commented Aug 17, 2019

@ptrendx If I understand correctly, "static assignment" is what current mxnet is doing, which is "ndarray on GPU" in @xidulu 's table.

@ptrendx
Copy link
Member

ptrendx commented Aug 18, 2019

@yzhliu No. What MXNet currently does is a scheme where, yes, each thread gets assigned statically some number of elements, but it has a while loop for each of them. The scheme I proposed has a single while loop that processes all elements assigned to a given thread. There is a big difference between these approaches, due to SIMT architecture of the GPU. Basically you can treat some number of threads (called warp, 32 threads on NVIDIA's GPU) as lanes in SIMD vector instruction on the CPU. This means that if 1 thread needs to perform some computation, all threads in the warp need to perform the same instruction (and possibly discard the result).
So in the current MXNet's implementation for each output element every group of 32 threads is always doing the number of loop iterations equal to the slowest thread (because no thread in warp can exit the while loop while at least 1 thread is still not finished).
In the proposed implementation there is only 1 while loop and the only difference between threads lies inside the if (accepted) part, which is cheap compared to generating a random number. In this implementation every warp does the number of loop iterations equal to sum of the steps for the slowest thread (which is hopefully pretty uniform across threads, especially as we are talking RNG and not some crafted input, and definitely much better than the previous "for each element take the slowest and sum that").

@xidulu What is the RNG used for host-side and device-side API? cuRAND ones should not really differ much in perf between device-side and host-side.
There are a few advantages:

  • you don't need to store and load the RNG numbers you made (and in the fully optimized case making random numbers should actually be pretty bandwidth-limited operation)
  • you don't need additional storage (besides the RNG generator state which you need anyway)
  • you compute only as many RNG numbers as you really need

@xidulu
Copy link
Contributor Author

xidulu commented Aug 19, 2019

@ptrendx

The device-side api I mentioned is the RandGenerator class. (the one used in ndarray.random()), it generates random number with curand_uniform():
https://github.com/apache/incubator-mxnet/blob/master/include/mxnet/random_generator.h#L111

Host api can be seen here (the one I used)
https://github.com/apache/incubator-mxnet/blob/master/3rdparty/mshadow/mshadow/random.h#L370
Random numbers are generated with curandGenerateUniform()

In terms of random number generation, RandGenerator (which is basically a wrapper over the CUDA device api, IMO) may be comparable to mshadow/random.
However, is it possible that the overhead of managing random states in RandGenerator affects its performance ?


Update:

To find out the bottleneck of ndarray.random(), I remove the while loop in the kernel: https://github.com/apache/incubator-mxnet/blob/fb4f9d55382538fe688638b741830d84ae0d783e/src/operator/random/sampler.h#L183

The new version becomes ten times faster than the origin one: 160ms V.S 1600ms at size 10e7. (of course, some samples are not sampled correctly).


A few words about additional storage:

In my experiment, I tracked the GPU memory usage with watch -d -n 0.5 nvidia-smi (the method may be problematic), I discovered that my method, though explicitly requested for extra storage, only consumed an acceptable amount of extra memory in practice. ndarray.random.gamma() used around 2400Mb while my method used around 2500Mb when sampling 10e7 samples.

@yzhliu
Copy link
Member

yzhliu commented Aug 19, 2019

@ptrendx Thanks now I got what you mean. I'm open to what you proposed. while I think one of the major problems with the device api is the maintenance of the random generator (and it's states).

@xidulu
Copy link
Contributor Author

xidulu commented Sep 4, 2019

@ptrendx @yzhliu
I will create a PR for np.random.gamma implemented using the method I proposed before the end of the week, as I need to proceed to implement more distribution samplers, in which the gamma sampler serves as a necessity.
Refactoring nd.random may be left for further discussion.


(P.S. Personally speaking, I think performance is not the only problem in nd.random, the poor support for broadcasting parameters and shapes may bring trouble to heavy users of the random sampling module. )

@szha
Copy link
Member

szha commented Jul 31, 2020

This feature has been accepted. Thanks @xidulu for the contribution!

@szha szha closed this as completed Jul 31, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Backend Issues related to the backend of MXNet Operator RFC Post requesting for comments
Projects
None yet
Development

No branches or pull requests

6 participants