-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[RFC] A faster version of Gamma sampling on GPU. #15928
Comments
Hey, this is the MXNet Label Bot. |
cc @apache/mxnet-committers I think we can gradually refactor current implementation (ndarray api) by adopting this new approach. @xidulu could you please fix the url links in your post. |
Links fixed. |
Hi @xidulu. I did not look at the differences in the implementation of host-side vs device-side API for RNG in MXNet, but if they are comparable in terms of performance, a possible better approach would be something like this:
There are 2 ways of implementing |
@yzhliu No. What MXNet currently does is a scheme where, yes, each thread gets assigned statically some number of elements, but it has a while loop for each of them. The scheme I proposed has a single while loop that processes all elements assigned to a given thread. There is a big difference between these approaches, due to SIMT architecture of the GPU. Basically you can treat some number of threads (called warp, 32 threads on NVIDIA's GPU) as lanes in SIMD vector instruction on the CPU. This means that if 1 thread needs to perform some computation, all threads in the warp need to perform the same instruction (and possibly discard the result). @xidulu What is the RNG used for host-side and device-side API? cuRAND ones should not really differ much in perf between device-side and host-side.
|
The device-side api I mentioned is the Host api can be seen here (the one I used) In terms of random number generation, Update: To find out the bottleneck of The new version becomes ten times faster than the origin one: 160ms V.S 1600ms at size 10e7. (of course, some samples are not sampled correctly). A few words about additional storage: In my experiment, I tracked the GPU memory usage with |
@ptrendx Thanks now I got what you mean. I'm open to what you proposed. while I think one of the major problems with the device api is the maintenance of the random generator (and it's states). |
@ptrendx @yzhliu (P.S. Personally speaking, I think performance is not the only problem in |
This feature has been accepted. Thanks @xidulu for the contribution! |
Description
Sampling from Gamma distribution requires rejection sampling, which is applied in the implementation of
mxnet.ndarray.random.gamma()
. However, two main drawbacks exist in the current implementation ( sampler.h )Random numbers used in the rejection sampling ( N(0,1) and U(0,1) ) are generated inside the kernel using CUDA device api. Also, although every batch of threads has its own RNG, samples are actually generated in serial inside each batch of threads.
Rejection sampling is achieved by using an infinite while loop inside the kernel, which may potentially affect the performance on GPU.
To solve the problems above, I write a new version of Gamma sampling on GPU innovated by this blog post
Implementation details
My implementation differs from the current version in the following aspects:
Instead of generating samples in the kernel, we generate them in advance using host api, which allows us to fill a buffer with random samples directly.
Redundant samples are generated to replace the while loop. Suppose we are going to generate a Gamma tensor of size (N,), N x (M + 1) zero-one gaussian samples and N x (M + 1) zero-one uniform samples will be generated before entering the kernel, where M is a predefined const. For each entity, we generate M proposed Gamma r.v. and then select the first accepted one as the output. The one extra sample is required when \alpha is less than one.
In case all M proposed samples get rejected in some entities (which would be marked as -1), we simply resample the random buffer again and perform another round of rejection sampling, but only at the entities that fail the last round.
Here's part of the implementation :
(https://gist.github.com/xidulu/cd9da21f2ecbccd9b784cadd67844e23)
In my experiment, I set M to be 1 ( i.e. no redundant samples are generated.) as the adopted policy(Marsaglia and Tsang's method) has a rather high acceptance rate of around 98%.
The profiling result is listed below:
The new version is currently under development on
numpy
branch. It also designed to support broadcastable parameters.Correctness has been verified with the following script:
UPDATE:
ndarray on GPU is actually 50 times faster than I profiled, as I did not turn on the release option during my testing, which prevented the CUDA compiler from further optimization.
In short, my implementation is not fast. However, features like broadcastable parameters and global random seed mechanism still remain useful.
I will be looking into the performance issue in the following weeks. Hope I could find a solution.
The text was updated successfully, but these errors were encountered: