Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Subgraph API for integrating accelerators with MXNet (#12157)
Browse files Browse the repository at this point in the history
* Graph partitioner and subgraph op (#11251)

Graph partitioner and subgraph op

Fix duplicate entry bugs (#11767)

Make subgraph var node name unique (#11876)

[DO NOT REVIEW] Fix bug of eliminating cycles (#11907)

* Fix cycle bug

* Fix decycle bug

* Fix comment

[DO NOT REVIEW] Subgraph API (#12104)

* Initial commit

* Add unit tests

* Fix lint

* Fix lint

* Clean up

* Add graph partitiong to Bind

* Add property name to graph partitioning c api

* Fix unit test gpu context

* Address cr

* Move subgraph to attrs.subgraphs and fix the example

* Fix lint

* Add var version unit test

* Address cr

* Enable unit test that was flaky

* Clean up

* Clean up

* Clean up

* Change version return type in NDArray

* Clean up

* Add register or get for subgraph prop registry

* Address cr

* Remove unnecessary code

* Handle var version issue in naive engine

* Delete example

* Remove registration of resource request for default subgraph op

* Add doc string

* Improve doc string
  • Loading branch information
reminisce authored and eric-haibin-lin committed Aug 31, 2018
1 parent 32c9ca7 commit a64cf7d
Show file tree
Hide file tree
Showing 19 changed files with 2,059 additions and 23 deletions.
66 changes: 66 additions & 0 deletions include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file c_api_test.h
* \brief C API of mxnet for ease of testing backend in Python
*/
#ifndef MXNET_C_API_TEST_H_
#define MXNET_C_API_TEST_H_

/*! \brief Inhibit C++ name-mangling for MXNet functions. */
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus

#include <mxnet/c_api.h>

/*!
* \brief This API partitions a graph only by the operator names
* provided by users. This will attach a DefaultSubgraphProperty
* to the input graph for partitioning. This function should be
* used only for the testing purpose.
*/
MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle);

/*!
* \brief Given a subgraph property name, use the provided op names
* as the op_names attribute for that subgraph property, instead of
* the predefined one. This is only for the purpose of testing.
*/
MXNET_DLL int MXSetSubgraphPropertyOpNames(const char* prop_name,
const mx_uint num_ops,
const char** op_names);

/*!
* \brief Given a subgraph property name, delete the op name set
* in the SubgraphPropertyOpNameSet.
*/
MXNET_DLL int MXRemoveSubgraphPropertyOpNames(const char* prop_name);

#ifdef __cplusplus
}
#endif // __cplusplus

#endif // MXNET_C_API_TEST_H_
22 changes: 20 additions & 2 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,26 @@ class Engine;

/*! \brief namespace of engine internal types. */
namespace engine {
/*! \brief Internal representation of variable. */
struct Var;
/*! \brief base class of engine variables.*/
struct Var {
virtual size_t version() {
return version_;
}
virtual ~Var() = default;
/*!
* \brief cast variable to derived type T
* \tparam T the type we want to cast into.
* \return A casted variable.
*/
template <typename T>
inline T* Cast();
/*!
* \brief version number of the var. Every time the object it is associated with
* is modified, the version number is incremented by 1.
*/
size_t version_{0};
}; // struct Var

/*! \brief Internal representation of operator. */
struct Opr;
/*! \brief Variable pointer type, usually hold by user used to specify dependencies. */
Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ class NDArray {
inline size_t byte_offset() const {
return byte_offset_;
}
/*! \brief return var version of the NDArray*/
inline size_t version() const {
return var()->version();
}
/*!
* \brief save the content into binary stream
* \param strm the output stream
Expand Down
73 changes: 73 additions & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file c_api_test.cc
* \brief C API of mxnet for the ease of testing backend in Python
*/
#include <mxnet/c_api_test.h>
#include <nnvm/pass.h>
#include "./c_api_common.h"
#include "../operator/subgraph/subgraph_property.h"

int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
SymbolHandle* ret_sym_handle) {
nnvm::Symbol* s = new nnvm::Symbol();
API_BEGIN();
std::unordered_set<std::string> op_name_set;
for (size_t i = 0; i < num_ops; ++i) {
op_name_set.emplace(op_names[i]);
}
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
*s = sym->Copy();
nnvm::Graph g;
g.outputs = s->outputs;
if (!op_name_set.empty()) {
mxnet::op::SubgraphPropertyPtr property
= mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
}
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}

int MXSetSubgraphPropertyOpNames(const char* prop_name,
const mx_uint num_ops,
const char** op_names) {
API_BEGIN();
std::unordered_set<std::string> op_name_set;
for (size_t i = 0; i < num_ops; ++i) {
op_name_set.emplace(op_names[i]);
}
(*mxnet::op::SubgraphPropertyOpNameSet::Get())[prop_name] = op_name_set;
API_END();
}

int MXRemoveSubgraphPropertyOpNames(const char* prop_name) {
API_BEGIN();
mxnet::op::SubgraphPropertyOpNameSet::Get()->erase(prop_name);
API_END();
}
14 changes: 0 additions & 14 deletions src/engine/engine_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,6 @@
namespace mxnet {
namespace engine {

/*! \brief base class of engine variables, used for type checking */
struct Var {
#if ENGINE_DEBUG
virtual ~Var() = default;
#endif // ENGINE_DEBUG
/*!
* \brief cast variable to derived type T
* \tparam T the type we want to cast into.
* \return A casted variable.
*/
template <typename T>
inline T* Cast();
}; // struct Var

/*! \brief base class of engine operators, used for type checking */
struct Opr {
#if ENGINE_DEBUG
Expand Down
31 changes: 25 additions & 6 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,24 @@
#include "./engine_impl.h"
#include "../profiler/profiler.h"
#include "./openmp.h"
#include "../common/object_pool.h"

namespace mxnet {
namespace engine {

/*!
* \brief var used in Naive Engine for tracking the version
* of the objects it is associated with.
*/
class NaiveVar final
: public Var, public common::ObjectPoolAllocatable<NaiveVar> {
public:
inline static NaiveVar* CastFromBase(Var* ptr) {
return ptr->Cast<NaiveVar>();
}
}; // class NaiveVar


// implement naive engine
class NaiveEngine final : public Engine {
public:
Expand Down Expand Up @@ -71,8 +85,7 @@ class NaiveEngine final : public Engine {

// new variables
VarHandle NewVariable() override {
size_t v = ++counter_;
return reinterpret_cast<VarHandle>(v);
return NaiveVar::New();
}

OprHandle NewOperator(AsyncFn fn,
Expand Down Expand Up @@ -146,6 +159,10 @@ class NaiveEngine final : public Engine {
opr->opr_profile.reset(new profiler::ProfileOperator(opr->opr_name, attrs.release()));
opr->opr_profile->start(exec_ctx.dev_type, exec_ctx.dev_id);
}
// increment mutable var version
for (auto var : mutable_vars) {
++var->version_;
}
if (exec_ctx.dev_mask() == gpu::kDevMask) {
#if MXNET_USE_CUDA
size_t dev_id = static_cast<size_t>(exec_ctx.dev_id);
Expand All @@ -171,8 +188,12 @@ class NaiveEngine final : public Engine {
}

void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override {
this->PushSync(delete_fn, exec_ctx, {}, {var},
FnProperty::kNormal, 0, "DeleteVariable");
NaiveVar* naive_var = NaiveVar::CastFromBase(var);
this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable {
delete_fn(ctx);
NaiveVar::Delete(naive_var);
on_complete();
}, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable");
}

void WaitForVar(VarHandle var) override {
Expand All @@ -192,8 +213,6 @@ class NaiveEngine final : public Engine {
}
// whether action is completed
bool req_completed_;
// counter
std::atomic<size_t> counter_{0};
/*! \brief whether it is during shutdown phase*/
std::atomic<bool> shutdown_phase_{false};
// CPU stream
Expand Down
10 changes: 9 additions & 1 deletion src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
assert(pending_write_ != nullptr);
CHECK_EQ(num_pending_reads_, kWriteTriggered);

// increment version number
++version_;

// really delete
if (to_delete_) {
VersionedVarBlock *head = pending_write_->next;
Expand Down Expand Up @@ -164,7 +167,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) {
}
// This is outside of lock scope
// Be very carful, pending_write_ and num_pending_reads_
// can change now, do not reply ont the two variables.
// can change now, do not rely on these two variables.
// The linked list \in [old_pending_write, end_of_read_chain)
// is already detached from this Var.
// So it is safe to modify these
Expand Down Expand Up @@ -196,6 +199,11 @@ inline bool ThreadedVar::ready_to_read() {
return this->is_ready_to_read();
}

inline size_t ThreadedVar::version() {
std::lock_guard<std::mutex> lock{mutex_};
return this->version_;
}

// implementation of threaded engine
ThreadedVar* ThreadedEngine::NewVariable() {
return ThreadedVar::New(VersionedVarBlock::New());
Expand Down
1 change: 1 addition & 0 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class ThreadedVar final
inline void SetToDelete();
/*! \return whether this variable is ready to read. */
inline bool ready_to_read();
inline size_t version() override;
/*!
* \brief Cast a Var pointer to ThreadedVar pointer
* \param ptr pointer from base.
Expand Down
Loading

0 comments on commit a64cf7d

Please sign in to comment.