Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Custom Operator Profiling Enhancement (#15210)
Browse files Browse the repository at this point in the history
* working version

* style fix

* several fixes

* resolve issues in the comments

* revert to using thread-safe Get() for singleton class CustomOpProfiler

* indentation

* Now supports Naive Engine

* style fix

* tidiness

* tests added

* style fix

* add a new test case which has multiple custom ops

* testcases fix

* fix

* fix style

* minor naive engine fix

* simplify some branching logic

* better desing style

* fix

* fix

* fix

* fix

* fix

* fix

* add isprofiling check to onCustomStart

* fix

* rename dummy_wait

* fix conflict

* improve test

* fix

* fix test cases

* fix test cases

* fix testcases

* revert back to reduce overhead

* fix style

* Re-Trigger build

* rename var

* Re-Trigger build
  • Loading branch information
Zha0q1 authored and anirudh2290 committed Jun 28, 2019
1 parent cd19367 commit 92fce90
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 19 deletions.
9 changes: 7 additions & 2 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "../profiler/profiler.h"
#include "./openmp.h"
#include "../common/object_pool.h"
#include "../profiler/custom_op_profiler.h"

namespace mxnet {
namespace engine {
Expand Down Expand Up @@ -160,15 +161,19 @@ class NaiveEngine final : public Engine {
profiler::Profiler *profiler = profiler::Profiler::Get();
NaiveOpr *opr = nullptr;
const bool profiling = opr_name && profiler->IsProfiling(profiler::Profiler::kImperative);
// GenerateDisplayName() will return a pointer to the correct name of the operator
const char* display_name = profiling ?
profiler::CustomOpProfiler::Get()->GenerateDisplayName(opr_name) :
opr_name;
if (profiling) {
opr = NewOperator(exec_fun, const_vars, mutable_vars,
prop, opr_name)->Cast<NaiveOpr>();
prop, display_name)->Cast<NaiveOpr>();
opr->profiling = profiling;
std::unique_ptr<profiler::ProfileOperator::Attributes> attrs;
if (profiler->AggregateEnabled()) {
attrs.reset(new profiler::ProfileOperator::Attributes());
}
opr->opr_profile.reset(new profiler::ProfileOperator(opr->opr_name, attrs.release()));
opr->opr_profile.reset(new profiler::ProfileOperator(display_name, attrs.release()));
opr->opr_profile->start(exec_ctx.dev_type, exec_ctx.dev_id);
}
if (exec_ctx.dev_mask() == gpu::kDevMask) {
Expand Down
10 changes: 7 additions & 3 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,11 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {

void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) {
BulkFlush();

ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
if (profiling) {
threaded_opr->opr_name =
profiler::CustomOpProfiler::Get()->GenerateDisplayName(threaded_opr->opr_name);
}
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;

Expand Down Expand Up @@ -333,9 +336,10 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
<< device_count_;
}
#endif
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
opr->temporary = true;
const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative);
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars,
prop, opr_name, wait);
opr->temporary = true;
Push(opr, exec_ctx, priority, profiling);
}

Expand Down
1 change: 1 addition & 0 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "../profiler/profiler.h"
#include "./openmp.h"
#include "../common/object_pool.h"
#include "../profiler/custom_op_profiler.h"

namespace mxnet {
namespace engine {
Expand Down
23 changes: 18 additions & 5 deletions src/operator/custom/custom-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <condition_variable>
#include <queue>
#include "../operator_common.h"
#include "../../profiler/custom_op_profiler.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -76,9 +77,16 @@ class CustomOperator {
bool training, const std::vector<NDArray>& arrs,
const std::vector<int>& tags,
const std::unordered_set<int>& output_tags,
const std::vector<NDArray>& outputs) {
const std::vector<NDArray>& outputs,
const std::string op_type = "") {
if (naive_engine_) {
func();
if (profiler::Profiler::Get()->IsProfiling(profiler::Profiler::kImperative)) {
profiler::CustomOpProfiler::Get()->OnCustomBegin(op_type);
func();
profiler::CustomOpProfiler::Get()->OnCustomEnd();
} else {
func();
}
for (size_t i = 0, out_idx = 0; i < arrs.size(); i++) {
if (arrs[i].storage_type() == kDefaultStorage ||
arrs[i].storage_type() == kUndefinedStorage)
Expand All @@ -97,7 +105,13 @@ class CustomOperator {
bool prev_training = Imperative::Get()->set_is_training(training);

try {
func();
if (profiler::Profiler::Get()->IsProfiling(profiler::Profiler::kImperative)) {
profiler::CustomOpProfiler::Get()->OnCustomBegin(op_type);
func();
profiler::CustomOpProfiler::Get()->OnCustomEnd();
} else {
func();
}
} catch (dmlc::Error& e) {
exception_ =
std::make_shared<std::exception_ptr>(std::current_exception());
Expand Down Expand Up @@ -143,8 +157,7 @@ class CustomOperator {

ctx.async_on_complete();
},
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNoSkip, 0,
"CustomOperator");
ctx.run_ctx.ctx, vars, vars2, FnProperty::kNoSkip, 0, "CustomOperatorWait");
});
// increase num_threads if there is not enough threads to execute custom operator
if (q_.size() > num_free_threads_)
Expand Down
4 changes: 2 additions & 2 deletions src/operator/custom/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ void ForwardEx(const OpStatePtr& state, const OpContext& ctx,
static_cast<int>(ctx.is_train),
params.info->contexts[kCustomOpForward]));
},
ctx, false, ctx.is_train, cpys, tags, output_tags, outputs);
ctx, false, ctx.is_train, cpys, tags, output_tags, outputs, params.op_type);
}

