Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

adding "total" (total time) to profiler aggregate stats sorting criteria #16055

Merged
merged 2 commits into from
Sep 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
* \param out_str will receive a pointer to the output string
* \param reset clear the aggregate stats after printing
* \param format whether to return in tabular or json format
* \param sort_by sort by avg, min, max, or count
* \param sort_by sort by total, avg, min, max, or count
* \param ascending whether to sort ascendingly
* \return 0 when success, -1 when failure happens.
* \note
Expand Down
11 changes: 6 additions & 5 deletions python/mxnet/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def dump_profile():
dump(True)


def dumps(reset=False, format='table', sort_by='avg', ascending=False):
def dumps(reset=False, format='table', sort_by='total', ascending=False):
"""Return a printable string of aggregate profile stats.

Parameters
Expand All @@ -160,22 +160,23 @@ def dumps(reset=False, format='table', sort_by='avg', ascending=False):
can take 'table' or 'json'
defaults to 'table'
sort_by: string
can take 'avg', 'min', 'max', or 'count'
can take 'total', 'avg', 'min', 'max', or 'count'
by which stat to sort the entries in each category
defaults to 'avg'
defaults to 'total'
ascending: boolean
whether to sort ascendingly
defaults to False
"""
debug_str = ctypes.c_char_p()
reset_to_int = {False: 0, True: 1}
format_to_int = {'table': 0, 'json': 1}
sort_by_to_int = {'avg': 0, 'min': 1, 'max': 2, 'count': 3}
sort_by_to_int = {'total': 0, 'avg': 1, 'min': 2, 'max': 3, 'count': 4}
asc_to_int = {False: 0, True: 1}
assert format in format_to_int.keys(),\
"Invalid value provided for format: {0}. Support: 'table', 'json'".format(format)
assert sort_by in sort_by_to_int.keys(),\
"Invalid value provided for sort_by: {0}. Support: 'avg', 'min', 'max', 'count'"\
"Invalid value provided for sort_by: {0}.\
Support: 'total', 'avg', 'min', 'max', 'count'"\
.format(sort_by)
assert ascending in asc_to_int.keys(),\
"Invalid value provided for ascending: {0}. Support: False, True".format(ascending)
Expand Down
3 changes: 3 additions & 0 deletions src/profiler/aggregate_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ inline std::priority_queue<pi>
const AggregateStats::StatData& data = iter.second;
double value = 0;
switch (static_cast<AggregateStats::SortBy>(sort_by)) {
case AggregateStats::SortBy::Total:
value = data.total_aggregate_;
break;
case AggregateStats::SortBy::Avg:
if (data.type_ == AggregateStats::StatData::kCounter)
value = (data.max_aggregate_ - data.min_aggregate_) / 2;
Expand Down
2 changes: 1 addition & 1 deletion src/profiler/aggregate_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class AggregateStats {
void clear();
/* !\brief by which stat to sort */
enum class SortBy {
Avg, Min, Max, Count
Total, Avg, Min, Max, Count
};

private:
Expand Down
13 changes: 8 additions & 5 deletions tests/python/unittest/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,12 @@ def test_aggregate_stats_valid_json_return():
debug_str = profiler.dumps(format = 'json')
assert(len(debug_str) > 0)
target_dict = json.loads(debug_str)
assert "Memory" in target_dict and "Time" in target_dict and "Unit" in target_dict
assert 'Memory' in target_dict and 'Time' in target_dict and 'Unit' in target_dict
profiler.set_state('stop')

def test_aggregate_stats_sorting():
sort_by_options = {'avg': "Avg", 'min': "Min", 'max': "Max", 'count': "Count"}
sort_by_options = {'total': 'Total', 'avg': 'Avg', 'min': 'Min',\
'max': 'Max', 'count': 'Count'}
ascending_options = [False, True]
def check_ascending(lst, asc):
assert(lst == sorted(lst, reverse = not asc))
Expand All @@ -258,9 +259,11 @@ def check_sorting(debug_str, sort_by, ascending):
for domain_name, domain in target_dict['Time'].items():
lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()]
check_ascending(lst, ascending)
for domain_name, domain in target_dict['Memory'].items():
lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()]
check_ascending(lst, ascending)
# Memory items do not have stat 'Total'
if sort_by != 'total':
for domain_name, domain in target_dict['Memory'].items():
lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()]
check_ascending(lst, ascending)

file_name = 'test_aggregate_stats_sorting.json'
enable_profiler(file_name, True, True, True)
Expand Down