Skip to content

Commit

Permalink
cuda 6.5 compatibility fix
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Dec 28, 2015
1 parent 794232c commit b3a08e9
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 17 deletions.
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,15 @@ LIB_DEP += $(DMLC_CORE)/libdmlc.a
ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(LIB_DEP)
ifeq ($(USE_CUDA), 1)
ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ)
LDFLAGS += -lnvrtc -lcuda
LDFLAGS += -lcuda
endif

ifeq ($(USE_NVRTC), 1)
LDFLAGS += -lnvrtc
CFLAGS += -DMXNET_USE_NVRTC=1
else
CFLAGS += -DMXNET_USE_NVRTC=0
endif


build/%.o: src/%.cc
Expand Down
5 changes: 2 additions & 3 deletions include/mxnet/mxrtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
#ifndef MXNET_MXRTC_H_
#define MXNET_MXRTC_H_
#include "./base.h"
#if MXNET_USE_CUDA

#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
#include <nvrtc.h>
#include <cuda.h>

Expand Down Expand Up @@ -88,5 +87,5 @@ class MXRtc {

} // namespace mxnet

#endif // MXNET_USE_CUDA
#endif // MXNET_USE_CUDA && MXNET_USE_NVRTC
#endif // MXNET_MXRTC_H_
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ USE_CUDA_PATH = NONE
# whether use CUDNN R3 library
USE_CUDNN = 0

# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
USE_NVRTC = 0

# whether use opencv during compilation
# you can disable it, however, you will not able to use
# imbin iterator
Expand Down
3 changes: 3 additions & 0 deletions make/osx.mk
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ USE_CUDA_PATH = NONE
# whether use CUDNN R3 library
USE_CUDNN = 0

# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
USE_NVRTC = 0

# whether use opencv during compilation
# you can disable it, however, you will not able to use
# imbin iterator
Expand Down
18 changes: 9 additions & 9 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
NDArrayHandle* inputs, NDArrayHandle* outputs,
char* kernel, RtcHandle *out) {
API_BEGIN();
#if MXNET_USE_CUDA
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
std::vector<std::pair<std::string, NDArray> > input, output;
for (mx_uint i = 0; i < num_input; ++i) {
input.push_back(std::pair<std::string, NDArray>(input_names[i],
Expand All @@ -1167,8 +1167,8 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output,
MXRtc *rtc = new MXRtc(name, input, output, kernel);
*out = reinterpret_cast<RtcHandle>(rtc);
#else
LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc.";
#endif // MXNET_USE_CUDA
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
API_END();
}

Expand All @@ -1181,7 +1181,7 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
mx_uint blockDimY,
mx_uint blockDimZ) {
API_BEGIN();
#if MXNET_USE_CUDA
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
std::vector<NDArray> input, output;
for (mx_uint i = 0; i < num_input; ++i) {
input.push_back(*reinterpret_cast<NDArray*>(inputs[i]));
Expand All @@ -1197,18 +1197,18 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output,
blockDimY,
blockDimZ);
#else
LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc.";
#endif // MXNET_USE_CUDA
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
API_END();
}

int MXRtcFree(RtcHandle handle) {
API_BEGIN();
#if MXNET_USE_CUDA
#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
delete reinterpret_cast<MXRtc*>(handle);
#else
LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc.";
#endif // MXNET_USE_CUDA
LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc.";
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
API_END();
}

Expand Down
6 changes: 2 additions & 4 deletions src/common/mxrtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
* \author Junyuan Xie
*/
#include <mxnet/mxrtc.h>
#if MXNET_USE_CUDA

#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))
namespace mxnet {

const std::string MXRtc::str_type = "float";
std::unordered_map<std::string, char*> MXRtc::kernel_registry;

Expand Down Expand Up @@ -139,4 +137,4 @@ char* MXRtc::compile(const std::string& name, const std::string& code) {

} // namespace mxnet

#endif // MXNET_USE_CUDA
#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))

0 comments on commit b3a08e9

Please sign in to comment.