Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix template
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Aug 28, 2018
1 parent 786af1b commit 556f83d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion include/mxnet/random_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class RandGenerator<cpu, DType> {
std::mt19937 *engine_;
};

static void AllocState(RandGenerator<cpu, DType> *inst) {
static void AllocState(RandGenerator<cpu, DType> *inst) {
inst->states_ = new std::mt19937[kNumRandomStates];
}

Expand Down
8 changes: 4 additions & 4 deletions src/common/random_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) {
s->Wait();
}

template<typename DType>
void RandGenerator<gpu, DType>::AllocState(RandGenerator<gpu, DType> *inst) {
template<typename>
void RandGenerator<gpu, float>::AllocState(RandGenerator<gpu> *inst) {
CUDA_CALL(cudaMalloc(&inst->states_,
kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
}

template<typename DType>
void RandGenerator<gpu, DType>::FreeState(RandGenerator<gpu, DType> *inst) {
template<>
void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) {
CUDA_CALL(cudaFree(inst->states_));
}

Expand Down

0 comments on commit 556f83d

Please sign in to comment.