diff --git a/include/mxnet/random_generator.h b/include/mxnet/random_generator.h index 49c1b2dddd49..db1daf638d1b 100644 --- a/include/mxnet/random_generator.h +++ b/include/mxnet/random_generator.h @@ -79,7 +79,7 @@ class RandGenerator { std::mt19937 *engine_; }; - static void AllocState(RandGenerator *inst) { + static void AllocState(RandGenerator *inst) { inst->states_ = new std::mt19937[kNumRandomStates]; } diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu index ec7bedb120f4..b3a225184d22 100644 --- a/src/common/random_generator.cu +++ b/src/common/random_generator.cu @@ -59,14 +59,14 @@ void RandGenerator::Seed(mshadow::Stream *s, uint32_t seed) { s->Wait(); } -template -void RandGenerator::AllocState(RandGenerator *inst) { +template +void RandGenerator::AllocState(RandGenerator *inst) { CUDA_CALL(cudaMalloc(&inst->states_, kNumRandomStates * sizeof(curandStatePhilox4_32_10_t))); } -template -void RandGenerator::FreeState(RandGenerator *inst) { +template<> +void RandGenerator::FreeState(RandGenerator *inst) { CUDA_CALL(cudaFree(inst->states_)); }