Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1107] Fix CPUPinned unexpected behaviour #12031

Merged
merged 20 commits into from
Oct 19, 2018
2 changes: 1 addition & 1 deletion include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ struct Context {
* \brief Returns dev_id for kGPU, 0 otherwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also fix the documentation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

*/
inline int real_dev_id() const {
if (dev_type == kGPU) return dev_id;
if (dev_type == kCPUPinned || dev_type == kGPU) return dev_id;
return 0;
}
/*!
Expand Down
229 changes: 129 additions & 100 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,110 @@ inline __device__ bool __is_supported_cuda_architecture() {
}
#endif // __CUDACC__

/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected CUDA call.
* \param func Expression to call.
*
* It checks for CUDA errors after invocation of the expression.
*/
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected cuBLAS call.
* \param func Expression to call.
*
* It checks for cuBLAS errors after invocation of the expression.
*/
#define CUBLAS_CALL(func) \
{ \
cublasStatus_t e = (func); \
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
<< "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
}

/*!
* \brief Protected cuSolver call.
* \param func Expression to call.
*
* It checks for cuSolver errors after invocation of the expression.
*/
#define CUSOLVER_CALL(func) \
{ \
cusolverStatus_t e = (func); \
CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
<< "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
}

/*!
* \brief Protected cuRAND call.
* \param func Expression to call.
*
* It checks for cuRAND errors after invocation of the expression.
*/
#define CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
}

/*!
* \brief Protected NVRTC call.
* \param func Expression to call.
*
* It checks for NVRTC errors after invocation of the expression.
*/
#define NVRTC_CALL(x) \
{ \
nvrtcResult result = x; \
CHECK_EQ(result, NVRTC_SUCCESS) \
<< #x " failed with error " \
<< nvrtcGetErrorString(result); \
}

/*!
* \brief Protected CUDA driver call.
* \param func Expression to call.
*
* It checks for CUDA driver errors after invocation of the expression.
*/
#define CUDA_DRIVER_CALL(func) \
{ \
CUresult e = (func); \
if (e != CUDA_SUCCESS) { \
char const * err_msg = nullptr; \
if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
} else { \
LOG(FATAL) << "CUDA Driver: " << err_msg; \
} \
} \
}


#if !defined(_MSC_VER)
#define CUDA_UNROLL _Pragma("unroll")
#define CUDA_NOUNROLL _Pragma("nounroll")
#else
#define CUDA_UNROLL
#define CUDA_NOUNROLL
#endif

namespace mxnet {
namespace common {
/*! \brief common utils for cuda */
Expand Down Expand Up @@ -179,113 +283,38 @@ inline DType __device__ CudaMin(DType a, DType b) {
return a < b ? a : b;
}

} // namespace cuda
} // namespace common
} // namespace mxnet

/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected CUDA call.
* \param func Expression to call.
*
* It checks for CUDA errors after invocation of the expression.
*/
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected cuBLAS call.
* \param func Expression to call.
*
* It checks for cuBLAS errors after invocation of the expression.
*/
#define CUBLAS_CALL(func) \
{ \
cublasStatus_t e = (func); \
CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
<< "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
}

/*!
* \brief Protected cuSolver call.
* \param func Expression to call.
*
* It checks for cuSolver errors after invocation of the expression.
*/
#define CUSOLVER_CALL(func) \
{ \
cusolverStatus_t e = (func); \
CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
<< "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
}

/*!
* \brief Protected cuRAND call.
* \param func Expression to call.
*
* It checks for cuRAND errors after invocation of the expression.
*/
#define CURAND_CALL(func) \
{ \
curandStatus_t e = (func); \
CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
<< "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
class SetDevice {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to DeviceScope, add a member function SetDevice, and remove all occurrences of cudaGetDevice?

public:
/*! \brief default constructor only restores previous device upon going out of scope */
SetDevice() {
#if MXNET_USE_CUDA
CUDA_CALL(cudaGetDevice(&restore_device_));
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
}

/*!
* \brief Protected NVRTC call.
* \param func Expression to call.
*
* It checks for NVRTC errors after invocation of the expression.
*/
#define NVRTC_CALL(x) \
{ \
nvrtcResult result = x; \
CHECK_EQ(result, NVRTC_SUCCESS) \
<< #x " failed with error " \
<< nvrtcGetErrorString(result); \
/*! \brief standard constuctor is cudaSetDevice + restore previous device */
explicit SetDevice(int device) {
#if MXNET_USE_CUDA
CUDA_CALL(cudaGetDevice(&restore_device_));
CUDA_CALL(cudaSetDevice(device));
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
}

/*!
* \brief Protected CUDA driver call.
* \param func Expression to call.
*
* It checks for CUDA driver errors after invocation of the expression.
*/
#define CUDA_DRIVER_CALL(func) \
{ \
CUresult e = (func); \
if (e != CUDA_SUCCESS) { \
char const * err_msg = nullptr; \
if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
} else { \
LOG(FATAL) << "CUDA Driver: " << err_msg; \
} \
} \
~SetDevice() {
CUDA_CALL(cudaSetDevice(restore_device_));
}

private:
int restore_device_;
};

#if !defined(_MSC_VER)
#define CUDA_UNROLL _Pragma("unroll")
#define CUDA_NOUNROLL _Pragma("nounroll")
#else
#define CUDA_UNROLL
#define CUDA_NOUNROLL
#endif
} // namespace cuda
} // namespace common
} // namespace mxnet

