Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Profiler API Enhancements #15132

Merged
merged 24 commits into from
Jun 15, 2019
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,30 @@ MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHan
*/
MXNET_DLL int MXDumpProfile(int finished);


/*!
* \brief Print aggregate stats to the a string
* \brief Deprecated, use MXAggregateProfileStatsPrintEx instead.
* \param out_str Will receive a pointer to the output string
* \param reset Clear the aggregate stats after printing
* \return 0 when success, -1 when failure happens.
* \note
*/
MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);

/*!
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
* \brief Print sorted aggregate stats to the a string
* How aggregate stats are stored will not change
* \param out_str will receive a pointer to the output string
* \param reset clear the aggregate stats after printing
* \param format whether to return in tabular or json format
* \param sort_by sort by avg, min, max, or count
* \param ascending whether to sort ascendingly
* \return 0 when success, -1 when failure happens.
* \note
*/
MXNET_DLL int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format,
int sort_by, int ascending);

/*!
* \brief Pause profiler tuning collection
* \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
Expand Down
35 changes: 31 additions & 4 deletions python/mxnet/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,44 @@ def dump_profile():
dump(True)


def dumps(reset=False):
def dumps(reset=False, format='table', sort_by='avg', ascending=False):
"""Return a printable string of aggregate profile stats.
Parameters
----------
reset: boolean
Indicates whether to clean aggeregate statistical data collected up to this point
indicates whether to clean aggeregate statistical data collected up to this point
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
format: string
whether to return the aggregate stats in table of json format
can take 'table' or 'json'
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
defaults to 'table'
sort_by: string
can take 'avg', 'min', 'max', or 'count'
by which stat to sort the entries in each category
defaults to 'avg'
ascending: boolean
whether to sort ascendingly
defaults to False
"""
debug_str = ctypes.c_char_p()
do_reset = 1 if reset is True else 0
check_call(_LIB.MXAggregateProfileStatsPrint(ctypes.byref(debug_str), int(do_reset)))
reset_to_int = {False: 0, True: 1}
format_to_int = {'table': 0, 'json': 1}
sort_by_to_int = {'avg': 0, 'min': 1, 'max': 2, 'count': 3}
asc_to_int = {False: 0, True: 1}
assert format in format_to_int.keys(),\
"Invalid value provided for format: {0}. Support: 'table', 'json'".format(format)
assert sort_by in sort_by_to_int.keys(),\
"Invalid value provided for sort_by: {0}. Support: 'avg', 'min', 'max', 'count'"\
.format(sort_by)
assert ascending in asc_to_int.keys(),\
"Invalid value provided for ascending: {0}. Support: False, True".format(ascending)
assert reset in reset_to_int.keys(),\
"Invalid value provided for reset: {0}. Support: False, True".format(reset)
check_call(_LIB.MXAggregateProfileStatsPrintEx(ctypes.byref(debug_str),
reset_to_int[reset],
format_to_int[format],
sort_by_to_int[sort_by],
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
asc_to_int[ascending]))
return py_str(debug_str.value)


Expand Down
18 changes: 17 additions & 1 deletion src/c_api/c_api_profile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ enum class ProfileProcess {
kWorker, kServer
};

enum class PrintFormat {
table, json
};

struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
bool profile_all;
bool profile_symbolic;
Expand Down Expand Up @@ -303,6 +307,11 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con
}

int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
return MXAggregateProfileStatsPrintEx(out_str, reset, 0, 0, 0);
}

int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int format, int sort_by,
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
int ascending) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
CHECK_NOTNULL(out_str);
Expand All @@ -314,8 +323,15 @@ int MXAggregateProfileStatsPrint(const char **out_str, int reset) {
std::shared_ptr<profiler::AggregateStats> stats = profiler->GetAggregateStats();
std::ostringstream os;
if (stats) {
stats->Dump(os, reset != 0);
if (static_cast<PrintFormat>(format) == PrintFormat::table)
stats->DumpTable(os, sort_by, ascending);
else if (static_cast<PrintFormat>(format) == PrintFormat::json)
stats->DumpJson(os, sort_by, ascending);
else
LOG(FATAL) << "Invliad value for parameter format";
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
}
if (reset != 0)
stats->clear();
ret->ret_str = os.str();
*out_str = (ret->ret_str).c_str();
API_END();
Expand Down
Loading