diff --git a/src/profiler/aggregate_stats.cc b/src/profiler/aggregate_stats.cc index ff3894650545..98ac5e96b761 100644 --- a/src/profiler/aggregate_stats.cc +++ b/src/profiler/aggregate_stats.cc @@ -85,7 +85,9 @@ inline std::priority_queue void AggregateStats::OnProfileStat(const ProfileStat& stat) { std::unique_lock lk(m_); - stat.SaveAggregate(&stats_[stat.categories_.c_str()][stat.name_.c_str()]); + if (stat.enable_aggregate_) { + stat.SaveAggregate(&stats_[stat.categories_.c_str()][stat.name_.c_str()]); + } } void AggregateStats::DumpTable(std::ostream& os, int sort_by, int ascending) { diff --git a/src/profiler/profiler.h b/src/profiler/profiler.h index f9eb0af9acc1..a6d9ecf06fee 100644 --- a/src/profiler/profiler.h +++ b/src/profiler/profiler.h @@ -128,6 +128,9 @@ struct ProfileStat { /*! \brief operation categories (comma-delimited) */ profile_stat_string categories_; + /*! \brief whether to add this stat to AggregateStats */ + bool enable_aggregate_ = true; + /* !\brief Process id */ size_t process_id_ = current_process_id(); @@ -807,6 +810,13 @@ struct ProfileTask : public ProfileDuration { ProfileObjectType type() const override { return kTask; } + /*! + * \brief Whether to add stat to AggregateStats + */ + void enableAggregateStats(bool enabled = true) { + enable_aggregate_ = enabled; + } + protected: /*! * \brief Task statistic object @@ -831,6 +841,7 @@ struct ProfileTask : public ProfileDuration { inline void SendStat() { Profiler::Get()->AddNewProfileStat([this](ProfileTaskStat *stat) { stat->categories_.set(domain_->name()); + stat->enable_aggregate_ = enable_aggregate_; }, name_.c_str(), start_time_, ProfileStat::NowInMicrosec()); } /*! \brief Task name */ @@ -843,6 +854,8 @@ struct ProfileTask : public ProfileDuration { VTUNE_ONLY_CODE(std::unique_ptr vtune_task_); /*! \brief NVTX duration object */ NVTX_ONLY_CODE(std::unique_ptr nvtx_duration_); + /*! \brief whether to add this stat to AggregateStats */ + bool enable_aggregate_ = true; protected: /*! \brief Task's start tick */ @@ -1150,6 +1163,8 @@ struct ProfileOperator : public ProfileEvent { , as_task_(name, &domain_) , name_(name) , attributes_(attributes) { + // make as_task_ not to add stat to AggregateStats; otherwise we will add twice + as_task_.enableAggregateStats(false); SetCategories(domain_.name()); } /*! diff --git a/tests/python/unittest/test_profiler.py b/tests/python/unittest/test_profiler.py index 09e571be0088..b76bfbc82d12 100644 --- a/tests/python/unittest/test_profiler.py +++ b/tests/python/unittest/test_profiler.py @@ -269,6 +269,28 @@ def check_sorting(debug_str, sort_by, ascending): check_sorting(debug_str, sb, asc) profiler.set_state('stop') +def test_aggregate_duplication(): + file_name = 'test_aggregate_duplication.json' + enable_profiler(profile_filename = file_name, run=True, continuous_dump=True, \ + aggregate_stats=True) + inp = mx.nd.zeros(shape=(100, 100)) + y = mx.nd.sqrt(inp) + inp = inp + 1 + inp = inp + 1 + mx.nd.waitall() + profiler.dump(False) + debug_str = profiler.dumps(format = 'json') + target_dict = json.loads(debug_str) + assert 'Time' in target_dict and 'operator' in target_dict['Time'] \ + and 'sqrt' in target_dict['Time']['operator'] \ + and 'Count' in target_dict['Time']['operator']['sqrt'] \ + and '_plus_scalar' in target_dict['Time']['operator'] \ + and 'Count' in target_dict['Time']['operator']['_plus_scalar'] + # they are called once and twice respectively + assert target_dict['Time']['operator']['sqrt']['Count'] == 1 + assert target_dict['Time']['operator']['_plus_scalar']['Count'] == 2 + profiler.set_state('stop') + if __name__ == '__main__': import nose nose.runmodule()