Skip to content

Commit

Permalink
Profiler API Enhancements (apache#15132)
Browse files Browse the repository at this point in the history
* add support for sorting and printing aggregate info in json

* revert some overwrites

* revert to old c apis; added MXAggregateProfileStatsPrintEx instead

* style fix

* style fixes

* more style fixes

* rebase

* style fixes

* fix infinite loop bug

* test cases added and bugs fixed

* fix test cases

* fix testcases

* testcases

* testcases

* testcases

* fix doc

* use enum to avoid hardcoding

* sanity test fix

* sanity test fix

* add parameter validation

* add parameter validation in frontend

* validation

* removing print()

* fix typos
  • Loading branch information
Zha0q1 authored and haohuw committed Jun 23, 2019
1 parent a9a0325 commit f0355c6
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 74 deletions.
17 changes: 16 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,30 @@ MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHan
*/
MXNET_DLL int MXDumpProfile(int finished);


/*!
* \brief Print aggregate stats to the a string
* \brief Deprecated, use MXAggregateProfileStatsPrintEx instead.
* \param out_str Will receive a pointer to the output string
* \param reset Clear the aggregate stats after printing
* \return 0 when success, -1 when failure happens.
* \note
*/
MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);

/*!
* \brief Print sorted aggregate stats to the a string
* How aggregate stats are stored will not change
* \param out_str will receive a pointer to the output string
* \param reset clear the aggregate stats after printing
* \param format whether to return in tabular or json format
* \param sort_by sort by avg, min, max, or count
* \param ascending whether to sort ascendingly
* \return 0 when success, -1 when failure happens.
* \note
*/
MXNET_DLL int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format,
int sort_by, int ascending);

/*!
* \brief Pause profiler tuning collection
* \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
Expand Down
35 changes: 31 additions & 4 deletions python/mxnet/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,44 @@ def dump_profile():
dump(True)


def dumps(reset=False):
def dumps(reset=False, format='table', sort_by='avg', ascending=False):
"""Return a printable string of aggregate profile stats.
Parameters
----------
reset: boolean
Indicates whether to clean aggeregate statistical data collected up to this point
indicates whether to clean aggeregate statistical data collected up to this point
format: string
whether to return the aggregate stats in table of json format
can take 'table' or 'json'
defaults to 'table'
sort_by: string
can take 'avg', 'min', 'max', or 'count'
by which stat to sort the entries in each category
defaults to 'avg'
ascending: boolean
whether to sort ascendingly
defaults to False
"""
debug_str = ctypes.c_char_p()
do_reset = 1 if reset is True else 0
check_call(_LIB.MXAggregateProfileStatsPrint(ctypes.byref(debug_str), int(do_reset)))
reset_to_int = {False: 0, True: 1}
format_to_int = {'table': 0, 'json': 1}
sort_by_to_int = {'avg': 0, 'min': 1, 'max': 2, 'count': 3}
asc_to_int = {False: 0, True: 1}
assert format in format_to_int.keys(),\
"Invalid value provided for format: {0}. Support: 'table', 'json'".format(format)
assert sort_by in sort_by_to_int.keys(),\
"Invalid value provided for sort_by: {0}. Support: 'avg', 'min', 'max', 'count'"\
.format(sort_by)
assert ascending in asc_to_int.keys(),\
"Invalid value provided for ascending: {0}. Support: False, True".format(ascending)
assert reset in reset_to_int.keys(),\
"Invalid value provided for reset: {0}. Support: False, True".format(reset)
check_call(_LIB.MXAggregateProfileStatsPrintEx(ctypes.byref(debug_str),
reset_to_int[reset],
format_to_int[format],
sort_by_to_int[sort_by],
asc_to_int[ascending]))
return py_str(debug_str.value)


Expand Down
18 changes: 17 additions & 1 deletion src/c_api/c_api_profile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ enum class ProfileProcess {
kWorker, kServer
};

enum class PrintFormat {
table, json
};

struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
bool profile_all;
bool profile_symbolic;
Expand Down Expand Up @@ -303,6 +307,11 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con
}

int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
return MXAggregateProfileStatsPrintEx(out_str, reset, 0, 0, 0);
}

int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format, int sort_by,
int ascending) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
CHECK_NOTNULL(out_str);
Expand All @@ -314,8 +323,15 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
std::shared_ptr<profiler::AggregateStats> stats = profiler->GetAggregateStats();
std::ostringstream os;
if (stats) {
stats->Dump(os, reset != 0);
if (static_cast<PrintFormat>(format) == PrintFormat::table)
stats->DumpTable(os, sort_by, ascending);
else if (static_cast<PrintFormat>(format) == PrintFormat::json)
stats->DumpJson(os, sort_by, ascending);
else
LOG(FATAL) << "Invalid value for parameter format";
}
if (reset != 0)
stats->clear();
ret->ret_str = os.str();
*out_str = (ret->ret_str).c_str();
API_END();
Expand Down
Loading

0 comments on commit f0355c6

Please sign in to comment.