diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 2d5122cf670f..b3dca69b0c8b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2792,9 +2792,9 @@ MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, cons * \param ctx_handle Execution context. * \param const_vars_handle The variables that current operation will use * but not mutate. - * \param num_const_vars The number of const_vars. + * \param num_const_vars The number of const_vars_handle. * \param mutable_vars_handle The variables that current operation will mutate. - * \param num_mutable_vars The number of mutable_vars. + * \param num_mutable_vars The number of mutable_vars_handle. * \param prop_handle Property of the function. * \param priority Priority of the action, as hint to the engine. * \param opr_name The operation name. @@ -2816,9 +2816,9 @@ MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param, * \param ctx_handle Execution context. * \param const_vars_handle The variables that current operation will use * but not mutate. - * \param num_const_vars The number of const_vars. + * \param num_const_vars The number of const_vars_handle. * \param mutable_vars_handle The variables that current operation will mutate. - * \param num_mutable_vars The number of mutable_vars. + * \param num_mutable_vars The number of mutable_vars_handle. * \param prop_handle Property of the function. * \param priority Priority of the action, as hint to the engine. * \param opr_name The operation name. @@ -2830,6 +2830,53 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, EngineFnPropertyHandle prop_handle DEFAULT(NULL), int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); +/*! + * \brief Push an asynchronous operation to the engine. + * \param async_func Execution function whici takes a parameter on_complete + * that must be called when the execution ompletes. + * \param func_param The parameter set on calling async_func, can be NULL. + * \param deleter The callback to free func_param, can be NULL. + * \param ctx_handle Execution context. + * \param const_nds_handle The NDArrays that current operation will use + * but not mutate. + * \param num_const_nds The number of const_nds_handle. + * \param mutable_nds_handle The NDArrays that current operation will mutate. + * \param num_mutable_nds The number of mutable_nds_handle. + * \param prop_handle Property of the function. + * \param priority Priority of the action, as hint to the engine. + * \param opr_name The operation name. + * \param wait Whether this is a WaitForVar operation. + */ +MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle const_nds_handle, int num_const_nds, + NDArrayHandle mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle DEFAULT(NULL), + int priority DEFAULT(0), const char* opr_name DEFAULT(NULL), + bool wait DEFAULT(false)); + +/*! + * \brief Push a synchronous operation to the engine. + * \param sync_func Execution function that executes the operation. + * \param func_param The parameter set on calling sync_func, can be NULL. + * \param deleter The callback to free func_param, can be NULL. + * \param ctx_handle Execution context. + * \param const_nds_handle The NDArrays that current operation will use + * but not mutate. + * \param num_const_nds The number of const_nds_handle. + * \param mutable_nds_handle The NDArrays that current operation will mutate. + * \param num_mutable_nds The number of mutable_nds_handle. + * \param prop_handle Property of the function. + * \param priority Priority of the action, as hint to the engine. + * \param opr_name The operation name. + */ +MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle const_nds_handle, int num_const_nds, + NDArrayHandle mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle DEFAULT(NULL), + int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a9a49b08a001..35bd3eeb477a 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1534,6 +1534,46 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, API_END(); } +int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle const_nds_handle, int num_const_nds, + NDArrayHandle mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle, int priority, + const char* opr_name, bool wait) { + API_BEGIN(); + NDArray* const_nds = static_cast(const_nds_handle); + NDArray* mutable_nds = static_cast(mutable_nds_handle); + std::vector const_var_vec(num_const_nds); + for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var(); + std::vector mutable_var_vec(num_mutable_nds); + for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var(); + return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle, + const_var_vec.data(), num_const_nds, + mutable_var_vec.data(), num_mutable_nds, + prop_handle, priority, opr_name, wait); + API_END(); +} + +int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param, + EngineFuncParamDeleter deleter, ContextHandle ctx_handle, + NDArrayHandle const_nds_handle, int num_const_nds, + NDArrayHandle mutable_nds_handle, int num_mutable_nds, + EngineFnPropertyHandle prop_handle, int priority, + const char* opr_name) { + API_BEGIN(); + NDArray* const_nds = static_cast(const_nds_handle); + NDArray* mutable_nds = static_cast(mutable_nds_handle); + std::vector const_var_vec(num_const_nds); + for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var(); + std::vector mutable_var_vec(num_mutable_nds); + for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var(); + return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle, + const_var_vec.data(), num_const_nds, + mutable_var_vec.data(), num_mutable_nds, + prop_handle, priority, opr_name); + API_END(); +} + int MXStorageEmptyCache(int dev_type, int dev_id) { API_BEGIN(); Context ctx = Context::Create(static_cast(dev_type), dev_id); diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc index ef3aec18e529..6b863f8efdc0 100644 --- a/tests/cpp/engine/threaded_engine_test.cc +++ b/tests/cpp/engine/threaded_engine_test.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -254,6 +255,53 @@ TEST(Engine, PushFunc) { EXPECT_EQ(res, -1); } +TEST(Engine, PushFuncND) { + auto ctx = mxnet::Context{}; + mxnet::NDArray nd(ctx); + + // Test #1 + LOG(INFO) << "===== Test #1: PushAsyncND param and deleter ====="; + int* a = new int(100); + int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0); + EXPECT_EQ(res, 0); + + // Test #2 + LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter ====="; + res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0); + EXPECT_EQ(res, 0); + + // Test #3 + LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds ====="; + res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0); + EXPECT_EQ(res, -1); + + // Test #4 + LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds ====="; + res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1); + EXPECT_EQ(res, -1); + + // Test #5 + LOG(INFO) << "===== Test #5: PushSyncND param and deleter ====="; + int* b = new int(101); + res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0); + EXPECT_EQ(res, 0); + + // Test #6 + LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter ====="; + res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1); + EXPECT_EQ(res, 0); + + // Test #7 + LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds ====="; + res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0); + EXPECT_EQ(res, -1); + + // Test #8 + LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds ====="; + res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1); + EXPECT_EQ(res, -1); +} + TEST(Engine, basics) { auto&& engine = mxnet::Engine::Get(); auto&& var = engine->NewVariable();