Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixing duplication in operator profiling (#15240)
Browse files Browse the repository at this point in the history
* initial

* adding a bool in ProfileStat to control if add to aggregate stats

* add test case

* fix style

* stylefix

* testcases

* fix type

* fix comment

* Update profiler.h
  • Loading branch information
Zha0q1 authored and sandeep-krishnamurthy committed Jun 21, 2019
1 parent 4a9e9f6 commit 3f8fd00
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/profiler/aggregate_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ inline std::priority_queue<pi>

void AggregateStats::OnProfileStat(const ProfileStat& stat) {
std::unique_lock<std::mutex> lk(m_);
stat.SaveAggregate(&stats_[stat.categories_.c_str()][stat.name_.c_str()]);
if (stat.enable_aggregate_) {
stat.SaveAggregate(&stats_[stat.categories_.c_str()][stat.name_.c_str()]);
}
}

void AggregateStats::DumpTable(std::ostream& os, int sort_by, int ascending) {
Expand Down
15 changes: 15 additions & 0 deletions src/profiler/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ struct ProfileStat {
/*! \brief operation categories (comma-delimited) */
profile_stat_string categories_;

/*! \brief whether to add this stat to AggregateStats */
bool enable_aggregate_ = true;

/* !\brief Process id */
size_t process_id_ = current_process_id();

Expand Down Expand Up @@ -807,6 +810,13 @@ struct ProfileTask : public ProfileDuration {

ProfileObjectType type() const override { return kTask; }

/*!
* \brief Whether to add stat to AggregateStats
*/
void enableAggregateStats(bool enabled = true) {
enable_aggregate_ = enabled;
}

protected:
/*!
* \brief Task statistic object
Expand All @@ -831,6 +841,7 @@ struct ProfileTask : public ProfileDuration {
inline void SendStat() {
Profiler::Get()->AddNewProfileStat<ProfileTaskStat>([this](ProfileTaskStat *stat) {
stat->categories_.set(domain_->name());
stat->enable_aggregate_ = enable_aggregate_;
}, name_.c_str(), start_time_, ProfileStat::NowInMicrosec());
}
/*! \brief Task name */
Expand All @@ -843,6 +854,8 @@ struct ProfileTask : public ProfileDuration {
VTUNE_ONLY_CODE(std::unique_ptr<vtune::VTuneTask> vtune_task_);
/*! \brief NVTX duration object */
NVTX_ONLY_CODE(std::unique_ptr<nvtx::NVTXDuration> nvtx_duration_);
/*! \brief whether to add this stat to AggregateStats */
bool enable_aggregate_ = true;

protected:
/*! \brief Task's start tick */
Expand Down Expand Up @@ -1150,6 +1163,8 @@ struct ProfileOperator : public ProfileEvent {
, as_task_(name, &domain_)
, name_(name)
, attributes_(attributes) {
// make as_task_ not to add stat to AggregateStats; otherwise we will add twice
as_task_.enableAggregateStats(false);
SetCategories(domain_.name());
}
/*!
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ def check_sorting(debug_str, sort_by, ascending):
check_sorting(debug_str, sb, asc)
profiler.set_state('stop')

def test_aggregate_duplication():
file_name = 'test_aggregate_duplication.json'
enable_profiler(profile_filename = file_name, run=True, continuous_dump=True, \
aggregate_stats=True)
inp = mx.nd.zeros(shape=(100, 100))
y = mx.nd.sqrt(inp)
inp = inp + 1
inp = inp + 1
mx.nd.waitall()
profiler.dump(False)
debug_str = profiler.dumps(format = 'json')
target_dict = json.loads(debug_str)
assert 'Time' in target_dict and 'operator' in target_dict['Time'] \
and 'sqrt' in target_dict['Time']['operator'] \
and 'Count' in target_dict['Time']['operator']['sqrt'] \
and '_plus_scalar' in target_dict['Time']['operator'] \
and 'Count' in target_dict['Time']['operator']['_plus_scalar']
# they are called once and twice respectively
assert target_dict['Time']['operator']['sqrt']['Count'] == 1
assert target_dict['Time']['operator']['_plus_scalar']['Count'] == 2
profiler.set_state('stop')

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 3f8fd00

Please sign in to comment.