Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-1415]Add MXEnginePushAsyncND and MXEnginePushSyncND C APIs (#1…
Browse files Browse the repository at this point in the history
…5177)

* add MXEnginePushAsyncND and MXEnginePushSyncND

* fix test build

* return exception value

* retrigger CI
  • Loading branch information
wkcn authored and anirudh2290 committed Jun 14, 2019
1 parent 41d35c4 commit 3b663ef
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 4 deletions.
55 changes: 51 additions & 4 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2792,9 +2792,9 @@ MXNET_DLL int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, cons
* \param ctx_handle Execution context.
* \param const_vars_handle The variables that current operation will use
* but not mutate.
* \param num_const_vars The number of const_vars.
* \param num_const_vars The number of const_vars_handle.
* \param mutable_vars_handle The variables that current operation will mutate.
* \param num_mutable_vars The number of mutable_vars.
* \param num_mutable_vars The number of mutable_vars_handle.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
Expand All @@ -2816,9 +2816,9 @@ MXNET_DLL int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
* \param ctx_handle Execution context.
* \param const_vars_handle The variables that current operation will use
* but not mutate.
* \param num_const_vars The number of const_vars.
* \param num_const_vars The number of const_vars_handle.
* \param mutable_vars_handle The variables that current operation will mutate.
* \param num_mutable_vars The number of mutable_vars.
* \param num_mutable_vars The number of mutable_vars_handle.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
Expand All @@ -2830,6 +2830,53 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));

/*!
* \brief Push an asynchronous operation to the engine.
* \param async_func Execution function whici takes a parameter on_complete
* that must be called when the execution ompletes.
* \param func_param The parameter set on calling async_func, can be NULL.
* \param deleter The callback to free func_param, can be NULL.
* \param ctx_handle Execution context.
* \param const_nds_handle The NDArrays that current operation will use
* but not mutate.
* \param num_const_nds The number of const_nds_handle.
* \param mutable_nds_handle The NDArrays that current operation will mutate.
* \param num_mutable_nds The number of mutable_nds_handle.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
* \param wait Whether this is a WaitForVar operation.
*/
MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
bool wait DEFAULT(false));

/*!
* \brief Push a synchronous operation to the engine.
* \param sync_func Execution function that executes the operation.
* \param func_param The parameter set on calling sync_func, can be NULL.
* \param deleter The callback to free func_param, can be NULL.
* \param ctx_handle Execution context.
* \param const_nds_handle The NDArrays that current operation will use
* but not mutate.
* \param num_const_nds The number of const_nds_handle.
* \param mutable_nds_handle The NDArrays that current operation will mutate.
* \param num_mutable_nds The number of mutable_nds_handle.
* \param prop_handle Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operation name.
*/
MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
40 changes: 40 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,46 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
API_END();
}

int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
prop_handle, priority, opr_name, wait);
API_END();
}

int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
prop_handle, priority, opr_name);
API_END();
}

int MXStorageEmptyCache(int dev_type, int dev_id) {
API_BEGIN();
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
Expand Down
48 changes: 48 additions & 0 deletions tests/cpp/engine/threaded_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <gtest/gtest.h>
#include <mxnet/c_api.h>
#include <mxnet/engine.h>
#include <mxnet/ndarray.h>
#include <dmlc/timer.h>
#include <cstdio>
#include <thread>
Expand Down Expand Up @@ -254,6 +255,53 @@ TEST(Engine, PushFunc) {
EXPECT_EQ(res, -1);
}

TEST(Engine, PushFuncND) {
auto ctx = mxnet::Context{};
mxnet::NDArray nd(ctx);

// Test #1
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
int* b = new int(101);
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);
}

TEST(Engine, basics) {
auto&& engine = mxnet::Engine::Get();
auto&& var = engine->NewVariable();
Expand Down

0 comments on commit 3b663ef

Please sign in to comment.