Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into arange
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Feb 27, 2019
2 parents d5ca06e + 0eed3da commit 18b620d
Show file tree
Hide file tree
Showing 13 changed files with 584 additions and 46 deletions.
10 changes: 5 additions & 5 deletions ci/jenkins/Jenkinsfile_windows_cpu
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ utils.assign_node_labels(utility: 'utility', windows_cpu: 'mxnetwindows-cpu')
utils.main_wrapper(
core_logic: {
utils.parallel_stage('Build', [
custom_steps.compile_windows_cpu()
// custom_steps.compile_windows_cpu()
])

utils.parallel_stage('Tests', [
custom_steps.test_windows_python2_cpu(),
custom_steps.test_windows_python3_cpu(),
custom_steps.test_windows_julia07_cpu(),
custom_steps.test_windows_julia10_cpu()
// custom_steps.test_windows_python2_cpu(),
// custom_steps.test_windows_python3_cpu(),
// custom_steps.test_windows_julia07_cpu(),
// custom_steps.test_windows_julia10_cpu()
])
}
,
Expand Down
10 changes: 5 additions & 5 deletions ci/jenkins/Jenkinsfile_windows_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ utils.assign_node_labels(utility: 'utility', windows_cpu: 'mxnetwindows-cpu', wi
utils.main_wrapper(
core_logic: {
utils.parallel_stage('Build', [
custom_steps.compile_windows_gpu(),
custom_steps.compile_windows_gpu_mkldnn()
// custom_steps.compile_windows_gpu(),
// custom_steps.compile_windows_gpu_mkldnn()
])

utils.parallel_stage('Tests', [
custom_steps.test_windows_python2_gpu(),
custom_steps.test_windows_python3_gpu(),
custom_steps.test_windows_python3_gpu_mkldnn()
// custom_steps.test_windows_python2_gpu(),
// custom_steps.test_windows_python3_gpu(),
// custom_steps.test_windows_python3_gpu_mkldnn()
])
}
,
Expand Down
6 changes: 6 additions & 0 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca

## Other Environment Variables

* MXNET_GPU_WORKER_NSTREAMS
- Values: 1, or 2 ```(default=1)```
- Determines the number of GPU streams available to operators for their functions.
- Setting this to 2 may yield a modest performance increase, since ops like the cuDNN convolution op can then calculate their data- and weight-gradients in parallel.
- Setting this to 2 may also increase a model's demand for GPU global memory.

* MXNET_CUDNN_AUTOTUNE_DEFAULT
- Values: 0, 1, or 2 ```(default=1)```
- The default value of cudnn auto tuning for convolution layers.
Expand Down
133 changes: 133 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ struct Context {
* \return The number of GPUs that are available.
*/
inline static int32_t GetGPUCount();
/*!
* Get the number of streams that a GPU Worker has available to operations.
* \return The number of streams that are available.
*/
inline static int32_t GetGPUStreamsPerWorker();
/*!
* \brief get the free and total available memory on a GPU
* \param dev the GPU number to query
Expand Down Expand Up @@ -221,6 +226,112 @@ struct Context {
inline static Context FromString(const std::string& str);
};

#if MXNET_USE_CUDA
/*! \brief Holds an auxiliary mshadow gpu stream that can be synced with a primary stream. */
class GPUAuxStream {
public:
/*!
* \brief constructor.
* \param primary_stream gpu stream that is synced with the created auxiliary stream.
*/
explicit GPUAuxStream(mshadow::Stream<gpu> *primary_stream) :
primary_stream_(primary_stream),
aux_stream_(primary_stream),
gpu_stream_sync_event_(nullptr) {
if (Context::GetGPUStreamsPerWorker() >= 2) {
// Create auxiliary stream on the same device with the same properties as the primary stream
bool primary_has_blas_handle =
primary_stream->blas_handle_ownership_ == mshadow::Stream<gpu>::OwnHandle;
bool primary_has_dnn_handle =
primary_stream->dnn_handle_ownership_ == mshadow::Stream<gpu>::OwnHandle;
aux_stream_ = mshadow::NewStream<gpu>(primary_has_blas_handle,
primary_has_dnn_handle,
primary_stream->dev_id);
MSHADOW_CUDA_CALL(cudaEventCreateWithFlags(&gpu_stream_sync_event_, cudaEventDisableTiming));
}
}
/*! \brief destructor */
~GPUAuxStream() {
// If the aux_stream_ == primary_stream_, then we created no new streams to destroy.
if (aux_stream_ != primary_stream_) {
MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(aux_stream_));
MSHADOW_CATCH_ERROR(cudaEventDestroy(gpu_stream_sync_event_));
}
}
/*!
* \brief Makes future aux stream work wait on the completion of existing primary stream work.
*/
void PreAuxStreamUseSync() {
// If the aux_stream_ == primary_stream_, then no synchronization is necessary.
if (aux_stream_ != primary_stream_)
StreamSync(primary_stream_, aux_stream_, gpu_stream_sync_event_);
}
/*!
* \brief Makes future primary stream work wait on the completion of existing aux stream work.
*/
void PostAuxStreamUseSync() {
// If the aux_stream_ == primary_stream_, then no synchronization is necessary.
if (aux_stream_ != primary_stream_)
StreamSync(aux_stream_, primary_stream_, gpu_stream_sync_event_);
}
/*! \brief Getter for created auxiliary stream. */
mshadow::Stream<gpu> *GetStream() { return aux_stream_; }
/*!
* \brief Make future work enqueued to `s2` wait on completion of current work enqueued to `s1`.
* \param s1 stream with work that must be completed before future s2 work can begin.
* \param s2 stream whose future work is made to wait on the completion of existing s1 work.
* \param event used to pass s1 state to s2.
*/
static void StreamSync(mshadow::Stream<gpu> *s1, mshadow::Stream<gpu> *s2, cudaEvent_t event) {
MSHADOW_CUDA_CALL(cudaEventRecord(event, s1->stream_));
MSHADOW_CUDA_CALL(cudaStreamWaitEvent(s2->stream_, event, 0));
}

private:
mshadow::Stream<gpu> *primary_stream_;
mshadow::Stream<gpu> *aux_stream_;
cudaEvent_t gpu_stream_sync_event_;
};

/*!
* \brief Provides automatic coordination of an auxilary stream with a primary one.
* This object, upon construction, prepares an aux stream for use by syncing it with enqueued
* primary-stream work. Object destruction will sync again so future primary-stream work
* will wait on enqueued aux-stream work. If MXNET_GPU_WORKER_NSTREAMS == 1, then this defaults
* simply: the primary stream will equal the aux stream and the syncs will be executed as nops.
* See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example.
*/
class SyncedGPUAuxStream {
public:
/*!
* \brief constructor.
* \param gpu_aux_stream auxilary gpu stream that is managed by this RAII object.
*/
explicit SyncedGPUAuxStream(GPUAuxStream *gpu_aux_stream) : gpu_aux_stream_(gpu_aux_stream) {
gpu_aux_stream_->PreAuxStreamUseSync();
}
/*! \brief destructor */
~SyncedGPUAuxStream() {
gpu_aux_stream_->PostAuxStreamUseSync();
}
/*! \brief copy constructor deleted to prevent unexpected synchronizations. */
SyncedGPUAuxStream(const SyncedGPUAuxStream&) = delete;
/*! \brief copy assignment operator deleted to prevent unexpected synchronizations. */
void operator=(const SyncedGPUAuxStream&) = delete;
/*! \brief move constructor permitted as alternative to copying. */
SyncedGPUAuxStream(SyncedGPUAuxStream&&) = default;
/*! \brief move assignment operator permitted as alternative to copy assignment. */
SyncedGPUAuxStream& operator=(SyncedGPUAuxStream&&) = default;
/*! \brief Getter for underlying mshadow::Stream<gpu>. */
inline mshadow::Stream<gpu>* GetStream() const {
return gpu_aux_stream_->GetStream();
}

private:
GPUAuxStream *gpu_aux_stream_;
};
#endif // MXNET_USE_CUDA

/*!
* \brief execution time context.
* The information needed in runtime for actual execution.
Expand All @@ -232,6 +343,10 @@ struct RunContext {
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
/*!
* \brief the auxiliary stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *aux_stream;
/*!
* \brief indicator of whether this execution is run in bulk mode
*/
Expand All @@ -245,6 +360,15 @@ struct RunContext {
inline mshadow::Stream<xpu>* get_stream() const {
return static_cast<mshadow::Stream<xpu>*>(stream);
}
#if MXNET_USE_CUDA
/*!
* \brief get an RAII object that transparently handles the syncing of the auxiliary stream.
* \return the aux stream auto-syncing object
*/
inline SyncedGPUAuxStream get_gpu_aux_stream() const {
return SyncedGPUAuxStream(static_cast<GPUAuxStream*>(aux_stream));
}
#endif
/*! \brief get the base Context from RunContext */
inline const Context& get_ctx() const {
return ctx;
Expand Down Expand Up @@ -309,6 +433,15 @@ inline int32_t Context::GetGPUCount() {
#endif
}

inline int32_t Context::GetGPUStreamsPerWorker() {
// The default number of streams available if the user has not set MXNET_GPU_WORKER_NSTREAMS.
const int32_t default_num_streams = 1;
// The get_aux_stream() interface can supply one additional stream beyond the standard one.
static int32_t num_streams =
dmlc::GetEnv("MXNET_GPU_WORKER_NSTREAMS", default_num_streams) >= 2 ? 2 : 1;
return num_streams;
}

inline void Context::GetGPUMemoryInformation(int dev, uint64_t *free_mem,
uint64_t *total_mem) {
#if MXNET_USE_CUDA
Expand Down
9 changes: 9 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ struct OpContext {
inline mshadow::Stream<xpu>* get_stream() const {
return run_ctx.get_stream<xpu>();
}
#if MXNET_USE_CUDA
/*!
* \brief get auxilary gpu stream auto-syncing object from Context
* \return the aux stream auto-syncing object
*/
inline SyncedGPUAuxStream get_gpu_aux_stream() const {
return run_ctx.get_gpu_aux_stream();
}
#endif
};

/*! \brief the execution type of the operator */
Expand Down
25 changes: 23 additions & 2 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class NaiveEngine final : public Engine {
};

NaiveEngine() {
objpool_opr_ref_ = common::ObjectPool<NaiveOpr>::_GetSharedRef();
objpool_var_ref_ = common::ObjectPool<NaiveVar>::_GetSharedRef();
}
// virtual destructor
virtual ~NaiveEngine() {
Expand All @@ -74,6 +76,12 @@ class NaiveEngine final : public Engine {
streams_[i] = nullptr;
}
}
for (size_t i = 0; i < aux_streams_.size(); ++i) {
if (aux_streams_[i] != nullptr) {
delete aux_streams_[i];
aux_streams_[i] = nullptr;
}
}
#endif
}

Expand Down Expand Up @@ -169,16 +177,18 @@ class NaiveEngine final : public Engine {
MSHADOW_CATCH_ERROR(mshadow::SetDevice<gpu>(exec_ctx.dev_id));
if (streams_.size() <= dev_id) {
streams_.resize(dev_id + 1, nullptr);
aux_streams_.resize(dev_id + 1, nullptr);
}
if (streams_[dev_id] == nullptr) {
streams_[dev_id] = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, dev_id);
aux_streams_[dev_id] = new GPUAuxStream(streams_[dev_id]);
}
exec_fun(RunContext{exec_ctx, streams_[dev_id]}, callback);
exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id], false}, callback);
#else
LOG(FATAL) << "GPU is not enabled";
#endif
} else {
exec_fun(RunContext{exec_ctx, &cpu_stream_}, callback);
exec_fun(RunContext{exec_ctx, &cpu_stream_, nullptr, false}, callback);
}
CHECK(this->req_completed_)
<< "NaiveEngine only support synchronize Push so far";
Expand Down Expand Up @@ -220,6 +230,17 @@ class NaiveEngine final : public Engine {
mshadow::Stream<cpu> cpu_stream_;
// GPU streams
std::vector<mshadow::Stream<gpu>*> streams_;
#if MXNET_USE_CUDA
// GPU auxiliary streams
std::vector<GPUAuxStream*> aux_streams_;
#endif
/*!
* \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early
* See also #309 (https://github.com/dmlc/mxnet/issues/309) and similar fix in threaded_engine.h.
* Without this, segfaults seen on CentOS7 in test_operator_gpu.py:test_convolution_multiple_streams
*/
std::shared_ptr<common::ObjectPool<NaiveOpr> > objpool_opr_ref_;
std::shared_ptr<common::ObjectPool<NaiveVar> > objpool_var_ref_;
}; // class NaiveEngine

Engine *CreateNaiveEngine() {
Expand Down
25 changes: 18 additions & 7 deletions src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class StreamManager {
#if MXNET_USE_CUDA
std::array<std::array<mshadow::Stream<gpu>*, kStreams>, kNumGpus>
gpu_streams_;
std::array<std::array<GPUAuxStream*, kStreams>, kNumGpus>
gpu_aux_streams_;
std::array<mshadow::Stream<gpu>*, kNumGpus> gpu_io_streams_;
std::array<int, kNumGpus> gpu_cnt_;
#endif // MXNET_USE_CUDA
Expand All @@ -67,7 +69,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr, false};
ret = RunContext{ctx, nullptr, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
Expand All @@ -77,8 +79,13 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
auto&& counter = gpu_cnt_.at(ctx.dev_id);
if (counter == -1) {
mxnet::common::cuda::DeviceStore device_store(ctx.dev_id);
for (auto&& i : gpu_streams_.at(ctx.dev_id)) {
i = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, ctx.dev_id);
for (auto&& primary_stream : gpu_streams_.at(ctx.dev_id)) {
primary_stream = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, ctx.dev_id);
}
int idx = 0;
for (auto&& aux_stream : gpu_aux_streams_.at(ctx.dev_id)) {
auto primary_stream = gpu_streams_.at(ctx.dev_id).at(idx++);
aux_stream = new GPUAuxStream(primary_stream);
}
counter = 0;
}
Expand All @@ -87,6 +94,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
}
ret = RunContext{ctx,
gpu_streams_.at(ctx.dev_id).at(use_counter),
gpu_aux_streams_.at(ctx.dev_id).at(use_counter),
false};
break;
#else
Expand All @@ -105,7 +113,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr, false};
ret = RunContext{ctx, nullptr, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
Expand All @@ -116,7 +124,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, false, ctx.dev_id);
}
}
ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), false};
ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), nullptr, false};
break;
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
Expand Down Expand Up @@ -145,9 +153,12 @@ void StreamManager<kNumGpus, kStreams>::Finalize() {
#if MXNET_USE_CUDA
for (std::size_t i = 0; i < kNumGpus; ++i) {
if (gpu_cnt_.at(i) != -1) {
for (auto&& j : gpu_streams_.at(i)) {
for (auto&& primary_stream : gpu_streams_.at(i)) {
// Catch exception for CUDA driver shutdown
MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(j));
MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(primary_stream));
}
for (auto&& aux_stream : gpu_aux_streams_.at(i)) {
delete aux_stream;
}
gpu_cnt_.at(i) = -1;
}
Expand Down
Loading

0 comments on commit 18b620d

Please sign in to comment.