Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add async GPU dependency Engine #20331

Merged
merged 8 commits into from
Oct 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ci/jenkins/Jenkinsfile_unix_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ core_logic: {
custom_steps.test_unix_cpp_package_gpu('gpu'),
// TODO(szha): fix and reenable the hanging issue. tracked in #18098
// custom_steps.test_unix_distributed_kvstore_gpu('gpu'),
custom_steps.test_unix_byteps_gpu('gpu'),
// TODO(spanev): reenable when byteps is updated with the new dep engine API
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an issue to track this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, tracked here #20697

// custom_steps.test_unix_byteps_gpu('gpu'),
])
}
,
Expand Down
4 changes: 2 additions & 2 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ struct RunContext {
*/
void *aux_stream;
/*!
* \brief indicator of whether this execution is run in bulk mode
* \brief pointer to the cuda event pool used by the dependency engine
*/
bool is_bulk;
void* event_pool = nullptr;
/*!
* \brief get mshadow stream from Context
* \return the mshadow stream
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ typedef const void *EngineFnPropertyHandle;
typedef void *EngineVarHandle;

/*! \brief Engine asynchronous operation */
typedef void (*EngineAsyncFunc)(void*, void*, void*);
typedef void (*EngineAsyncFunc)(void*, void*, void*, void*);
/*! \brief Engine synchronous operation */
typedef void (*EngineSyncFunc)(void*, void*);
/*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */
Expand Down
133 changes: 127 additions & 6 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <memory>
#include <functional>
#endif
#include <utility>
#include <vector>
#include "./base.h"

Expand All @@ -39,6 +40,73 @@ class Engine;

/*! \brief namespace of engine internal types. */
namespace engine {
#if MXNET_USE_CUDA
/* \brief The class wrapping CUDA event with timing disabled. */
class CUDAEvent final {
public:
explicit CUDAEvent(Context const& ctx);

CUDAEvent(CUDAEvent&& other) : event_(other.event_), dev_id_(other.dev_id_) {
other.event_ = nullptr;
}

CUDAEvent(const CUDAEvent& other) = delete;
void operator=(const CUDAEvent& other) = delete;

~CUDAEvent();

inline std::weak_ptr<cudaEvent_t> GetEvent() noexcept {
return event_;
}

private:
std::shared_ptr<cudaEvent_t> event_;
int dev_id_;
};

class CUDAEventPool final {
public:
explicit CUDAEventPool(Context const& ctx) : counter_(0) {
for (size_t i = 0; i < kPoolSize; ++i) {
events_.emplace_back(ctx);
}
}

inline std::weak_ptr<cudaEvent_t> GetEvent(size_t i) noexcept {
return events_.at(i).GetEvent();
}

inline std::pair<std::weak_ptr<cudaEvent_t>, uint64_t> GetNextEvent() noexcept {
uint64_t c = counter_++;
return {events_.at((c) % kPoolSize).GetEvent(), c};
}

inline uint64_t GetCounterValue() noexcept {
return counter_.load();
}

private:
static constexpr size_t kPoolSize = 64;
std::vector<CUDAEvent> events_;
std::atomic<uint64_t> counter_;
};

/*! \brief full event info for the sync object.*/
struct EventInfo {
std::weak_ptr<cudaEvent_t> event;
cudaStream_t stream;
uint64_t pool_index;
};
/*! \brief struct containing cuda events and variables needed for the dependencies.*/
struct SyncObject {
// vector can carry multiple reader events
std::vector<EventInfo> reader_events;
// vector should carry only 1 writer event
std::vector<EventInfo> writer_event;
std::mutex mutex;
};
#endif

/*! \brief base class of engine variables.*/
struct Var {
virtual size_t version() {
Expand All @@ -57,6 +125,12 @@ struct Var {
* is modified, the version number is incremented by 1.
*/
size_t version_{0};
#if MXNET_USE_CUDA
/*!
* \brief struct containing cuda events and variables needed for the dependencies.
*/
SyncObject sync_object;
#endif
}; // struct Var

/*! \brief Internal representation of operator. */
Expand All @@ -65,6 +139,29 @@ struct Opr;
typedef Var* VarHandle;
/*! \brief Operator pointer type, usually hold by user.*/
typedef Opr* OprHandle;
/*!
* \brief OnStart callback to the engine,
* called by AsyncFn before the action
*/
class CallbackOnStart {
public:
// use implicit copy and assign
/*! \brief involve the callback */
inline void operator()(const dmlc::Error* error = nullptr) const {
if (callback_ != nullptr)
(*callback_)(engine_, param_, error);
}

private:
/*! \brief engine can see content of callback */
friend class ::mxnet::Engine;
/*! \brief the real callback */
void (*callback_)(Engine*, void*, const dmlc::Error*);
/*! \brief the engine class passed to callback */
Engine* engine_;
/*! \brief the parameter set on callback */
void* param_;
};
/*!
* \brief OnComplete Callback to the engine,
* called by AsyncFn when action completes
Expand Down Expand Up @@ -115,12 +212,14 @@ enum class FnProperty {
*/
class MXNET_API Engine {
public:
/*! \brief on start*/
typedef engine::CallbackOnStart CallbackOnStart;
/*! \brief callback on complete*/
typedef engine::CallbackOnComplete CallbackOnComplete;
/*! \brief Synchronous operation to pass to engine. */
typedef std::function<void(RunContext)> SyncFn;
/*! \brief Asynchronous operation to pass to engine. */
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
typedef std::function<void(RunContext, CallbackOnStart, CallbackOnComplete)> AsyncFn;
/*! \brief Variable pointer */
typedef engine::VarHandle VarHandle;
/*! \brief Operator pointer */
Expand Down Expand Up @@ -247,7 +346,7 @@ class MXNET_API Engine {
*
* \return A shared pointer to Engine singleton.
*/
static std::shared_ptr<Engine> _GetSharedRef();
static const std::shared_ptr<Engine>& _GetSharedRef();
/*!
* \brief Push an synchronous operation to the engine.
* \param exec_fn Execution function that executes the operation.
Expand All @@ -266,10 +365,32 @@ class MXNET_API Engine {
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) {
this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
exec_fn(ctx);
on_complete();
}, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
this->PushAsync(
[exec_fn](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
on_start();
exec_fn(ctx);
on_complete();
},
exec_ctx,
const_vars,
mutable_vars,
prop,
priority,
opr_name);
}

/*!
* \brief factory function to create OnStart callback.
* \param callback th static callback function.
* \param param the paramter passed to callback.
*/
inline CallbackOnStart CreateOnStart(void (*callback)(Engine*, void*, const dmlc::Error*),
void* param) {
CallbackOnStart ret;
ret.callback_ = callback;
ret.engine_ = this;
ret.param_ = param;
return ret;
}

/*!
Expand Down
19 changes: 18 additions & 1 deletion include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <memory>
#include <string>
#include <vector>
#include "./base.h"

namespace mxnet {
Expand All @@ -38,6 +39,17 @@ namespace mxnet {
*/
class Storage {
public:
/*!
* \brief Storage sync object.
*/
struct SyncObj {
#if MXNET_USE_CUDA
/*!
* \brief All the events from the engine variable.
*/
std::vector<std::weak_ptr<cudaEvent_t>> events;
#endif
};
/*!
* \brief Storage handle.
*/
Expand All @@ -64,6 +76,11 @@ class Storage {
*/
std::string profiler_scope{MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR};
std::string name{MXNET_STORAGE_DEFAULT_NAME_CSTR};
/*!
* \brief Used to pass events back and forth between the engine Var
* and the storage manager.
*/
SyncObj sync_obj;
};
/*!
* \brief Allocate a new contiguous memory for a given size.
Expand Down Expand Up @@ -137,7 +154,7 @@ class Storage {
*
* \return A shared pointer to Storage singleton.
*/
static std::shared_ptr<Storage> _GetSharedRef();
static const std::shared_ptr<Storage>& _GetSharedRef();

private:
std::mutex cpu_mutex_;
Expand Down
11 changes: 7 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3764,6 +3764,7 @@ int MXNDArrayCreateFromSharedMem(int shared_pid,
}

using VarHandle = Engine::VarHandle;
using CallbackOnStart = Engine::CallbackOnStart;
using CallbackOnComplete = Engine::CallbackOnComplete;

void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) {
Expand Down Expand Up @@ -3795,15 +3796,17 @@ int MXEnginePushAsync(EngineAsyncFunc async_func,

Engine::AsyncFn exec_fn;
if (deleter == nullptr) {
exec_fn = [async_func, func_param](RunContext rctx, CallbackOnComplete on_complete) {
async_func(&rctx, &on_complete, func_param);
exec_fn = [async_func, func_param](
RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
async_func(&rctx, &on_start, &on_complete, func_param);
};
} else {
// Wrap func_param in a shared_ptr with deleter such that deleter
// will be called when the lambda goes out of scope.
std::shared_ptr<void> shared_func_param(func_param, deleter);
exec_fn = [async_func, shared_func_param](RunContext rctx, CallbackOnComplete on_complete) {
async_func(&rctx, &on_complete, shared_func_param.get());
exec_fn = [async_func, shared_func_param](
RunContext rctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
async_func(&rctx, &on_start, &on_complete, shared_func_param.get());
};
}

Expand Down
4 changes: 2 additions & 2 deletions src/common/object_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ObjectPool {
* \brief Get a shared ptr of the singleton instance of pool.
* \return Shared pointer to the Object Pool.
*/
static std::shared_ptr<ObjectPool> _GetSharedRef();
static const std::shared_ptr<ObjectPool>& _GetSharedRef();

private:
/*!
Expand Down Expand Up @@ -170,7 +170,7 @@ ObjectPool<T>* ObjectPool<T>::Get() {
}

template <typename T>
std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
const std::shared_ptr<ObjectPool<T> >& ObjectPool<T>::_GetSharedRef() {
static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
return inst_ptr;
}
Expand Down
28 changes: 27 additions & 1 deletion src/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <memory>
#include <cstdlib>
#include "./engine_impl.h"
#include "../common/cuda/utils.h"

namespace mxnet {
namespace engine {
Expand All @@ -35,6 +36,13 @@ inline Engine* CreateEngine() {
type = "ThreadedEnginePerDevice";
std::string stype = type;

// The async tag is used later to determine if we use the GPU dependecy engine
std::string async_engine_tag = "Async";
auto tag_pos = stype.find(async_engine_tag);
if (tag_pos != std::string::npos && tag_pos + async_engine_tag.length() == stype.length()) {
stype = stype.substr(0, tag_pos);
}

Engine* ret = nullptr;
#if MXNET_PREDICT_ONLY == 0
if (stype == "NaiveEngine") {
Expand All @@ -56,9 +64,27 @@ inline Engine* CreateEngine() {
}
return ret;
}

#if MXNET_USE_CUDA
CUDAEvent::CUDAEvent(Context const& ctx)
: event_(std::make_shared<cudaEvent_t>()), dev_id_(ctx.dev_id) {
cudaEvent_t ev;
common::cuda::DeviceStore device_store(dev_id_);
CUDA_CALL(cudaEventCreateWithFlags(&ev, cudaEventDisableTiming));
*event_ = ev;
}

CUDAEvent::~CUDAEvent() {
if (event_ && *event_ != nullptr) {
common::cuda::DeviceStore device_store(dev_id_);
CUDA_CALL(cudaEventSynchronize(*event_));
CUDA_CALL(cudaEventDestroy(*event_));
}
}
#endif
} // namespace engine

std::shared_ptr<Engine> Engine::_GetSharedRef() {
const std::shared_ptr<Engine>& Engine::_GetSharedRef() {
static std::shared_ptr<Engine> sptr(engine::CreateEngine());
return sptr;
}
Expand Down
Loading