diff --git a/src/common/random_generator.h b/include/mxnet/random_generator.h similarity index 92% rename from src/common/random_generator.h rename to include/mxnet/random_generator.h index 5d78b616e534..6e37efd40598 100644 --- a/src/common/random_generator.h +++ b/include/mxnet/random_generator.h @@ -22,16 +22,15 @@ * \file random_generator.h * \brief Parallel random number generator. */ -#ifndef MXNET_COMMON_RANDOM_GENERATOR_H_ -#define MXNET_COMMON_RANDOM_GENERATOR_H_ +#ifndef MXNET_RANDOM_GENERATOR_H_ +#define MXNET_RANDOM_GENERATOR_H_ -#include #include #include +#include "./base.h" #if MXNET_USE_CUDA #include -#include "../common/cuda_utils.h" #endif // MXNET_USE_CUDA namespace mxnet { @@ -50,6 +49,7 @@ class RandGenerator { static const int kNumRandomStates; // implementation class for random number generator + // TODO(alexzai): move impl class to separate file - tracked in MXNET-948 class Impl { public: typedef typename std::conditional::value, @@ -116,6 +116,7 @@ class RandGenerator { // by using 1.0-curand_uniform(). // Needed as some samplers in sampler.h won't be able to deal with // one of the boundary cases. + // TODO(alexzai): move impl class to separate file - tracked in MXNET-948 class Impl { public: Impl &operator=(const Impl &) = delete; @@ -150,14 +151,9 @@ class RandGenerator { curandStatePhilox4_32_10_t state_; }; // class RandGenerator::Impl - static void AllocState(RandGenerator *inst) { - CUDA_CALL(cudaMalloc(&inst->states_, - kNumRandomStates * sizeof(curandStatePhilox4_32_10_t))); - } + static void AllocState(RandGenerator *inst); - static void FreeState(RandGenerator *inst) { - CUDA_CALL(cudaFree(inst->states_)); - } + static void FreeState(RandGenerator *inst); void Seed(mshadow::Stream *s, uint32_t seed); @@ -172,6 +168,7 @@ class RandGenerator { // by using 1.0-curand_uniform(). // Needed as some samplers in sampler.h won't be able to deal with // one of the boundary cases. + // TODO(alexzai): move impl class to separate file - tracked in MXNET-948 class Impl { public: Impl &operator=(const Impl &) = delete; @@ -215,4 +212,4 @@ class RandGenerator { } // namespace random } // namespace common } // namespace mxnet -#endif // MXNET_COMMON_RANDOM_GENERATOR_H_ +#endif // MXNET_RANDOM_GENERATOR_H_ diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h index 74ae7e321af6..67c14b66abdd 100644 --- a/include/mxnet/resource.h +++ b/include/mxnet/resource.h @@ -28,7 +28,7 @@ #include #include "./base.h" #include "./engine.h" -#include "../../src/common/random_generator.h" +#include "./random_generator.h" namespace mxnet { diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu index 930e5e07b89a..a2d3e0d911e3 100644 --- a/src/common/random_generator.cu +++ b/src/common/random_generator.cu @@ -23,8 +23,8 @@ * \brief gpu implements for parallel random number generator. */ +#include #include -#include "./random_generator.h" #include "../operator/mxnet_op.h" namespace mxnet { @@ -59,6 +59,17 @@ void RandGenerator::Seed(mshadow::Stream *s, uint32_t seed) { s->Wait(); } +template<> +void RandGenerator::AllocState(RandGenerator *inst) { + CUDA_CALL(cudaMalloc(&inst->states_, + kNumRandomStates * sizeof(curandStatePhilox4_32_10_t))); +} + +template<> +void RandGenerator::FreeState(RandGenerator *inst) { + CUDA_CALL(cudaFree(inst->states_)); +} + } // namespace random } // namespace common } // namespace mxnet diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h index 1c4f48b32ed2..1d2baa4b6c3f 100644 --- a/src/operator/leaky_relu-inl.h +++ b/src/operator/leaky_relu-inl.h @@ -28,13 +28,13 @@ #include #include +#include #include #include #include #include #include #include -#include "../common/random_generator.h" #include "./operator_common.h" #include "./mshadow_op.h" #include "./random/sampler.h" diff --git a/src/resource.cc b/src/resource.cc index 2794d48f85bf..ba4ab7270bdb 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -27,12 +27,12 @@ #include #include #include +#include #include #include #include #include #include "./common/lazy_alloc_array.h" -#include "./common/random_generator.h" #include "./common/utils.h" namespace mxnet {