diff --git a/Makefile b/Makefile index 6bd9ff2f175d..ea71cd3fff86 100644 --- a/Makefile +++ b/Makefile @@ -114,9 +114,15 @@ LIB_DEP += $(DMLC_CORE)/libdmlc.a ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(LIB_DEP) ifeq ($(USE_CUDA), 1) ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ) - LDFLAGS += -lnvrtc -lcuda + LDFLAGS += -lcuda endif +ifeq ($(USE_NVRTC), 1) + LDFLAGS += -lnvrtc + CFLAGS += -DMXNET_USE_NVRTC=1 +else + CFLAGS += -DMXNET_USE_NVRTC=0 +endif build/%.o: src/%.cc diff --git a/include/mxnet/mxrtc.h b/include/mxnet/mxrtc.h index e0418110277c..de8c385549bb 100644 --- a/include/mxnet/mxrtc.h +++ b/include/mxnet/mxrtc.h @@ -7,8 +7,7 @@ #ifndef MXNET_MXRTC_H_ #define MXNET_MXRTC_H_ #include "./base.h" -#if MXNET_USE_CUDA - +#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) #include #include @@ -88,5 +87,5 @@ class MXRtc { } // namespace mxnet -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA && MXNET_USE_NVRTC #endif // MXNET_MXRTC_H_ diff --git a/make/config.mk b/make/config.mk index 6585e5299f5e..8e9f8af3a5da 100644 --- a/make/config.mk +++ b/make/config.mk @@ -48,6 +48,9 @@ USE_CUDA_PATH = NONE # whether use CUDNN R3 library USE_CUDNN = 0 +# whether use cuda runtime compiling for writing kernels in native language (i.e. Python) +USE_NVRTC = 0 + # whether use opencv during compilation # you can disable it, however, you will not able to use # imbin iterator diff --git a/make/osx.mk b/make/osx.mk index 13a6389bba04..23c2c7a363e5 100644 --- a/make/osx.mk +++ b/make/osx.mk @@ -48,6 +48,9 @@ USE_CUDA_PATH = NONE # whether use CUDNN R3 library USE_CUDNN = 0 +# whether use cuda runtime compiling for writing kernels in native language (i.e. Python) +USE_NVRTC = 0 + # whether use opencv during compilation # you can disable it, however, you will not able to use # imbin iterator diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 2c913d85ddf5..3deea52f9e9d 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1154,7 +1154,7 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output, NDArrayHandle* inputs, NDArrayHandle* outputs, char* kernel, RtcHandle *out) { API_BEGIN(); -#if MXNET_USE_CUDA +#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) std::vector > input, output; for (mx_uint i = 0; i < num_input; ++i) { input.push_back(std::pair(input_names[i], @@ -1167,8 +1167,8 @@ int MXRtcCreate(char* name, mx_uint num_input, mx_uint num_output, MXRtc *rtc = new MXRtc(name, input, output, kernel); *out = reinterpret_cast(rtc); #else - LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc."; -#endif // MXNET_USE_CUDA + LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc."; +#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) API_END(); } @@ -1181,7 +1181,7 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output, mx_uint blockDimY, mx_uint blockDimZ) { API_BEGIN(); -#if MXNET_USE_CUDA +#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) std::vector input, output; for (mx_uint i = 0; i < num_input; ++i) { input.push_back(*reinterpret_cast(inputs[i])); @@ -1197,18 +1197,18 @@ int MXRtcPush(RtcHandle handle, mx_uint num_input, mx_uint num_output, blockDimY, blockDimZ); #else - LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc."; -#endif // MXNET_USE_CUDA + LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc."; +#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) API_END(); } int MXRtcFree(RtcHandle handle) { API_BEGIN(); -#if MXNET_USE_CUDA +#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) delete reinterpret_cast(handle); #else - LOG(FATAL) << "Need to compile with USE_CUDA=1 for MXRtc."; -#endif // MXNET_USE_CUDA + LOG(FATAL) << "Need to compile with USE_CUDA=1 and USE_NVRTC=1 for MXRtc."; +#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) API_END(); } diff --git a/src/common/mxrtc.cc b/src/common/mxrtc.cc index c23e5eacc94f..4fd687267409 100644 --- a/src/common/mxrtc.cc +++ b/src/common/mxrtc.cc @@ -5,10 +5,8 @@ * \author Junyuan Xie */ #include -#if MXNET_USE_CUDA - +#if ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC)) namespace mxnet { - const std::string MXRtc::str_type = "float"; std::unordered_map MXRtc::kernel_registry; @@ -139,4 +137,4 @@ char* MXRtc::compile(const std::string& name, const std::string& code) { } // namespace mxnet -#endif // MXNET_USE_CUDA +#endif // ((MXNET_USE_CUDA) && (MXNET_USE_NVRTC))