void BackwardEx(const OpStatePtr& state, const OpContext& ctx,
Expand Down Expand Up @@ -415,7 +415,7 @@ void BackwardEx(const OpStatePtr& state, const OpContext& ctx,
ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
params.info->contexts[kCustomOpBackward]));
}, ctx, false, ctx.is_train, cpys, tags, output_tags, outputs);
}, ctx, false, ctx.is_train, cpys, tags, output_tags, outputs, "_backward_" + params.op_type);
}

// infer storage backward function for custom op which assigns kDefaultStorage for
Expand Down
125 changes: 125 additions & 0 deletions src/profiler/custom_op_profiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef MXNET_PROFILER_CUSTOM_OP_PROFILER_H_
#define MXNET_PROFILER_CUSTOM_OP_PROFILER_H_

#include <string>
#include <unordered_set>
#include <unordered_map>
#include <thread>
#include "./profiler.h"

namespace mxnet {
namespace profiler {

using Tid = std::thread::id;
using TaskPtr = std::unique_ptr<ProfileTask>;

/*!
* \brief Singleton class to assist profiling python callback of custom operators
* and to assist linking sub-operators to custom operators
*/
class CustomOpProfiler {
public:
static CustomOpProfiler* Get() {
static std::mutex mtx;
static std::unique_ptr<CustomOpProfiler> prof = nullptr;
if (!prof) {
std::unique_lock<std::mutex> lk(mtx);
if (!prof)
prof = std::make_unique<CustomOpProfiler>();
}
return prof.get();
}
/*!
* \brief Called before the callback of custom operators to start a profile task for python
* code execution time
* \param op_type The registed name of the custom operator
*/
void OnCustomBegin(const std::string& op_type) {
const Tid tid = std::this_thread::get_id();
const std::string task_name = MakePythonCodeName(op_type);
std::lock_guard<std::mutex> lock(mutex_);
tid_to_op_type_[tid] = op_type;
tasks_[tid] = std::make_unique<ProfileTask>(task_name.c_str(), &custom_op_domain);
tasks_[tid]->start();
}

/*!
* \brief Called after the callback of custom operators to stop the profile task for python
* code execution time
*/
void OnCustomEnd() {
const Tid tid = std::this_thread::get_id();
std::lock_guard<std::mutex> lock(mutex_);
tid_to_op_type_.erase(tid);
// this should never fail
CHECK(tasks_.find(tid) != tasks_.end()) << "thread_id not found. " <<
"Please use OnCustomBegin() and OnCustomEnd() in pairs.";
tasks_[tid]->stop();
tasks_.erase(tid);
}

/*!
* \brief Generate a display name for sub-operators, which is the name used for OprBlock
* and later by profiler, and store it in a unordered_set so that it can be referenced
* in the future.
* Notice if the operator is not a sub-operator, just return the char pointer back.
* \param op_type The registed name of the operator
* \return Returns a pointer to the display name generated
*/
const char* GenerateDisplayName(const char* op_type) {
if (!op_type) {
return nullptr;
}
Tid tid = std::this_thread::get_id();
std::lock_guard<std::mutex> lock(mutex_);
if (tid_to_op_type_.find(tid) == tid_to_op_type_.end()) {
return op_type;
}
std::string name = MakeSubOperatorName(tid, op_type);
return display_names_.insert(name).first->c_str();
}

private:
/* !\brief make the display name for sub-operators */
inline std::string MakeSubOperatorName(const Tid& tid, const char* op_type) {
return tid_to_op_type_[tid] + "::" + std::string(op_type);
}
/* !\brief make the display name for the pure python call back function i.e.
* forward() or backward() in the custom operator definition
*/
inline std::string MakePythonCodeName(const std::string& op_type) {
return op_type + "::pure_python";
}
/*! \brief class mutex */
std::mutex mutex_;
/* !\brief display names for sub-operators in custom ops */
std::unordered_set<std::string> display_names_;
/* !\brief profiling tasks for pure python code in custom operators */
std::unordered_map<Tid, TaskPtr> tasks_;
/* !\brief the maping from thread id to the registered name op the custom operator
* that is runnin on that thread
*/
std::unordered_map<Tid, std::string> tid_to_op_type_;
};
} // namespace profiler
} // namespace mxnet

