Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding "total" (total time) to profiler aggregate stats sorting crite…
Browse files Browse the repository at this point in the history
…ria (#16055)

* Update aggregate_stats.h

* add 'total' to sorting criteria
  • Loading branch information
Zha0q1 authored and eric-haibin-lin committed Sep 2, 2019
1 parent 5def003 commit 1abf05b
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 12 deletions.
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
* \param out_str will receive a pointer to the output string
* \param reset clear the aggregate stats after printing
* \param format whether to return in tabular or json format
* \param sort_by sort by avg, min, max, or count
* \param sort_by sort by total, avg, min, max, or count
* \param ascending whether to sort ascendingly
* \return 0 when success, -1 when failure happens.
* \note
Expand Down
11 changes: 6 additions & 5 deletions python/mxnet/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def dump_profile():
dump(True)


def dumps(reset=False, format='table', sort_by='avg', ascending=False):
def dumps(reset=False, format='table', sort_by='total', ascending=False):
"""Return a printable string of aggregate profile stats.
Parameters
Expand All @@ -160,22 +160,23 @@ def dumps(reset=False, format='table', sort_by='avg', ascending=False):
can take 'table' or 'json'
defaults to 'table'
sort_by: string
can take 'avg', 'min', 'max', or 'count'
can take 'total', 'avg', 'min', 'max', or 'count'
by which stat to sort the entries in each category
defaults to 'avg'
defaults to 'total'
ascending: boolean
whether to sort ascendingly
defaults to False
"""
debug_str = ctypes.c_char_p()
reset_to_int = {False: 0, True: 1}
format_to_int = {'table': 0, 'json': 1}
sort_by_to_int = {'avg': 0, 'min': 1, 'max': 2, 'count': 3}
sort_by_to_int = {'total': 0, 'avg': 1, 'min': 2, 'max': 3, 'count': 4}
asc_to_int = {False: 0, True: 1}
assert format in format_to_int.keys(),\
"Invalid value provided for format: {0}. Support: 'table', 'json'".format(format)
assert sort_by in sort_by_to_int.keys(),\
"Invalid value provided for sort_by: {0}. Support: 'avg', 'min', 'max', 'count'"\
"Invalid value provided for sort_by: {0}.\
Support: 'total', 'avg', 'min', 'max', 'count'"\
.format(sort_by)
assert ascending in asc_to_int.keys(),\
"Invalid value provided for ascending: {0}. Support: False, True".format(ascending)
Expand Down
3 changes: 3 additions & 0 deletions src/profiler/aggregate_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ inline std::priority_queue<pi>
const AggregateStats::StatData& data = iter.second;
double value = 0;
switch (static_cast<AggregateStats::SortBy>(sort_by)) {
case AggregateStats::SortBy::Total:
value = data.total_aggregate_;
break;
case AggregateStats::SortBy::Avg:
if (data.type_ == AggregateStats::StatData::kCounter)
value = (data.max_aggregate_ - data.min_aggregate_) / 2;
Expand Down
2 changes: 1 addition & 1 deletion src/profiler/aggregate_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class AggregateStats {
void clear();
/* !\brief by which stat to sort */
enum class SortBy {
Avg, Min, Max, Count
Total, Avg, Min, Max, Count
};

private:
Expand Down
13 changes: 8 additions & 5 deletions tests/python/unittest/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,12 @@ def test_aggregate_stats_valid_json_return():
debug_str = profiler.dumps(format = 'json')
assert(len(debug_str) > 0)
target_dict = json.loads(debug_str)
assert "Memory" in target_dict and "Time" in target_dict and "Unit" in target_dict
assert 'Memory' in target_dict and 'Time' in target_dict and 'Unit' in target_dict
profiler.set_state('stop')

def test_aggregate_stats_sorting():
sort_by_options = {'avg': "Avg", 'min': "Min", 'max': "Max", 'count': "Count"}
sort_by_options = {'total': 'Total', 'avg': 'Avg', 'min': 'Min',\
'max': 'Max', 'count': 'Count'}
ascending_options = [False, True]
def check_ascending(lst, asc):
assert(lst == sorted(lst, reverse = not asc))
Expand All @@ -258,9 +259,11 @@ def check_sorting(debug_str, sort_by, ascending):
for domain_name, domain in target_dict['Time'].items():
lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()]
check_ascending(lst, ascending)
for domain_name, domain in target_dict['Memory'].items():
lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()]
check_ascending(lst, ascending)
# Memory items do not have stat 'Total'
if sort_by != 'total':
for domain_name, domain in target_dict['Memory'].items():
lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()]
check_ascending(lst, ascending)

file_name = 'test_aggregate_stats_sorting.json'
enable_profiler(file_name, True, True, True)
Expand Down

0 comments on commit 1abf05b

Please sign in to comment.