diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index e0ed316a1bdc..657b32452cd3 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -345,7 +345,7 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset); * \param out_str will receive a pointer to the output string * \param reset clear the aggregate stats after printing * \param format whether to return in tabular or json format - * \param sort_by sort by avg, min, max, or count + * \param sort_by sort by total, avg, min, max, or count * \param ascending whether to sort ascendingly * \return 0 when success, -1 when failure happens. * \note diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py index 373a11592f86..7dbc060ed60f 100644 --- a/python/mxnet/profiler.py +++ b/python/mxnet/profiler.py @@ -148,7 +148,7 @@ def dump_profile(): dump(True) -def dumps(reset=False, format='table', sort_by='avg', ascending=False): +def dumps(reset=False, format='table', sort_by='total', ascending=False): """Return a printable string of aggregate profile stats. Parameters @@ -160,9 +160,9 @@ def dumps(reset=False, format='table', sort_by='avg', ascending=False): can take 'table' or 'json' defaults to 'table' sort_by: string - can take 'avg', 'min', 'max', or 'count' + can take 'total', 'avg', 'min', 'max', or 'count' by which stat to sort the entries in each category - defaults to 'avg' + defaults to 'total' ascending: boolean whether to sort ascendingly defaults to False @@ -170,12 +170,13 @@ def dumps(reset=False, format='table', sort_by='avg', ascending=False): debug_str = ctypes.c_char_p() reset_to_int = {False: 0, True: 1} format_to_int = {'table': 0, 'json': 1} - sort_by_to_int = {'avg': 0, 'min': 1, 'max': 2, 'count': 3} + sort_by_to_int = {'total': 0, 'avg': 1, 'min': 2, 'max': 3, 'count': 4} asc_to_int = {False: 0, True: 1} assert format in format_to_int.keys(),\ "Invalid value provided for format: {0}. Support: 'table', 'json'".format(format) assert sort_by in sort_by_to_int.keys(),\ - "Invalid value provided for sort_by: {0}. Support: 'avg', 'min', 'max', 'count'"\ + "Invalid value provided for sort_by: {0}.\ + Support: 'total', 'avg', 'min', 'max', 'count'"\ .format(sort_by) assert ascending in asc_to_int.keys(),\ "Invalid value provided for ascending: {0}. Support: False, True".format(ascending) diff --git a/src/profiler/aggregate_stats.cc b/src/profiler/aggregate_stats.cc index 98ac5e96b761..86791ebf3074 100644 --- a/src/profiler/aggregate_stats.cc +++ b/src/profiler/aggregate_stats.cc @@ -56,6 +56,9 @@ inline std::priority_queue const AggregateStats::StatData& data = iter.second; double value = 0; switch (static_cast(sort_by)) { + case AggregateStats::SortBy::Total: + value = data.total_aggregate_; + break; case AggregateStats::SortBy::Avg: if (data.type_ == AggregateStats::StatData::kCounter) value = (data.max_aggregate_ - data.min_aggregate_) / 2; diff --git a/src/profiler/aggregate_stats.h b/src/profiler/aggregate_stats.h index b9634250b82a..fe7521adc544 100644 --- a/src/profiler/aggregate_stats.h +++ b/src/profiler/aggregate_stats.h @@ -74,7 +74,7 @@ class AggregateStats { void clear(); /* !\brief by which stat to sort */ enum class SortBy { - Avg, Min, Max, Count + Total, Avg, Min, Max, Count }; private: diff --git a/tests/python/unittest/test_profiler.py b/tests/python/unittest/test_profiler.py index d04f390f219b..e5fdd2b1798a 100644 --- a/tests/python/unittest/test_profiler.py +++ b/tests/python/unittest/test_profiler.py @@ -243,11 +243,12 @@ def test_aggregate_stats_valid_json_return(): debug_str = profiler.dumps(format = 'json') assert(len(debug_str) > 0) target_dict = json.loads(debug_str) - assert "Memory" in target_dict and "Time" in target_dict and "Unit" in target_dict + assert 'Memory' in target_dict and 'Time' in target_dict and 'Unit' in target_dict profiler.set_state('stop') def test_aggregate_stats_sorting(): - sort_by_options = {'avg': "Avg", 'min': "Min", 'max': "Max", 'count': "Count"} + sort_by_options = {'total': 'Total', 'avg': 'Avg', 'min': 'Min',\ + 'max': 'Max', 'count': 'Count'} ascending_options = [False, True] def check_ascending(lst, asc): assert(lst == sorted(lst, reverse = not asc)) @@ -258,9 +259,11 @@ def check_sorting(debug_str, sort_by, ascending): for domain_name, domain in target_dict['Time'].items(): lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()] check_ascending(lst, ascending) - for domain_name, domain in target_dict['Memory'].items(): - lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()] - check_ascending(lst, ascending) + # Memory items do not have stat 'Total' + if sort_by != 'total': + for domain_name, domain in target_dict['Memory'].items(): + lst = [item[sort_by_options[sort_by]] for item_name, item in domain.items()] + check_ascending(lst, ascending) file_name = 'test_aggregate_stats_sorting.json' enable_profiler(file_name, True, True, True)