Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix/public internal header #12374

Merged
merged 19 commits into from
Sep 13, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@
* \file random_generator.h
* \brief Parallel random number generator.
*/
#ifndef MXNET_COMMON_RANDOM_GENERATOR_H_
#define MXNET_COMMON_RANDOM_GENERATOR_H_
#ifndef MXNET_RANDOM_GENERATOR_H_
#define MXNET_RANDOM_GENERATOR_H_

#include <mxnet/base.h>
#include <random>
#include <new>
#include "./base.h"

#if MXNET_USE_CUDA
#include <curand_kernel.h>
#include "../common/cuda_utils.h"
#endif // MXNET_USE_CUDA

namespace mxnet {
Expand All @@ -50,6 +49,7 @@ class RandGenerator<cpu, DType> {
static const int kNumRandomStates;

// implementation class for random number generator
// TODO(alexzai): move impl class to separate file - tracked in MXNET-948
class Impl {
public:
typedef typename std::conditional<std::is_floating_point<DType>::value,
Expand Down Expand Up @@ -116,6 +116,7 @@ class RandGenerator<gpu, DType> {
// by using 1.0-curand_uniform().
// Needed as some samplers in sampler.h won't be able to deal with
// one of the boundary cases.
// TODO(alexzai): move impl class to separate file - tracked in MXNET-948
class Impl {
public:
Impl &operator=(const Impl &) = delete;
Expand Down Expand Up @@ -150,14 +151,9 @@ class RandGenerator<gpu, DType> {
curandStatePhilox4_32_10_t state_;
}; // class RandGenerator<gpu, DType>::Impl

static void AllocState(RandGenerator<gpu, DType> *inst) {
CUDA_CALL(cudaMalloc(&inst->states_,
kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
}
static void AllocState(RandGenerator<gpu, DType> *inst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we do this differently with CPU implementation?

Copy link
Contributor Author

@azai91 azai91 Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? we are calling cudaMalloc in this case instead of just "new"

Copy link
Contributor

@apeforest apeforest Aug 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my question is not well phrased. What I meant is why move out from header file while leaving the same function for CPU in the header

See line: https://github.com/apache/incubator-mxnet/pull/12374/files#diff-ba5bcd7d0b76b85a2df1f793dc4d3302R82

Aside from that, I think these functions are inside the inner class Impl which is supposed to handle all the implementation. Therefore I think it is very logical to leave them here in the header file. Not to mention the performance advantage of calling inline function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. these cuda calls depends on cuda_utils.h which is in the commons folder. I am not sure if we want to expose those. thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apeforest are you indicating about perf advantages during compile time ? I think its okay to put the definition in random_generator.cu since we don't want to expose cuda_utils.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apeforest any thoughts on what to revise? if we want to be consistent we could create a random_generator.cc file and put the non-cuda implementations in there, but I personally think that is unnecessary for such small functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@azai91 Thanks for your efforts in refactoring this. Since the main purpose of this PR is refactoring, I hope we can do it in the most elegant way. If we are making this file a public header, I would extract the implementation class Impl to another internal header file so that we do not expose the internal implementation details. By doing that, we can have put the implementation details for both CPU and GPU there. Please let me know your thoughts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that they way we are using the random_generator, we expect the developer to be able to access the internal Impl class (https://github.com/azai91/incubator-mxnet/blob/6f7254c91709904a9fb6290f1998fcf2da818d0e/src/operator/random/unique_sample_op.h#L118).

the class is is public as well. I may be interpretting the developers intentions incorrectly though.

Copy link
Contributor

@apeforest apeforest Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the developers did not intend to directly access internal Impl class but was rather lacking proper interface API. In fact, throughout the repository, I only found one place that uses RandGenerator<cpu, GType>::Impl` directly, and therefore I think it's not a big overhead to refactor that one line of code. Regarding how to separating Impl class from the interface, you might find this reference helpful: https://cpppatterns.com/patterns/pimpl.html


static void FreeState(RandGenerator<gpu, DType> *inst) {
CUDA_CALL(cudaFree(inst->states_));
}
static void FreeState(RandGenerator<gpu, DType> *inst);

void Seed(mshadow::Stream<gpu> *s, uint32_t seed);

Expand All @@ -172,6 +168,7 @@ class RandGenerator<gpu, double> {
// by using 1.0-curand_uniform().
// Needed as some samplers in sampler.h won't be able to deal with
// one of the boundary cases.
// TODO(alexzai): move impl class to separate file - tracked in MXNET-948
class Impl {
public:
Impl &operator=(const Impl &) = delete;
Expand Down Expand Up @@ -215,4 +212,4 @@ class RandGenerator<gpu, double> {
} // namespace random
} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_RANDOM_GENERATOR_H_
#endif // MXNET_RANDOM_GENERATOR_H_
2 changes: 1 addition & 1 deletion include/mxnet/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <dmlc/logging.h>
#include "./base.h"
#include "./engine.h"
#include "../../src/common/random_generator.h"
#include "./random_generator.h"

namespace mxnet {

Expand Down
13 changes: 12 additions & 1 deletion src/common/random_generator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
* \brief gpu implements for parallel random number generator.
*/

#include <mxnet/random_generator.h>
#include <algorithm>
#include "./random_generator.h"
#include "../operator/mxnet_op.h"

namespace mxnet {
Expand Down Expand Up @@ -59,6 +59,17 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) {
s->Wait();
}

template<>
void RandGenerator<gpu, float>::AllocState(RandGenerator<gpu> *inst) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt this remove support for other DTypes like half_t ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the random generator always generates floats between 0 and 1, we scale them as needed in wherever the random generator is used

CUDA_CALL(cudaMalloc(&inst->states_,
kNumRandomStates * sizeof(curandStatePhilox4_32_10_t)));
}

template<>
void RandGenerator<gpu, float>::FreeState(RandGenerator<gpu> *inst) {
CUDA_CALL(cudaFree(inst->states_));
}

} // namespace random
} // namespace common
} // namespace mxnet
2 changes: 1 addition & 1 deletion src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/random_generator.h>
#include <mxnet/operator.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>
#include <utility>
#include "../common/random_generator.h"
#include "./operator_common.h"
#include "./mshadow_op.h"
#include "./random/sampler.h"
Expand Down
2 changes: 1 addition & 1 deletion src/resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
#include <dmlc/thread_local.h>
#include <mxnet/base.h>
#include <mxnet/engine.h>
#include <mxnet/random_generator.h>
#include <mxnet/resource.h>
#include <mxnet/storage.h>
#include <limits>
#include <atomic>
#include "./common/lazy_alloc_array.h"
#include "./common/random_generator.h"
#include "./common/utils.h"

namespace mxnet {
Expand Down