Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixing duplication in operator profiling #15240

Merged
merged 9 commits into from
Jun 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/profiler/aggregate_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ inline std::priority_queue<pi>

void AggregateStats::OnProfileStat(const ProfileStat& stat) {
std::unique_lock<std::mutex> lk(m_);
stat.SaveAggregate(&stats_[stat.categories_.c_str()][stat.name_.c_str()]);
if (stat.enable_aggregate_) {
stat.SaveAggregate(&stats_[stat.categories_.c_str()][stat.name_.c_str()]);
}
}

void AggregateStats::DumpTable(std::ostream& os, int sort_by, int ascending) {
Expand Down
15 changes: 15 additions & 0 deletions src/profiler/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ struct ProfileStat {
/*! \brief operation categories (comma-delimited) */
profile_stat_string categories_;

/*! \brief whether to add this stat to AggregateStats */
bool enable_aggregate_ = true;

/* !\brief Process id */
size_t process_id_ = current_process_id();

Expand Down Expand Up @@ -807,6 +810,13 @@ struct ProfileTask : public ProfileDuration {

ProfileObjectType type() const override { return kTask; }

/*!
* \brief Whether to add stat to AggregateStats
*/
void enableAggregateStats(bool enabled = true) {
enable_aggregate_ = enabled;
}

protected:
/*!
* \brief Task statistic object
Expand All @@ -831,6 +841,7 @@ struct ProfileTask : public ProfileDuration {
inline void SendStat() {
Profiler::Get()->AddNewProfileStat<ProfileTaskStat>([this](ProfileTaskStat *stat) {
stat->categories_.set(domain_->name());
stat->enable_aggregate_ = enable_aggregate_;
}, name_.c_str(), start_time_, ProfileStat::NowInMicrosec());
}
/*! \brief Task name */
Expand All @@ -843,6 +854,8 @@ struct ProfileTask : public ProfileDuration {
VTUNE_ONLY_CODE(std::unique_ptr<vtune::VTuneTask> vtune_task_);
/*! \brief NVTX duration object */
NVTX_ONLY_CODE(std::unique_ptr<nvtx::NVTXDuration> nvtx_duration_);
/*! \brief whether to add this stat to AggregateStats */
bool enable_aggregate_ = true;

protected:
/*! \brief Task's start tick */
Expand Down Expand Up @@ -1150,6 +1163,8 @@ struct ProfileOperator : public ProfileEvent {
, as_task_(name, &domain_)
, name_(name)
, attributes_(attributes) {
// make as_task_ not to add stat to AggregateStats; otherwise we will add twice
as_task_.enableAggregateStats(false);
SetCategories(domain_.name());
}
/*!
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ def check_sorting(debug_str, sort_by, ascending):
check_sorting(debug_str, sb, asc)
profiler.set_state('stop')

def test_aggregate_duplication():
file_name = 'test_aggregate_duplication.json'
enable_profiler(profile_filename = file_name, run=True, continuous_dump=True, \
aggregate_stats=True)
inp = mx.nd.zeros(shape=(100, 100))
y = mx.nd.sqrt(inp)
inp = inp + 1
inp = inp + 1
mx.nd.waitall()
profiler.dump(False)
debug_str = profiler.dumps(format = 'json')
target_dict = json.loads(debug_str)
assert 'Time' in target_dict and 'operator' in target_dict['Time'] \
and 'sqrt' in target_dict['Time']['operator'] \
and 'Count' in target_dict['Time']['operator']['sqrt'] \
and '_plus_scalar' in target_dict['Time']['operator'] \
and 'Count' in target_dict['Time']['operator']['_plus_scalar']
# they are called once and twice respectively
assert target_dict['Time']['operator']['sqrt']['Count'] == 1
assert target_dict['Time']['operator']['_plus_scalar']['Count'] == 2
profiler.set_state('stop')

if __name__ == '__main__':
import nose
nose.runmodule()