Skip to content

Commit

Permalink
mxnet: fix ABI compatibility issues (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored and ymjiang committed Sep 22, 2019
1 parent 33a7f91 commit b9a5ba6
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions byteps/mxnet/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ namespace mxnet {
namespace {

std::atomic_int op_count;
const auto MX_EXEC_CTX = Context::CPU();
const auto MX_FUNC_PROP = FnProperty::kCPUPrioritized;

// struct to hold parameters for pushpull with MXNet Engine
struct PushPullParam {
BPSContext* context;
NDArray* input;
int version;
int priority;

PushPullParam(BPSContext* context, NDArray* input, int version, int priority)
: context(context), input(input), version(version), priority(priority) {}
};

// callback function to release parameters used for pushpull with MXNet Engine
void DeletePushPullParam(void* param) {
auto push_pull_param = static_cast<PushPullParam*>(param);
delete push_pull_param;
}

std::string GetOpName(std::string prefix, char* name) {
if (name != nullptr) {
Expand All @@ -48,9 +67,14 @@ inline void InvokeCompleteCallback(Callback on_complete, const Status& status) {
}
}

void DoPushPull(BPSContext& context, NDArray* input, int version, int priority,
Callback on_complete) {
void DoPushPull(void*, void* on_complete_ptr, void* param) {
ThrowIfError(common::CheckInitialized());
auto on_complete = *static_cast<Callback*>(on_complete_ptr);
auto push_pull_param = static_cast<PushPullParam*>(param);
int priority = push_pull_param->priority;
int version = push_pull_param->version;
NDArray* input = push_pull_param->input;
BPSContext& context = *push_pull_param->context;

auto device = TensorUtil::GetDevice(input);
auto byteps_input = std::make_shared<MXTensor<NDArray>>(input);
Expand Down Expand Up @@ -85,14 +109,13 @@ extern "C" int byteps_mxnet_push_pull_async(NDArray* tensor, char* name,
: nullptr;
common::InitTensor(context, size, dtype, cpubuff);

auto push_pull_async_fn = [&context, tensor, version, priority](
RunContext rctx, Callback on_complete) mutable {
DoPushPull(context, tensor, version, priority, on_complete);
};

Engine::Get()->PushAsync(push_pull_async_fn, Context::CPU(), {},
{tensor->var()}, FnProperty::kCPUPrioritized, 0,
"BytePSPushPull");
auto push_pull_param = new PushPullParam(&context, tensor, version, priority);
auto var = tensor->var();
// Use MXEnginePushAsync instead of Engine::Get()->PushAsync to avoid ABI
// compatibility issues
MXEnginePushAsync(DoPushPull, push_pull_param, DeletePushPullParam,
&MX_EXEC_CTX, nullptr, 0, &var, 1,
&MX_FUNC_PROP, 0, "BytePSPushPull");

if (is_average) {
// average the aggregated gradient
Expand Down

0 comments on commit b9a5ba6

Please sign in to comment.