/*!
* \brief Determine major version number of the gpu's cuda compute architecture.
Expand Down
1 change: 1 addition & 0 deletions src/common/rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ CUfunction CudaModule::Chunk::GetFunction(
CHECK_EQ(ctx.dev_mask(), Context::kGPU)
<< "CUDA Runtime compilation only supports Nvidia GPU.";
auto iter = mod_.find(ctx.dev_id);
mxnet::common::cuda::SetDevice set_device;
CUmodule module;
if (iter != mod_.end()) {
module = iter->second;
Expand Down
3 changes: 2 additions & 1 deletion src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "./engine_impl.h"
#include "../profiler/profiler.h"
#include "./openmp.h"
#include "../common/cuda_utils.h"

namespace mxnet {
namespace engine {
Expand Down Expand Up @@ -149,7 +150,7 @@ class NaiveEngine final : public Engine {
if (exec_ctx.dev_mask() == gpu::kDevMask) {
#if MXNET_USE_CUDA
size_t dev_id = static_cast<size_t>(exec_ctx.dev_id);
MSHADOW_CATCH_ERROR(mshadow::SetDevice<gpu>(exec_ctx.dev_id));
mxnet::common::cuda::SetDevice set_device(exec_ctx.dev_id);
if (streams_.size() <= dev_id) {
streams_.resize(dev_id + 1, nullptr);
}
Expand Down
6 changes: 6 additions & 0 deletions src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ template <std::size_t kNumGpus, std::size_t kStreams>
RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
Context const& ctx) {
RunContext ret;
#if MXNET_USE_CUDA
mxnet::common::cuda::SetDevice set_device;
#endif
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
Expand Down Expand Up @@ -101,6 +104,9 @@ template <std::size_t kNumGpus, std::size_t kStreams>
RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
Context const& ctx) {
RunContext ret;
#if MXNET_USE_CUDA
mxnet::common::cuda::SetDevice set_device;
#endif
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
Expand Down
4 changes: 4 additions & 0 deletions src/engine/threaded_engine_perdevice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
protected:
void PushToExecute(OprBlock *opr_block, bool pusher_thread) override {
const Context& ctx = opr_block->ctx;
#if MXNET_USE_CUDA
mxnet::common::cuda::SetDevice set_device;
#endif
if ((opr_block->opr->prop == FnProperty::kAsync ||
opr_block->opr->prop == FnProperty::kDeleteVar) && pusher_thread) {
if (ctx.dev_mask() == Context::kGPU) {
Expand Down Expand Up @@ -245,6 +248,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
#if MXNET_USE_CUDA
CHECK(block != nullptr);
mshadow::Stream<gpu> *stream;
mxnet::common::cuda::SetDevice set_device;
do {
ThreadPool::SetReadyOnDestroy setReady(ready_event);
// allocate stream
Expand Down
2 changes: 1 addition & 1 deletion src/engine/threaded_engine_pooled.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class ThreadedEnginePooled : public ThreadedEngine {
assert(opr_block->wait.load() == 0);
if (opr_block->ctx.dev_mask() == gpu::kDevMask) {
#if MXNET_USE_CUDA
CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id));
mxnet::common::cuda::SetDevice set_device(opr_block->ctx.dev_id);
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
Expand Down
3 changes: 3 additions & 0 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,9 @@ class CommDevice : public Comm {
int n = static_cast<int>(gpus.size());
int enabled = 0;
std::vector<int> p2p(n*n);

// Restores active device to what it was before EnableP2P
mxnet::common::cuda::SetDevice set_device;
for (int i = 0; i < n; ++i) {
cudaSetDevice(gpus[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @eric-haibin-lin suggested, replace this with a class method by your RAII object SetDevice (before renaming)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace this with a class method by the SetDevice (renaming) object as @eric-haibin-lin suggested.

for (int j = 0; j < n; j++) {
Expand Down
1 change: 1 addition & 0 deletions src/kvstore/comm_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ class CommDeviceTree : public CommDevice {
int n = static_cast<int>(gpus.size());
int enabled = 0;
std::vector<int> p2p(n*n);
mxnet::common::cuda::SetDevice set_device;
for (int i = 0; i < n; ++i) {
cudaSetDevice(gpus[i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

for (int j = 0; j < n; j++) {
Expand Down
2 changes: 2 additions & 0 deletions src/kvstore/kvstore_nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ class KVStoreNCCL : public KVStoreLocal {
mutate_vars.push_back(ptr(dst[i])->var());
}
Engine::Get()->PushSync([this](RunContext rctx) {
mxnet::common::cuda::SetDevice set_device;
for (auto cur : nccl_data_) {
CUDA_CALL(cudaSetDevice(cur.second.dev_id));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

CUDA_CALL(cudaStreamSynchronize(cur.second.stream));
Expand Down Expand Up @@ -479,6 +480,7 @@ class KVStoreNCCL : public KVStoreLocal {
std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
std::vector<ncclComm_t> comms(devs.size());
ncclCommInitAll(&(comms[0]), devs.size(), &(device_ids_[0]));
mxnet::common::cuda::SetDevice set_device;
for (size_t i = 0; i < devs.size(); ++i) {
NCCLEntry e;
e.dev_id = device_ids_[i];
Expand Down
Loading