Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add async GPU depency Engine
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Jun 6, 2021
1 parent a6fdc7a commit c94b243
Show file tree
Hide file tree
Showing 32 changed files with 815 additions and 343 deletions.
6 changes: 3 additions & 3 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ struct RunContext {
* \brief the auxiliary stream of the device, can be nullptr or Stream<gpu>* in GPU mode
*/
void *aux_stream;
/*!
* \brief indicator of whether this execution is run in bulk mode
/*!
* \brief pointer to the cuda event pool used by the dependecy engine
*/
bool is_bulk;
void *event_pool = nullptr;
/*!
* \brief get mshadow stream from Context
* \return the mshadow stream
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ typedef const void *EngineFnPropertyHandle;
typedef void *EngineVarHandle;

/*! \brief Engine asynchronous operation */
typedef void (*EngineAsyncFunc)(void*, void*, void*);
typedef void (*EngineAsyncFunc)(void*, void*, void*, void*);
/*! \brief Engine synchronous operation */
typedef void (*EngineSyncFunc)(void*, void*);
/*! \brief Callback to free the param for EngineAsyncFunc/EngineSyncFunc */
Expand Down
121 changes: 118 additions & 3 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <memory>
#include <functional>
#endif
#include <utility>
#include <vector>
#include "./base.h"

Expand All @@ -40,6 +41,72 @@ class Engine;

/*! \brief namespace of engine internal types. */
namespace engine {
#if MXNET_USE_CUDA
/* \brief The class wrapping CUDA event with timing disabled. */
class CUDAEvent final {
public:
explicit CUDAEvent(Context const& ctx);

CUDAEvent(CUDAEvent&& other)
: event_(other.event_), dev_id_(other.dev_id_) {
other.event_ = nullptr;
}

CUDAEvent(const CUDAEvent& other) = delete;
void operator=(const CUDAEvent& other) = delete;

~CUDAEvent();

inline std::weak_ptr<cudaEvent_t> GetEvent() noexcept {
return event_;
}
private:
std::shared_ptr<cudaEvent_t> event_;
int dev_id_;
};

class CUDAEventPool final {
public:
explicit CUDAEventPool(Context const& ctx) : counter_(0) {
for (size_t i = 0; i < kPoolSize; ++i) {
events_.emplace_back(ctx);
}
}

inline std::weak_ptr<cudaEvent_t> GetEvent(size_t i) noexcept {
return events_.at(i).GetEvent();
}

inline std::pair<std::weak_ptr<cudaEvent_t>, uint64_t> GetNextEvent() noexcept {
int c = counter_++;
return {events_.at((c) % kPoolSize).GetEvent(), c};
}

inline uint64_t GetCounterValue() noexcept {
return counter_.load();
}
private:
static constexpr size_t kPoolSize = 64;
std::vector<CUDAEvent> events_;
std::atomic<uint64_t> counter_;
};

/*! \brief full event info for the sync object.*/
struct EventInfo {
std::weak_ptr<cudaEvent_t> event;
cudaStream_t stream;
uint64_t pool_index;
};
/*! \brief struct containing cuda events and variables needed for the dependencies.*/
struct SyncObject {
// vector can carry multiple reader events
std::vector<EventInfo> reader_events;
// vector should carry only 1 writer event
std::vector<EventInfo> writer_event;
std::mutex mutex;
};
#endif

/*! \brief base class of engine variables.*/
struct Var {
virtual size_t version() {
Expand All @@ -58,6 +125,12 @@ struct Var {
* is modified, the version number is incremented by 1.
*/
size_t version_{0};
#if MXNET_USE_CUDA
/*!
* \brief struct containing cuda events and variables needed for the dependencies.
*/
SyncObject sync_object;
#endif
}; // struct Var

/*! \brief Internal representation of operator. */
Expand All @@ -66,6 +139,29 @@ struct Opr;
typedef Var* VarHandle;
/*! \brief Operator pointer type, usually hold by user.*/
typedef Opr* OprHandle;
/*!
* \brief OnStart callback to the engine,
* called by AsyncFn before the action
*/
class CallbackOnStart {
public:
// use implicit copy and assign
/*! \brief involve the callback */
inline void operator()(const dmlc::Error* error = nullptr) const {
if (callback_ != nullptr)
(*callback_)(engine_, param_, error);
}

private:
/*! \brief engine can see content of callback */
friend class ::mxnet::Engine;
/*! \brief the real callback */
void (*callback_)(Engine *, void *, const dmlc::Error *);
/*! \brief the engine class passed to callback */
Engine* engine_;
/*! \brief the parameter set on callback */
void* param_;
};
/*!
* \brief OnComplete Callback to the engine,
* called by AsyncFn when action completes
Expand Down Expand Up @@ -116,12 +212,14 @@ enum class FnProperty {
*/
class MXNET_API Engine {
public:
/*! \brief on start*/
typedef engine::CallbackOnStart CallbackOnStart;
/*! \brief callback on complete*/
typedef engine::CallbackOnComplete CallbackOnComplete;
/*! \brief Synchronous operation to pass to engine. */
typedef std::function<void(RunContext)> SyncFn;
/*! \brief Asynchronous operation to pass to engine. */
typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
typedef std::function<void(RunContext, CallbackOnStart, CallbackOnComplete)> AsyncFn;
/*! \brief Variable pointer */
typedef engine::VarHandle VarHandle;
/*! \brief Operator pointer */
Expand Down Expand Up @@ -248,7 +346,7 @@ class MXNET_API Engine {
*
* \return A shared pointer to Engine singleton.
*/
static std::shared_ptr<Engine> _GetSharedRef();
static const std::shared_ptr<Engine> &_GetSharedRef();
/*!
* \brief Push an synchronous operation to the engine.
* \param exec_fn Execution function that executes the operation.
Expand All @@ -267,12 +365,29 @@ class MXNET_API Engine {
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) {
this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
this->PushAsync([exec_fn](RunContext ctx,
CallbackOnStart on_start,
CallbackOnComplete on_complete) {
on_start();
exec_fn(ctx);
on_complete();
}, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
}

/*!
* \brief factory function to create OnStart callback.
* \param callback th static callback function.
* \param param the paramter passed to callback.
*/
inline CallbackOnStart CreateOnStart(
void (*callback)(Engine *, void *, const dmlc::Error *), void *param) {
CallbackOnStart ret;
ret.callback_ = callback;
ret.engine_ = this;
ret.param_ = param;
return ret;
}

/*!
* \brief factory function to create OnComplete callback.
* \param callback th static callback function.
Expand Down
19 changes: 18 additions & 1 deletion include/mxnet/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <memory>
#include <string>
#include <vector>
#include "./base.h"

namespace mxnet {
Expand All @@ -39,6 +40,17 @@ namespace mxnet {
*/
class Storage {
public:
/*!
* \brief Storage sync object.
*/
struct SyncObj {
#if MXNET_USE_CUDA
/*!
* \brief All the events from the engine variable.
*/
std::vector<std::weak_ptr<cudaEvent_t>> events;
#endif
};
/*!
* \brief Storage handle.
*/
Expand All @@ -65,6 +77,11 @@ class Storage {
*/
std::string profiler_scope{MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR};
std::string name{MXNET_STORAGE_DEFAULT_NAME_CSTR};
/*!
* \brief Used to pass events back and forth between the engine Var
* and the storage manager.
*/
SyncObj sync_obj;
};
/*!
* \brief Allocate a new contiguous memory for a given size.
Expand Down Expand Up @@ -138,7 +155,7 @@ class Storage {
*
* \return A shared pointer to Storage singleton.
*/
static std::shared_ptr<Storage> _GetSharedRef();
static const std::shared_ptr<Storage> &_GetSharedRef();

private:
std::mutex cpu_mutex_;
Expand Down
7 changes: 5 additions & 2 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3454,6 +3454,7 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const int *shape
}

using VarHandle = Engine::VarHandle;
using CallbackOnStart = Engine::CallbackOnStart;
using CallbackOnComplete = Engine::CallbackOnComplete;

void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) {
Expand All @@ -3480,16 +3481,18 @@ int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param,
Engine::AsyncFn exec_fn;
if (deleter == nullptr) {
exec_fn = [async_func, func_param](RunContext rctx,
CallbackOnStart on_start,
CallbackOnComplete on_complete) {
async_func(&rctx, &on_complete, func_param);
async_func(&rctx, &on_start, &on_complete, func_param);
};
} else {
// Wrap func_param in a shared_ptr with deleter such that deleter
// will be called when the lambda goes out of scope.
std::shared_ptr<void> shared_func_param(func_param, deleter);
exec_fn = [async_func, shared_func_param](RunContext rctx,
CallbackOnStart on_start,
CallbackOnComplete on_complete) {
async_func(&rctx, &on_complete, shared_func_param.get());
async_func(&rctx, &on_start, &on_complete, shared_func_param.get());
};
}

Expand Down
4 changes: 2 additions & 2 deletions src/common/object_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ObjectPool {
* \brief Get a shared ptr of the singleton instance of pool.
* \return Shared pointer to the Object Pool.
*/
static std::shared_ptr<ObjectPool> _GetSharedRef();
static const std::shared_ptr<ObjectPool> &_GetSharedRef();

private:
/*!
Expand Down Expand Up @@ -173,7 +173,7 @@ ObjectPool<T>* ObjectPool<T>::Get() {
}

template <typename T>
std::shared_ptr<ObjectPool<T> > ObjectPool<T>::_GetSharedRef() {
const std::shared_ptr<ObjectPool<T> > &ObjectPool<T>::_GetSharedRef() {
static std::shared_ptr<ObjectPool<T> > inst_ptr(new ObjectPool<T>());
return inst_ptr;
}
Expand Down
21 changes: 20 additions & 1 deletion src/engine/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <memory>
#include <cstdlib>
#include "./engine_impl.h"
#include "../common/cuda/utils.h"

namespace mxnet {
namespace engine {
Expand Down Expand Up @@ -56,9 +57,27 @@ inline Engine* CreateEngine() {
}
return ret;
}

#if MXNET_USE_CUDA
CUDAEvent::CUDAEvent(Context const& ctx) :
event_(std::make_shared<cudaEvent_t>()), dev_id_(ctx.dev_id) {
cudaEvent_t ev;
common::cuda::DeviceStore device_store(dev_id_);
CUDA_CALL(cudaEventCreateWithFlags(&ev, cudaEventDisableTiming));
*event_ = ev;
}

CUDAEvent::~CUDAEvent() {
if (event_ && *event_ != nullptr) {
common::cuda::DeviceStore device_store(dev_id_);
CUDA_CALL(cudaEventSynchronize(*event_));
CUDA_CALL(cudaEventDestroy(*event_));
}
}
#endif
} // namespace engine

std::shared_ptr<Engine> Engine::_GetSharedRef() {
const std::shared_ptr<Engine> &Engine::_GetSharedRef() {
static std::shared_ptr<Engine> sptr(engine::CreateEngine());
return sptr;
}
Expand Down
Loading

0 comments on commit c94b243

Please sign in to comment.