#endif // MXNET_PROFILER_CUSTOM_OP_PROFILER_H_
53 changes: 46 additions & 7 deletions src/profiler/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,13 @@ struct ProfileTask : public ProfileDuration {
NVTX_ONLY_CODE(nvtx_duration_.reset(new nvtx::NVTXDuration(name)));
}

/*!
* \brief Set the domain
*/
void setDomain(ProfileDomain* domain) {
domain_ = domain;
}

/*!
* \brief Start the profiling scope
*/
Expand Down Expand Up @@ -1111,6 +1118,8 @@ struct ProfileMarker {
VTUNE_ONLY_CODE(std::unique_ptr<vtune::VTuneInstantMarker> vtune_instant_marker_);
};

static ProfileDomain custom_op_domain("Custom Operator");

/*!
* \brief Operator profiler object. Logs as both an independent event and a task in
* the operator domain
Expand Down Expand Up @@ -1162,10 +1171,16 @@ struct ProfileOperator : public ProfileEvent {
: ProfileEvent(name)
, as_task_(name, &domain_)
, name_(name)
, attributes_(attributes) {
, attributes_(attributes)
, profiling_(!IsDeprecatedOperator(name)) {
if (IsSubOperatorOfCustom(name)) {
as_task_.setDomain(&custom_op_domain);
SetCategories(custom_op_domain.name());
} else {
SetCategories(domain_.name());
}
// make as_task_ not to add stat to AggregateStats; otherwise we will add twice
as_task_.enableAggregateStats(false);
SetCategories(domain_.name());
}
/*!
* \brief Start the profiling scope
Expand All @@ -1175,15 +1190,19 @@ struct ProfileOperator : public ProfileEvent {
void start(mxnet::Context::DeviceType dev_type, uint32_t dev_id) {
dev_type_ = dev_type;
dev_id_ = dev_id;
ProfileEvent::start();
as_task_.start();
if (profiling_) {
ProfileEvent::start();
as_task_.start();
}
}
/*!
* \brief Stop the profiling scope
*/
void stop() override {
as_task_.stop();
ProfileEvent::stop();
if (profiling_) {
as_task_.stop();
ProfileEvent::stop();
}
}

/*!
Expand All @@ -1208,7 +1227,11 @@ struct ProfileOperator : public ProfileEvent {
if (attributes) {
name_.append(attributes->to_string().c_str());
}
categories_.set("operator");
if (IsSubOperatorOfCustom(name)) {
categories_.set(custom_op_domain.name());
} else {
categories_.set("operator");
}
items_[kStart].timestamp_ = start_time;
items_[kStop].timestamp_ = stop_time;
}
Expand All @@ -1228,6 +1251,20 @@ struct ProfileOperator : public ProfileEvent {
start_time_, ProfileStat::NowInMicrosec(),
attributes_.get());
}
/*!
* \brief Check if this operator is no longer profiled
* Notice that this operator may still be used for e.g synchronization
*/
inline static bool IsDeprecatedOperator(const char* name) {
return strcmp(name, "CustomOperatorWait") == 0 ||
strcmp(name, "Custom") == 0 || strcmp(name, "_backward_Custom") == 0;
}
/*!
* \brief Check if this operator a sub-operator of a custom operator
*/
inline static bool IsSubOperatorOfCustom(const char* name) {
return strstr(name, "::");
}
/*! \brief Also log the operator as a task in the operator domain */
ProfileTask as_task_;
/* !\brief Operator name */
Expand All @@ -1240,6 +1277,8 @@ struct ProfileOperator : public ProfileEvent {
static ProfileDomain domain_;
/*! \brief Optional operator attributes */
std::unique_ptr<Attributes> attributes_;
/*! \brief Whether to profile or not */
const bool profiling_;
};

/*
Expand Down
Loading

0 comments on commit 92fce90

Please sign in to comment.